Skip to content

[Draft] Add On-Policy Distillation (OPSD) Trainer in DeepSpeed#8027

Open
PKUWZP wants to merge 4 commits into
deepspeedai:masterfrom
PKUWZP:zhipwang_opd_pr
Open

[Draft] Add On-Policy Distillation (OPSD) Trainer in DeepSpeed#8027
PKUWZP wants to merge 4 commits into
deepspeedai:masterfrom
PKUWZP:zhipwang_opd_pr

Conversation

@PKUWZP
Copy link
Copy Markdown
Collaborator

@PKUWZP PKUWZP commented May 26, 2026

Summary

Adds a DeepSpeed-native on-policy distillation trainer under examples/opsd/.

On-policy distillation: a small student generates rollouts, a frozen large teacher scores them, and the student is updated by a per-token divergence (forward-KL / reverse-KL / JSD) between the two distributions on the student's own samples. Each step has three phases — student rollout → teacher forward + CPU logit cache → student forward + streamed divergence + backward — so the full [B, T, V] teacher tensor never co-resides with the student logits on the training device.

Modules (~3.2k LOC across 32 files):

  • opsd/losses.py — chunked / streamed forward-KL, reverse-KL, JSD with sequence-axis chunking
  • opsd/teacher.py — frozen teacher wrapper + TeacherLogitCache (host-resident, chunk fetch)
  • opsd/rollout/{base,hybrid_engine,vllm}.py — backend-agnostic RolloutEngine ABC + two backends
  • opsd/weight_bridge/{base,qwen2,qwen3}.py — per-arch ParallelKind + TP slicer for vLLM weight sync
  • opsd/trainer.py + main.py — three-phase loop + deepspeed.initialize entry point
  • configs/ + scripts/ — production and smoke configs for both rollout backends
  • tests/ — 87 CPU-only tests covering loss math, cache, rollout interface, weight-bridge slicing, vLLM stitch

Validated end-to-end on 2× H200 with Qwen2.5-0.5B-Instruct student + Qwen2.5-1.5B-Instruct teacher via the hybrid-engine path; loss finite for 5 steps. See README for the smoke recipe.

Known follow-up documented in README + opsd/rollout/vllm.py docstring: vLLM rollout under the standard deepspeed --num_gpus N launcher hits a torch.distributed.new_group deadlock (vLLM's worker calls it on the global PG, but only rank 0 participates). The fix is the TRL/OpenRLHF separate-server pattern; this PR lands the scaffolding (rollout class, weight bridges, configs, unit tests for the testable pieces) so that work can begin against a green codebase.

Test plan

  • cd examples/opsd && python -m pytest tests/ -v87/87 passing on CPU
  • deepspeed --num_gpus 2 main.py --config configs/smoke_hybrid.json end-to-end on 2× H200 → 5 finite-loss steps
  • pre-commit run --files <all changed files> → green (yapf, flake8, check-torchdist, check-license, check-torchcuda, codespell)
  • vLLM rollout end-to-end (blocked, see above)
  • Larger-scale training run (out of scope for the initial PR)

Zhipeng Wang added 4 commits May 26, 2026 07:30
First slice of the on-policy distillation example app under examples/opsd/.
This commit lands the framework-agnostic foundation: the OPSDConfig dataclass
hierarchy, chunked / streamed forward-KL / reverse-KL / JSD losses with
sequence-axis chunking to bound peak memory, response-mask + shift helpers,
and a 24-case CPU-only test suite covering identity, masking, chunk
equivalence, gradient flow, and numerical edge cases.

Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Adds the two-phase teacher path:

  * TeacherWrapper loads a HuggingFace causal LM, freezes it, and runs
    forward-only. Two modes: load + pin on GPU (offload_to_cpu=false), or
    wrap with deepspeed.initialize using a ZeRO-3 + offload_param=cpu
    config (offload_to_cpu=true). Avoids deepspeed.zero.Init() around
    from_pretrained because HF's loader partitions params to zero-width
    shards before the checkpoint can fill them.

  * TeacherLogitCache stages the [B, T, V] teacher logits to (pinned) host
    memory in bf16, and exposes chunk_to_device() so the student-side loss
    can pull sequence slices back on demand. This is the memory-economising
    half of the two-phase update.

CPU-only tests cover the cache shape / dtype / round-trip / chunk-bounds
behaviour and verify the streamed-via-cache loss matches the direct
chunked loss bit-for-bit.

Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Lands the fully-runnable hybrid-engine training path: a backend-agnostic
RolloutEngine ABC with RolloutRequest / RolloutBatch / SamplingConfig
dataclasses, a HybridEngineRollout implementation that uses DeepSpeed's
accelerated decode when an inference policy exists and otherwise falls
back to GatheredParameters + the raw HF generate (covers Qwen-family and
other models not in DeepSpeed's inference container list), a left-padded
prompt dataset + collator, a three-phase trainer loop (rollout -> teacher
forward + cache -> student forward + streamed KL + backward + step), the
argparse + deepspeed.initialize entry point, base DeepSpeed ZeRO-3 +
hybrid_engine JSON configs, a 5-step smoke config and launcher script,
and a 20-prompt math toy dataset for the smoke run.

Smoke-validated end-to-end on 2x H200 with Qwen2.5-0.5B-Instruct student
and Qwen2.5-1.5B-Instruct teacher; loss finite for 5 steps. Rollout
interface contract is covered by tests/test_rollout_interface.py.

Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
Lands the second-stage rollout path, weight-sync infrastructure, and the
example app's README. Includes:

  * VLLMRollout that constructs vllm.LLM on training rank 0 and broadcasts
    generated token ids to peer ranks, with disjoint-GPU (subprocess) and
    shared (in-process) topology paths. Weight sync gathers ZeRO-3 params
    cooperatively then pushes to vLLM via LLM.collective_rpc("load_weights").

  * WeightBridge ABC with COLUMN / ROW / VOCAB / REPLICATED parallel kinds
    and an even-slice per-rank slicer; Qwen2WeightBridge with the full
    per-parameter table for Qwen2 / Qwen2.5; Qwen3WeightBridge adding the
    per-head q_norm / k_norm tensors as REPLICATED.

  * vLLM-side prompt+response stitching factored into stitch_rollout() so
    its index math is unit-testable without a live vLLM.

  * CPU-only tests: tests/test_weight_bridge.py covers parallel-kind
    dispatch, per-rank shape/gather round-trips across tp_size in {1,2,4},
    indivisibility / invalid-rank guards, and the registry;
    tests/test_vllm_stitch.py covers prompt/response stitching for the
    common shapes including variable response lengths and left-padded
    prompts.

  * configs + launch scripts for both production and smoke vLLM runs.

**Known blocker called out in README and module docstring:** vLLM's worker
init calls new_group() on the global process group, which deadlocks when
launched under the standard `deepspeed --num_gpus N` launcher (rank 0
calls vLLM, other ranks never participate in vLLM's collective). The
documented fix is the TRL/OpenRLHF separate-server pattern; this PR lands
the scaffolding so that work can begin against a green codebase.

Signed-off-by: Zhipeng Wang <zhipengbayern@gmail.com>
@PKUWZP PKUWZP requested a review from tohtana as a code owner May 26, 2026 07:31
@PKUWZP PKUWZP changed the title Add On-Policy Distillation (OPSD) example app [Draft] Add On-Policy Distillation (OPSD) Trainer in DeepSpeed May 26, 2026
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 6384396b48

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

NUM_TRAIN_GPUS="${NUM_TRAIN_GPUS:-6}"
INCLUDE_GPUS="${INCLUDE_GPUS:-0,1,2,3,4,5}"

deepspeed --num_gpus "${NUM_TRAIN_GPUS}" --include "localhost:${INCLUDE_GPUS}" \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Remove conflicting DeepSpeed launch flags

This launch command passes --num_gpus together with --include, which DeepSpeed rejects (deepspeed/launcher/runner.py raises ValueError("Cannot specify num_nodes/gpus with include/exclude")). As written, the documented/default vLLM training script fails before main.py starts, so users cannot run the disjoint-GPU workflow at all.

Useful? React with 👍 / 👎.

Comment thread examples/opsd/main.py
Comment on lines +112 to +116
loader = DataLoader(
dataset,
batch_size=cfg.training.micro_batch_size_per_gpu,
shuffle=cfg.data.shuffle,
collate_fn=collator,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Shard prompts across data-parallel ranks

The dataloader is created without a DistributedSampler, so in multi-rank DeepSpeed runs every rank iterates the full dataset instead of a disjoint shard. That means prompts are repeatedly processed across ranks each epoch, reducing effective data diversity and distorting expected global-batch behavior for OPSD training.

Useful? React with 👍 / 👎.

Comment thread examples/opsd/main.py
Comment on lines +80 to +83
ds_config = _load_ds_config(cfg.deepspeed_config)
ds_config["train_micro_batch_size_per_gpu"] = cfg.training.micro_batch_size_per_gpu
ds_config["train_batch_size"] = cfg.training.train_batch_size
ds_config["gradient_accumulation_steps"] = cfg.training.gradient_accumulation_steps
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Propagate training LR and warmup into DS config

Only batch-size fields are copied from OPSDConfig.training into the DeepSpeed config; optimizer/scheduler knobs like learning_rate, weight_decay, and warmup_steps are ignored. As a result, changing those values in opsd_*.json silently has no effect, and runs will keep using whatever is hardcoded in the referenced DeepSpeed JSON file.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant