Skip to content

feat(zero): enable torch.func transforms on engine for ZeRO 0/1/2#8026

Open
roycho96 wants to merge 5 commits into
deepspeedai:masterfrom
roycho96:feat/zero-engine-torch-func
Open

feat(zero): enable torch.func transforms on engine for ZeRO 0/1/2#8026
roycho96 wants to merge 5 commits into
deepspeedai:masterfrom
roycho96:feat/zero-engine-torch-func

Conversation

@roycho96
Copy link
Copy Markdown
Contributor

Follow-up to #7916 and #8023.

Makes torch.func.grad / grad_and_value / jacrev and vmap(grad) work when called directly on a DeepSpeed engine for ZeRO 0/1/2.

API ZeRO-0 ZeRO-1 ZeRO-2 ZeRO-3
torch.func.grad(lambda x: engine(x))(x) not yet
torch.func.grad_and_value(lambda x: engine(x))(x) not yet
torch.func.jacrev(lambda x: engine(x))(x) not yet
torch.func.vmap(torch.func.grad(...))(x_batch) not yet
torch.func.vmap(lambda x: engine(x))(x_batch) already ✓ already ✓ already ✓ not yet
engine.backward(loss) (regression)

vmap alone runs only the forward graph so it never hit the broken backward hooks and already worked before this PR; included in the table for completeness.

Usage:

engine, _, _, _ = deepspeed.initialize(model=model, ...)

# input gradient
g = torch.func.grad(lambda xi: engine(xi))(x)

# gradient and value in one pass
g, v = torch.func.grad_and_value(lambda xi: engine(xi))(x)

# Jacobian of output w.r.t. input
J = torch.func.jacrev(lambda xi: engine(xi))(x)

# per-sample input gradients (batched)
per_sample_g = torch.func.vmap(torch.func.grad(lambda xi: engine(xi)))(x_batch)

ZeRO-3 hits a separate SIGSEGV from the same APIs and is tracked separately.

Test:
pytest tests/unit/v1/zero/test_zero_torch_func.py

roycho96 added 3 commits May 25, 2026 22:25
torch.func.grad / grad_and_value / jacrev invoke autograd through
torch.autograd.grad, which fires the engine's output-tensor hooks but
intentionally bypasses engine.backward(). The prologue then raises on
ZeRO-0 (the safety net for direct loss.backward() callers) and the
epilogue indexes empty ZeRO-1/2 grad bucket bookkeeping that the
transformed graph never populated. Parameters are not leaves under
the transform, so per-param post-accumulate-grad hooks never fire.

Detect the active functorch interpreter via
torch._C._functorch.peek_interpreter_stack and short-circuit both
hooks early. The existing safety net for non-functorch direct
loss.backward() callers (deepspeedai#7665) is preserved.

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
…0/1/2

Compare each transform's output to a non-DeepSpeed baseline cloned
from the same initialization so a future regression that silently
zeros gradients fails the test. Includes a negative case that locks
in the ZeRO-0 direct-loss.backward() safety net.

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 4a4bd2ad5a

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread deepspeed/runtime/engine.py
roycho96 added 2 commits May 26, 2026 00:43
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant