[Draft] Add On-Policy Distillation (OPSD) Trainer in DeepSpeed#8027
[Draft] Add On-Policy Distillation (OPSD) Trainer in DeepSpeed#8027PKUWZP wants to merge 4 commits into
Conversation
First slice of the on-policy distillation example app under examples/opsd/. This commit lands the framework-agnostic foundation: the OPSDConfig dataclass hierarchy, chunked / streamed forward-KL / reverse-KL / JSD losses with sequence-axis chunking to bound peak memory, response-mask + shift helpers, and a 24-case CPU-only test suite covering identity, masking, chunk equivalence, gradient flow, and numerical edge cases. Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Adds the two-phase teacher path:
* TeacherWrapper loads a HuggingFace causal LM, freezes it, and runs
forward-only. Two modes: load + pin on GPU (offload_to_cpu=false), or
wrap with deepspeed.initialize using a ZeRO-3 + offload_param=cpu
config (offload_to_cpu=true). Avoids deepspeed.zero.Init() around
from_pretrained because HF's loader partitions params to zero-width
shards before the checkpoint can fill them.
* TeacherLogitCache stages the [B, T, V] teacher logits to (pinned) host
memory in bf16, and exposes chunk_to_device() so the student-side loss
can pull sequence slices back on demand. This is the memory-economising
half of the two-phase update.
CPU-only tests cover the cache shape / dtype / round-trip / chunk-bounds
behaviour and verify the streamed-via-cache loss matches the direct
chunked loss bit-for-bit.
Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Lands the fully-runnable hybrid-engine training path: a backend-agnostic RolloutEngine ABC with RolloutRequest / RolloutBatch / SamplingConfig dataclasses, a HybridEngineRollout implementation that uses DeepSpeed's accelerated decode when an inference policy exists and otherwise falls back to GatheredParameters + the raw HF generate (covers Qwen-family and other models not in DeepSpeed's inference container list), a left-padded prompt dataset + collator, a three-phase trainer loop (rollout -> teacher forward + cache -> student forward + streamed KL + backward + step), the argparse + deepspeed.initialize entry point, base DeepSpeed ZeRO-3 + hybrid_engine JSON configs, a 5-step smoke config and launcher script, and a 20-prompt math toy dataset for the smoke run. Smoke-validated end-to-end on 2x H200 with Qwen2.5-0.5B-Instruct student and Qwen2.5-1.5B-Instruct teacher; loss finite for 5 steps. Rollout interface contract is covered by tests/test_rollout_interface.py. Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Lands the second-stage rollout path, weight-sync infrastructure, and the
example app's README. Includes:
* VLLMRollout that constructs vllm.LLM on training rank 0 and broadcasts
generated token ids to peer ranks, with disjoint-GPU (subprocess) and
shared (in-process) topology paths. Weight sync gathers ZeRO-3 params
cooperatively then pushes to vLLM via LLM.collective_rpc("load_weights").
* WeightBridge ABC with COLUMN / ROW / VOCAB / REPLICATED parallel kinds
and an even-slice per-rank slicer; Qwen2WeightBridge with the full
per-parameter table for Qwen2 / Qwen2.5; Qwen3WeightBridge adding the
per-head q_norm / k_norm tensors as REPLICATED.
* vLLM-side prompt+response stitching factored into stitch_rollout() so
its index math is unit-testable without a live vLLM.
* CPU-only tests: tests/test_weight_bridge.py covers parallel-kind
dispatch, per-rank shape/gather round-trips across tp_size in {1,2,4},
indivisibility / invalid-rank guards, and the registry;
tests/test_vllm_stitch.py covers prompt/response stitching for the
common shapes including variable response lengths and left-padded
prompts.
* configs + launch scripts for both production and smoke vLLM runs.
**Known blocker called out in README and module docstring:** vLLM's worker
init calls new_group() on the global process group, which deadlocks when
launched under the standard `deepspeed --num_gpus N` launcher (rank 0
calls vLLM, other ranks never participate in vLLM's collective). The
documented fix is the TRL/OpenRLHF separate-server pattern; this PR lands
the scaffolding so that work can begin against a green codebase.
Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 6384396b48
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| NUM_TRAIN_GPUS="${NUM_TRAIN_GPUS:-6}" | ||
| INCLUDE_GPUS="${INCLUDE_GPUS:-0,1,2,3,4,5}" | ||
|
|
||
| deepspeed --num_gpus "${NUM_TRAIN_GPUS}" --include "localhost:${INCLUDE_GPUS}" \ |
There was a problem hiding this comment.
Remove conflicting DeepSpeed launch flags
This launch command passes --num_gpus together with --include, which DeepSpeed rejects (deepspeed/launcher/runner.py raises ValueError("Cannot specify num_nodes/gpus with include/exclude")). As written, the documented/default vLLM training script fails before main.py starts, so users cannot run the disjoint-GPU workflow at all.
Useful? React with 👍 / 👎.
| loader = DataLoader( | ||
| dataset, | ||
| batch_size=cfg.training.micro_batch_size_per_gpu, | ||
| shuffle=cfg.data.shuffle, | ||
| collate_fn=collator, |
There was a problem hiding this comment.
Shard prompts across data-parallel ranks
The dataloader is created without a DistributedSampler, so in multi-rank DeepSpeed runs every rank iterates the full dataset instead of a disjoint shard. That means prompts are repeatedly processed across ranks each epoch, reducing effective data diversity and distorting expected global-batch behavior for OPSD training.
Useful? React with 👍 / 👎.
| ds_config = _load_ds_config(cfg.deepspeed_config) | ||
| ds_config["train_micro_batch_size_per_gpu"] = cfg.training.micro_batch_size_per_gpu | ||
| ds_config["train_batch_size"] = cfg.training.train_batch_size | ||
| ds_config["gradient_accumulation_steps"] = cfg.training.gradient_accumulation_steps |
There was a problem hiding this comment.
Propagate training LR and warmup into DS config
Only batch-size fields are copied from OPSDConfig.training into the DeepSpeed config; optimizer/scheduler knobs like learning_rate, weight_decay, and warmup_steps are ignored. As a result, changing those values in opsd_*.json silently has no effect, and runs will keep using whatever is hardcoded in the referenced DeepSpeed JSON file.
Useful? React with 👍 / 👎.
Summary
Adds a DeepSpeed-native on-policy distillation trainer under
examples/opsd/.On-policy distillation: a small student generates rollouts, a frozen large teacher scores them, and the student is updated by a per-token divergence (forward-KL / reverse-KL / JSD) between the two distributions on the student's own samples. Each step has three phases — student rollout → teacher forward + CPU logit cache → student forward + streamed divergence + backward — so the full
[B, T, V]teacher tensor never co-resides with the student logits on the training device.Modules (~3.2k LOC across 32 files):
opsd/losses.py— chunked / streamed forward-KL, reverse-KL, JSD with sequence-axis chunkingopsd/teacher.py— frozen teacher wrapper +TeacherLogitCache(host-resident, chunk fetch)opsd/rollout/{base,hybrid_engine,vllm}.py— backend-agnosticRolloutEngineABC + two backendsopsd/weight_bridge/{base,qwen2,qwen3}.py— per-archParallelKind+ TP slicer for vLLM weight syncopsd/trainer.py+main.py— three-phase loop +deepspeed.initializeentry pointconfigs/+scripts/— production and smoke configs for both rollout backendstests/— 87 CPU-only tests covering loss math, cache, rollout interface, weight-bridge slicing, vLLM stitchValidated end-to-end on 2× H200 with Qwen2.5-0.5B-Instruct student + Qwen2.5-1.5B-Instruct teacher via the hybrid-engine path; loss finite for 5 steps. See README for the smoke recipe.
Known follow-up documented in README +
opsd/rollout/vllm.pydocstring: vLLM rollout under the standarddeepspeed --num_gpus Nlauncher hits atorch.distributed.new_groupdeadlock (vLLM's worker calls it on the global PG, but only rank 0 participates). The fix is the TRL/OpenRLHF separate-server pattern; this PR lands the scaffolding (rollout class, weight bridges, configs, unit tests for the testable pieces) so that work can begin against a green codebase.Test plan
cd examples/opsd && python -m pytest tests/ -v→ 87/87 passing on CPUdeepspeed --num_gpus 2 main.py --config configs/smoke_hybrid.jsonend-to-end on 2× H200 → 5 finite-loss stepspre-commit run --files <all changed files>→ green (yapf, flake8, check-torchdist, check-license, check-torchcuda, codespell)