Skip to content

[PyTorch][CP] Reduce P2P forward peak memory: O(C) _ O(1)#2916

Draft
sudhakarsingh27 wants to merge 3 commits intoNVIDIA:mainfrom
sudhakarsingh27:sudhakars/p2p_mem_opt_pr
Draft

[PyTorch][CP] Reduce P2P forward peak memory: O(C) _ O(1)#2916
sudhakarsingh27 wants to merge 3 commits intoNVIDIA:mainfrom
sudhakarsingh27:sudhakars/p2p_mem_opt_pr

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Collaborator

Summary

Reduce P2P (ring attention) forward-pass peak memory from O(cp_size) to O(1) (constant in cp_size) via two optimizations:

  • Opt1 — KV comm buffer double-buffering: p2p_comm_buffers held cp_size entries on main. With only send/recv needed at once, we only need 2.
  • Opt2 — Incremental output correction: Replace the post-loop output merge (which required holding cp_size out_per_step tensors) with online softmax merge during the main loop. Each step's partial output is immediately merged into a running accumulator using exp(old_lse − new_lse) rescaling.

Together: p2p_comm_buffers and out_per_step are both reduced from cp_size slots to 2 slots — constant regardless of CP size.

Commits

  1. [PyTorch][CP] Double-buffer P2P KV comm buffers in forward pass — Opt1
  2. [PyTorch][CP] Incremental output correction in P2P forward pass — Opt2
  3. [PyTorch][CP] Fix THD dtype mismatch in incremental output correction — bugfix for Opt2 (thd_out_correction kernel requires matching dtype)

Peak memory impact (theoretical)

In units of B × S_r × H × D × elem (one per-rank Q/K/V tensor):

Method main (baseline) This PR (Opt1+2) Savings
P2P forward peak (thd) 3C + 2 units ~8 units scales away at large C
P2P forward peak (bshd/sbhd) 3C + 3 units ~9 units

Concrete (B=2, S=262K, H=16, D=128, bf16): P2P forward peak drops from ~6.3 GiB to ~1.1 GiB at CP=64 (20× reduction).

At CP=8 on main, P2P baseline is 26 units — worse than AllGather (21 units). With this PR, P2P becomes the lowest-memory option again.

Validation against measured data

We decomposed measured peak memory from a multi-node GB200 benchmark (B=2, S=262144, H=16, D=128, bf16, THD) into four additive terms that account for every byte:

CP Globals† Locals P2P algo (3C+4 units) cuDNN workspace Total Measured
4 8,192 2,048 8,192 5,680 24,112 24,112
8 8,192 1,024 7,168 2,856 19,240 19,240
16 8,192 512 6,656 1,444 16,804 16,804
32 8,192 256 6,400 738 15,586 15,586
64 8,192 128 6,272 385 14,977 14,977

All values in MiB. †Globals are an artifact of the benchmark script not deleting full-sequence test tensors before reset_peak_memory_stats; does not reflect real-world usage.

cuDNN flash-attention workspace is ~11 units for P2P, ~14 units for A2A — constant across CP sizes because it scales with S_r × H (same as one unit).

Changes

  • transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py:
    • AttnFuncWithCPAndKVP2P.forward: reduce p2p_comm_buffers to 2 slots (i % 2 indexing), add old_softmax_lse clone before LSE correction, replace post-loop output correction with per-step online softmax merge
    • New @jit_fuser helpers: flash_attn_fwd_out_correction_init_incremental (bshd/sbhd), incremental correction variants for causal half-split
    • THD path: Python mul_ rescale of running out + reuse of thd_out_correction kernel
    • Dtype fix: out accumulator clones the first out_per_step (bf16) instead of being cast to float32, matching the kernel's dtype contract

Test plan

  • Unit tests pass: test_cp_with_fused_attention and test_cp_with_flash_attention across p2p/a2a/ag × bshd/sbhd/thd × causal/non-causal
  • Measured peak memory benchmark on GB200 (pending)
  • End-to-end training run validation (pending)

Status

Draft — opening for early feedback on the approach. Will mark ready for review after measured benchmarks.

The forward pass allocated cp_size KV communication buffers but only
ever needed 2 live at any time (current compute + next recv). This
mirrors the backward pass which already uses 2-entry double-buffering.

Convert p2p_comm_buffers from a cp_size-length list to a 2-entry list
with i%2 rotation. Saves (cp_size-3) buffer copies at peak — measured
2.6 GB at cp=8 (S=262k, B=2, H=16, D=128) with zero perf regression.

Also add bariamis benchmark configs (H=16, S=4k/8k) to flash_attn
test suite for correctness coverage at the exact config used in CP
communication benchmarking.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Replace post-loop output correction with online softmax merge during the
main loop. Each step's partial output is immediately merged into a running
accumulator using exp(old_lse - new_lse) rescaling, eliminating the need
to store all cp_size out_per_step and softmax_lse_per_step tensors.

Double-buffers out_per_step, softmax_lse_per_step, and max_logit_per_step
(2 slots each). rng_states and attn_biases remain at cp_size for backward.

All three QKV formats supported:
- bshd/sbhd: new @jit_fuser incremental correction helpers
- THD packed LSE: Python mul_ rescale + existing thd_out_correction kernel
- THD unpacked LSE (legacy): clone+zero+reconstruct fallback

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
The thd_out_correction CUDA kernel requires out and out_per_step to share
the same dtype. Using .to(torch.float32) for THD init broke this contract
since out_per_step stays in bf16/fp16. Use .clone() instead — the kernel
handles float promotion internally per-element.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant