Simplify module_inject.transpose#8028
Open
xbcReal wants to merge 1 commit into
Open
Conversation
The in-place rewrite pattern (allocate temp, copy_ back into the input's storage) provides no observable benefit at any current call site. All nine callers use the return value, and none rely on the side effect of mutating the input tensor's storage. The current implementation is also subtly worse for non-contiguous inputs. The leading `data = data.contiguous()` silently allocates a new storage, which already breaks the "in-place" intent of the subsequent `copy_()`. The temporary `data1` buffer then pushes peak memory to 3N instead of 2N. In particular, `load_checkpoint.py:132` can receive a `torch.split(tmp_data, ..., dim=dim)[rank]` result that is non-contiguous when `dim == 1`, which triggers exactly this case. Replacing the body with `data.transpose(-1, -2).contiguous()`: - keeps the 2N peak on contiguous input, - drops 3N to 2N on non-contiguous input, - preserves return-value semantics for every current caller, - drops the now-unused `import torch`. Signed-off-by: binchengxiong <binchengxiong@alibaba-inc.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
The
transpose()helper indeepspeed/module_inject/utils.py(consolidated from duplicate definitions in #7934) currently performs a temp-buffer +copy_()rewrite of the input tensor's storage:This PR replaces the body with a one-liner:
Two reasons
1. Subtle memory waste on non-contiguous input.
When the input is non-contiguous,
data = data.contiguous()silently allocates a new storage. The follow-upcopy_()then writes into that fresh storage, not into the caller's original storage — so the "in-place" intent is already lost. Worse, the temporarydata1buffer pushes peak memory to 3N instead of the 2N that.contiguous()alone would incur.load_checkpoint.py:132is a concrete case:weight_partitioncan come fromtorch.split(tmp_data, dst_shape[dim1], dim=dim)[rank]and is non-contiguous whendim == 1.2. The in-place rewrite is not relied on by any caller.
A full audit of every call site shows all callers use the return value and none depend on the side effect of mutating the input's storage:
policy.py:154tmp = sd[src_name]policy.py:157tmp = sd[src_name]policy.py:159_transpose(...).contiguous()policy.py:161tmppolicy.py:181torch.cat((q,k,v), dim=0)policy.py:184qkv_datapolicy.py:198torch.cat((reg, gate), dim=0)load_checkpoint.py:83sd[0][prefix+n].to(device)load_checkpoint.py:132torch.split(...)/torch.cat(...)Behavior
Changes
deepspeed/module_inject/utils.py: rewritetranspose()body; remove now-unusedimport torch.No call site changes.
Tests
pre-commit run --files deepspeed/module_inject/utils.pypasses locally (yapf, flake8, check-torchdist, license header, codespell).