Skip to content

Simplify module_inject.transpose#8028

Open
xbcReal wants to merge 1 commit into
deepspeedai:masterfrom
xbcReal:simplify-transpose-utils
Open

Simplify module_inject.transpose#8028
xbcReal wants to merge 1 commit into
deepspeedai:masterfrom
xbcReal:simplify-transpose-utils

Conversation

@xbcReal
Copy link
Copy Markdown
Contributor

@xbcReal xbcReal commented May 26, 2026

Motivation

The transpose() helper in deepspeed/module_inject/utils.py (consolidated from duplicate definitions in #7934) currently performs a temp-buffer + copy_() rewrite of the input tensor's storage:

def transpose(data):
    with torch.no_grad():
        data = data.contiguous()
        data1 = data.transpose(-1, -2).reshape(-1)
        data.reshape(-1).copy_(data1)
        data1 = None
    return data.reshape(data.shape[-1], data.shape[-2])

This PR replaces the body with a one-liner:

def transpose(data):
    return data.transpose(-1, -2).contiguous()

Two reasons

1. Subtle memory waste on non-contiguous input.
When the input is non-contiguous, data = data.contiguous() silently allocates a new storage. The follow-up copy_() then writes into that fresh storage, not into the caller's original storage — so the "in-place" intent is already lost. Worse, the temporary data1 buffer pushes peak memory to 3N instead of the 2N that .contiguous() alone would incur.

load_checkpoint.py:132 is a concrete case: weight_partition can come from torch.split(tmp_data, dst_shape[dim1], dim=dim)[rank] and is non-contiguous when dim == 1.

2. The in-place rewrite is not relied on by any caller.
A full audit of every call site shows all callers use the return value and none depend on the side effect of mutating the input's storage:

# Call site Input source Return used Relies on side-effect
1 policy.py:154 tmp = sd[src_name] yes no
2 policy.py:157 tmp = sd[src_name] yes no
3 policy.py:159 result of _transpose(...).contiguous() yes no
4 policy.py:161 local tmp yes no
5 policy.py:181 torch.cat((q,k,v), dim=0) yes no
6 policy.py:184 same qkv_data yes no
7 policy.py:198 torch.cat((reg, gate), dim=0) yes no
8 load_checkpoint.py:83 sd[0][prefix+n].to(device) yes no
9 load_checkpoint.py:132 torch.split(...) / torch.cat(...) yes no

Behavior

Input Old peak New peak Return semantics
Contiguous 2N 2N unchanged for callers
Non-contiguous 3N 2N unchanged for callers

Changes

  • deepspeed/module_inject/utils.py: rewrite transpose() body; remove now-unused import torch.

No call site changes.

Tests

pre-commit run --files deepspeed/module_inject/utils.py passes locally (yapf, flake8, check-torchdist, license header, codespell).

The in-place rewrite pattern (allocate temp, copy_ back into the input's
storage) provides no observable benefit at any current call site. All nine
callers use the return value, and none rely on the side effect of mutating
the input tensor's storage.

The current implementation is also subtly worse for non-contiguous inputs.
The leading `data = data.contiguous()` silently allocates a new storage,
which already breaks the "in-place" intent of the subsequent `copy_()`.
The temporary `data1` buffer then pushes peak memory to 3N instead of 2N.
In particular, `load_checkpoint.py:132` can receive a
`torch.split(tmp_data, ..., dim=dim)[rank]` result that is non-contiguous
when `dim == 1`, which triggers exactly this case.

Replacing the body with `data.transpose(-1, -2).contiguous()`:
  - keeps the 2N peak on contiguous input,
  - drops 3N to 2N on non-contiguous input,
  - preserves return-value semantics for every current caller,
  - drops the now-unused `import torch`.

Signed-off-by: binchengxiong <binchengxiong@alibaba-inc.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant