Skip to content

feat(rlix-task2): selective sync sender/receiver — CUDA IPC, NCCL broadcast, bucket cache (F4, F6)#3

Open
zhenyulincs wants to merge 10 commits intorlops:mainfrom
zhenyulincs:main
Open

feat(rlix-task2): selective sync sender/receiver — CUDA IPC, NCCL broadcast, bucket cache (F4, F6)#3
zhenyulincs wants to merge 10 commits intorlops:mainfrom
zhenyulincs:main

Conversation

@zhenyulincs
Copy link
Copy Markdown
Collaborator

@zhenyulincs zhenyulincs commented Apr 25, 2026

What does this PR do ?

Adds the NeMo RL side of Task 2: the training worker sender logic and the vLLM inference worker receiver logic required for selective weight sync.

On the sender side (megatron_policy_worker.py), build_latest_bucket_cache() gathers model weights from all TP/PP/CP/EP ranks into a CPU BucketRecord cache (only the cache owner stores). selective_sync_active_cache() reads the
cache under _cache_lock and transports each bucket to target inference workers — via CUDA IPC handle for same-GPU targets or NCCL broadcast for cross-GPU.

On the receiver side (vllm_backend.py), update_parameter_in_bucket() supports both transport modes: CUDA IPC zero-copy reconstruction via rebuild_cuda_tensor(), or cpu_serialize DMA via pin_memory().to(device). All six required
receiver methods are implemented and exposed through vllm_generation.py with correct Ray actor phase barriers.

Issues

  • Implements the 6 receiver API methods required by spec lines 613–649: setup_collective_group, update_parameter_in_bucket, broadcast_parameter, destroy_collective_group, verify_model, finalize_weight_update
  • Fixes CUDA IPC receiver: was copying GPU→CPU→GPU (roundtrip); now uses rebuild_cuda_tensor() directly (zero-copy)
  • Fixes receiver rank mask: was using dist.get_rank() instead of self.rank
  • Fixes _cache_lock scope: now spans cache lookup → all bucket transport → sender NCCL teardown (spec lines 401–402)
  • Fixes oversized tensor: a single param larger than bucket_size_bytes now raises RuntimeError before appending
  • Fixes vllm_generation.py pass-throughs: all 6 methods now call ray.get(futures) before returning, ensuring outer barriers are correct
  • Wires trajectory collector: grpo.py registers AsyncTrajectoryCollector as named Ray actor rlix:trajectory_collector:{pipeline_id} so the pipeline can publish set_weight_version at the right time

Usage

Sender — called by pipeline after train step
worker.build_latest_bucket_cache.remote(checkpoint_version=step)
worker.promote_active_checkpoint.remote(version=step)

Sync — called by ModelUpdateService
worker.selective_sync_active_cache.remote(
sync_id=..., comm_plan=..., tgt_dp_ranks=...,
tgt_workers=..., tgt_device_mapping=...,
tgt_num_gpus_per_worker=...,
model_update_transport="cpu_serialize", # or "cuda_ipc"
)

Set bucket size (required, no default)
export RLIX_BUCKET_SIZE_BYTES=$((256 * 1024 * 1024))

Run CUDA IPC cross-process test
python tests/integration/test_gate2_5_cuda_ipc.py

Run selective sync NCCL test (4 GPU)
NCCL_P2P_DISABLE=1 NCCL_SHM_DISABLE=1
torchrun --nproc-per-node=4 tests/integration/test_gate2_5_selective_sync.py

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

Modified files:

  • nemo_rl/models/policy/workers/megatron_policy_worker.py — sender: build_latest_bucket_cache, promote_active_checkpoint, selective_sync_active_cache, setup_collective_group, destroy_collective_group
  • nemo_rl/models/generation/vllm/vllm_backend.py — receiver: all 6 methods with CUDA IPC + cpu_serialize branching
  • nemo_rl/models/generation/vllm/vllm_generation.py — Ray actor pass-throughs with ray.get(futures) barriers
  • nemo_rl/algorithms/grpo.py — named trajectory collector registration

Depends on rlix/pipeline/bucket_cache.py from the rlix PR above (imported by vllm_backend.py via from rlix.pipeline.bucket_cache import BucketRecord, unpack_bucket_record)

anwithk and others added 6 commits April 12, 2026 22:48
Signed-off-by: anwithk <anwithk@nvidia.com>
Co-authored-by: Terry Kong <terryk@nvidia.com>
Thin adapter methods for rlix Task 2 (CPU bucket cache integration):

- get_cpu_weight_shards(): exports model weights to CPU via existing
  megatron_bridge.export_hf_weights PP collective gather. Only cache
  owner (pp0/dp0/tp0) returns data; other ranks return empty dict.
- promote_active_checkpoint(version): stores _rlix_active_checkpoint_version
  as a version marker; rlix BucketCacheLifecycle manages the lifecycle.

All scheduling logic, cache storage, dirty tracking, and transport
remain in rlix. This file only exposes the data.
Replace buf[offset:offset+1].view(dtype).element_size() with
torch.empty(0, dtype=dtype).element_size() to match the canonical
unpack logic in bucket_cache.unpack_bucket_record(). The slice-view
approach is the bug documented in TASK2_IMPLEMENTATION.md:83-88.
…x review fixes

Feature 4 — CPU bucket cache:
- build_latest_bucket_cache: fail-fast when single tensor > bucket_size_bytes
- bucket_size_bytes explicit config — RuntimeError if not set (no 256 MB default)
- Host-RAM check in build_latest_bucket_cache using actual packed model size
- promote_active_checkpoint: accept version= keyword (not checkpoint_version=)

Feature 6 — Selective sync transport:
- selective_sync_active_cache: model_update_transport param; branches to
  cuda_ipc (CUDA IPC handle via get_handle_from_tensor) or cpu_serialize
- _cache_lock spans transport + sender-side NCCL teardown (spec lines 401-402)
- torch.cuda.synchronize() before destroy_collective_group
- Receiver NCCL teardown triggered by ModelUpdateService (Phase 4)
- update_parameter_in_bucket (vllm_backend.py):
  - rank mask uses self.rank not dist.get_rank()
  - cuda_ipc: zero-copy GPU reconstruction via rebuild_cuda_tensor (no CPU roundtrip)
  - cpu_serialize: pin_memory DMA for efficiency
- vllm_generation.py: all pass-through methods await sub-worker futures (phase barriers)
- is_lora: bool = False added to update_parameter_in_bucket + broadcast_parameter

grpo.py:
- trajectory_collector registered as named Ray actor (rlix:trajectory_collector:{id})
  so pipeline can resolve and call set_weight_version()
@zhenyulincs zhenyulincs self-assigned this Apr 25, 2026
@zhenyulincs zhenyulincs changed the title feat(task2): CPU bucket cache + selective weight sync (F4, F6-transport, Gate 2.5) feat(rlix-task2): selective sync sender/receiver — CUDA IPC, NCCL broadcast, bucket cache (F4, F6) Apr 25, 2026
…y with pyproject.toml

63154570 → 15a851565 — setup.py was not updated when Megatron-LM submodule was bumped.
Fresh uv sync would fail for anyone cloning from scratch.
group_name,
bucket.param_names,
bucket.dtypes,
bucket.shapes,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能导致死锁。dist.broadcast() 是 NCCL collective,sender 调用后会阻塞当前 Python 线程直到所有 receiver 也调用了对应的 dist.broadcast()。但 receiver 的 broadcast_parameter.remote() 是在 sender 阻塞之后才发出去的——sender 的线程被卡住,.remote() 永远不会提交,receiver 永远不会进入 collective → 死锁
可不可以考虑:先 dispatch .remote() 给所有 receiver,再调用 dist.broadcast(),再 ray.get(recv_refs)。


local_rank = (
torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以改改。 torch.distributed.get_rank()会返回全局 rank,broadcast_local_ranks 里存的是 local worker ranks。这导致下一步’’if local_rank not in broadcast_local_ranks: return‘’ 在TP>1 / 多节点下提前 return静默失败。
试下这个:
local_rank = getattr(self, "rank", None)
if local_rank is None:
local_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0

comm_plan=comm_plan,
mode=mode,
timeout_s=timeout_s,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在多节点下会导致只setup 节点0,例如如果有2TP, TP rank 1 没有 setup group,它后面就没法 broadcast。后续NCCL collective group 需要每个参与 rank 都必须加入。 可以改成run_rank_0_only_axes=[],
所有可能要改的方法:
setup_collective_group
update_parameter_in_bucket
broadcast_parameter
destroy_collective_group
verify_model
finalize_weight_update

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants