Skip to content

[Feature][Performance] NextObservationDelta env transform#3777

Merged
vmoens merged 6 commits into
pytorch:mainfrom
vmoens:worktree-next-obs-delta
May 21, 2026
Merged

[Feature][Performance] NextObservationDelta env transform#3777
vmoens merged 6 commits into
pytorch:mainfrom
vmoens:worktree-next-obs-delta

Conversation

@vmoens
Copy link
Copy Markdown
Collaborator

@vmoens vmoens commented May 18, 2026

Summary

  • Adds NextObservationDelta, a stateless env-side transform that stores ("next", obs) as a low-precision delta from the root obs for rollout memory savings on large continuous observations.
  • Wires up the previously-stubbed _post_step_mdp_hooks extension point in EnvBase.step_and_maybe_reset and threads it through Transform, Compose, and TransformedEnv. The hook receives both the post-step and post-step-mdp tensordicts so a transform can rehydrate the flowing td that the policy reads on the next iteration.
  • NextObservationDelta._step writes (next_obs - obs).to(delta_dtype) (default float16); _post_step_mdp_hooks reconstructs obs + delta in restore_dtype (default: match root). Stateless — no caching across steps.

Why this shape

The existing compact_obs collector flag + NextStateReconstructor RB transform attack the same problem by dropping ("next", obs) entirely and shifting at sample time. That is zero-storage but lossy at trajectory boundaries (which become NaN). The delta variant trades a small precision loss for boundary-preserving reconstruction and an env-side hook that does not need to know about collector internals.

The _post_step_mdp_hooks mechanism was already stubbed (commented out) in common.py, transforms/_base.py, and llm/chat.py. This PR enables it. The signature was changed from the original comment ((tensordict_,) -> tensordict_) to (tensordict, tensordict_) -> tensordict_ because rehydration needs read access to the post-step root obs. No caller existed before, so this is not a breaking change.

v1 limitations (documented on the class)

  • Lossy. Round-trip error scales with delta_dtype precision and observation magnitude.
  • Memory savings require non-pre-allocated stacked output. SyncDataCollector(use_buffers=False) or a lazy RB storage. Pre-allocated _final_rollout upcasts the write back to the original dtype and erases the saving.
  • Hook fires from step_and_maybe_reset only. env.rollout() is not wired in v1; direct rollout callers must rehydrate manually.
  • check_env_specs does not pass on the transformed env. observation_spec is shared between root and ("next", ...) in TorchRL; the transform does not fork it in v1 (a follow-up could). Tests use a reset+step smoke instead.
  • Batched-env composition. For SerialEnv/ParallelEnv, the transform belongs outside the batched env (i.e. TransformedEnv(ParallelEnv(...), NextObservationDelta())) — that path uses the outer step_and_maybe_reset and the hook fires. Putting the transform inside each worker is allowed and runs without error, but the outer batched env's step_and_maybe_reset does not currently propagate the hook so the stacked output upcasts.

Out of scope (potential follow-ups)

  • Forking observation_spec so pre-allocated _final_rollout benefits from the compression.
  • Wiring the hook in _rollout_stop_early and in batched_envs / async_envs / envpool step_and_maybe_reset.
  • A replay-buffer-side delta transform paired with this one.
  • Benchmark entry under benchmarks/.

Test plan

  • pytest test/transforms/test_observation_transforms.py::TestNextObservationDelta — 14 passed, 2 documented skips.
  • pytest --doctest-modules torchrl/envs/transforms/_observation.py -k NextObservationDelta — passes.
  • pytest test/envs/test_env_base.py — 47 passed, 4 skipped (no regressions from the hook wiring).
  • Manual smoke against GymEnv("Pendulum-v1") confirms ("next", "observation").dtype == torch.float16 post-step and torch.float32 on the flowing td, with bitwise-exact rehydration (max diff 0.0).
  • Compose(NextObservationDelta, RewardSum) works in both orderings.
  • Wider CI sweep (compose + env-transforms suites) — local disk filled before completing; relying on CI.

Adds a stateless env-side transform that stores `("next", obs)` as a
low-precision delta from the root `obs`, reducing the rollout-time
memory footprint of large continuous observations.

The transform compresses next observations in `_step` and rehydrates
the flowing tensordict's root observation in a new
`_post_step_mdp_hooks` extension point on `EnvBase`. The hook was
previously half-stubbed in `common.py` / `_base.py` / `llm/chat.py`;
it is now wired through `step_and_maybe_reset` and threaded into
`Transform`, `Compose`, and `TransformedEnv`.

Caveats documented on the class:

- The compression is lossy; round-trip error scales with delta dtype
  precision and observation magnitude.
- Memory savings only materialize against non-pre-allocated stacked
  output (e.g. `SyncDataCollector(use_buffers=False)` or a lazy RB
  storage). Pre-allocated buffers upcast the write.
- The hook fires from `step_and_maybe_reset`; direct `env.rollout()`
  callers must rehydrate manually.
- `check_env_specs` rejects the transformed env in v1 because the
  observation spec is shared between root and `("next", ...)` and we
  do not fork it.

Includes a `TestNextObservationDelta` test class with 16 cases
(14 passing, 2 documented skips) covering single-env, serial/parallel
batched envs (inner and outer wrapping), auto-inference skipping
non-floating dtypes, multi-key, reset semantics, Compose ordering,
and an end-to-end `SyncDataCollector(use_buffers=False)` check that
the stacked batch carries `float16` `("next", obs)`.
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 18, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3777

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (1 Unrelated Failure)

As of commit 9567594 with merge base 996387f (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 18, 2026
@github-actions github-actions Bot added Documentation Improvements or additions to documentation Transforms Feature New feature labels May 18, 2026
vmoens added 3 commits May 18, 2026 18:05
- Wire `_post_step_mdp_hooks` in `EnvBase._rollout_stop_early` so
  `env.rollout(..., break_when_any_done=True)` rehydrates the flowing
  td just like `step_and_maybe_reset` already did. The non-stop path
  already routed through `step_and_maybe_reset` and is unchanged.

- Add `Transform.transform_fake_tensordict(td)` hook (no-op default),
  iterated by `Compose`, called by a new `TransformedEnv.fake_tensordict`
  override. `NextObservationDelta` overrides it to cast the
  `("next", key)` leaves to `delta_dtype` in the spec-derived fake td.
  Pre-allocated `_final_rollout` in `SyncDataCollector(use_buffers=True)`
  now reserves storage at the compressed dtype rather than upcasting
  writes; the collector test covers both `use_buffers={True, False}`.

- Add `Transform._check_batched_worker_compat()` (no-op default).
  `NextObservationDelta` raises with a clear message pointing at the
  correct usage pattern. `BatchedEnvBase._get_metadata` builds a
  transient probe env and runs the validator via a new `env_validator`
  kwarg on `get_env_metadata`, so the inner-batched configuration
  fails loudly at construction time instead of silently upcasting at
  runtime.

The remaining v1 caveat in the docstring is that `check_env_specs`
still does not pass: it calls `observation_spec.contains(("next", obs))`
and TorchRL shares `observation_spec` between root and `("next", ...)`
leaves, so a compressed dtype is rejected. Working around this
properly requires forking the spec system, which is out of scope for
this PR. Tests use a reset+step smoke instead.
Subtracting in delta_dtype (float16 by default) risks catastrophic
cancellation when next_obs and obs are close. Doing the subtraction
in the operands' source dtype and casting the result once preserves
significand bits and is strictly more accurate on round-trip.

The stored root obs is unchanged, so there is no asymmetry to
preserve between the on-the-fly delta and the value reconstructed
from storage.
@vmoens
Copy link
Copy Markdown
Collaborator Author

vmoens commented May 19, 2026

@elin-bdai @theap06 maybe you could help review this one?
It reduces the size of ("next", "observation") fields by 50% when the transform is applied by using low-prec representation of the delta between o_t and o'_t
That looks like a good compromise between efficiency and correctness to carry the last obs of a trajectory/rollout.

Copy link
Copy Markdown
Contributor

@elin-bdai elin-bdai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this! I'm going to test this with our longer training jobs this week to make sure loss of precision is not a problem. Just a comment in terms of reducing confusion.

Comment thread torchrl/envs/transforms/_observation.py Outdated
>>> td_root = env.reset()
>>> _ = td_root.set("action", env.action_spec.rand())
>>> td, td_ = env.step_and_maybe_reset(td_root)
>>> td["next", "observation"].dtype
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding correctly, I think it's confusing here when using NextObservationDelta() that what's inside td["next", "observation"] is the delta, but the tensordict is indistinguishable from when you don't use NextObservationDelta, so you're not sure if it's the delta or not stored in there. It could lead to confusion when inspecting the outputs at different points.

Comment thread torchrl/envs/transforms/_observation.py Outdated
# operand to ``delta_dtype`` first and subtracting in low precision
# (which would risk catastrophic cancellation for nearby values).
delta = (next_obs - obs).to(self.delta_dtype)
next_tensordict.set(key, delta)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to change the key here to {key}_delta?

@theap06
Copy link
Copy Markdown
Contributor

theap06 commented May 20, 2026

@vmoens Thanks for building this out! @elin-bdai on your first point, agree that it's easy to miss. The dtype change is the observable signal, and the docstring example showing td["next", "observation"].dtype helps.(personally, dont think its necessary) On the {key}_delta suggestion, I also agree here as it makes the off-policy footgun explicit.

@elin-bdai
Copy link
Copy Markdown
Contributor

@vmoens Thanks for building this out! @elin-bdai on your first point, agree that it's easy to miss. The dtype change is the observable signal, and the docstring example showing td["next", "observation"].dtype helps.(personally, dont think its necessary) On the {key}_delta suggestion, I also agree here as it makes the off-policy footgun explicit.

Good point. Yeah we can make the tradeoff of whether it's worth it. Thanks!

@vmoens
Copy link
Copy Markdown
Collaborator Author

vmoens commented May 20, 2026

I agree i'm working on making a ("next", "delta", "observation") kind of thing such that we never have that hidden bug and things are explicit. Funny that we all converged on a similar idea!
Codex has been pushing back for this but I think it's much cleaner.

vmoens added 2 commits May 20, 2026 15:35
The in-place compression overwrote `("next", observation)` with a half-precision
delta, which left consumers with a key that looked like a real next obs but
silently held a low-precision delta. It also violated the spec contract
(`observation_spec` is shared between root and `("next", ...)` leaves) and
required `transform_fake_tensordict` gymnastics to keep collector buffers
honest.

The new contract:

- `NextObservationDelta._step` leaves `("next", k)` alone and writes the
  delta under a sibling sub-tensordict key `("next", "delta", k)`.
- `_post_step_mdp_hooks` drops the full `("next", k)` from the post-step
  tensordict after `step_mdp` has already promoted it to root in the
  flowing td. No rehydration math; `step_mdp` does the work.
- A new RB-side transform `NextObservationDeltaReconstructor` recomputes
  `("next", k) = data[k] + data[("next", "delta", k)]` at sample time
  and drops the delta key. Sample-time reconstruction is exact within
  the round-trip precision of `delta_dtype`, and unlike
  `NextStateReconstructor` it has no trajectory-boundary NaN edge case
  (the delta encodes the actual transition).

API change visible across this PR (the hook was new in this PR, no
external callers):

- `_post_step_mdp_hooks` now returns `(tensordict, tensordict_)` instead
  of just `tensordict_`. Honest about mutating both tds. Threaded
  through `EnvBase.step_and_maybe_reset`, `EnvBase._rollout_stop_early`,
  `Transform`, `Compose`, and `TransformedEnv`.
- `_rollout_stop_early` is restructured to run `step_mdp` + hooks
  *before* copying the post-step td into the stacked rollout, so the
  hook's mutations (dropping the full next-obs slot) survive into the
  stack.

`NextObservationDelta` constructor lost the `restore_dtype` kwarg
(reconstruction happens on the RB side now); `auto_skip` also went
away (no more in-place overwrite, so no idempotency concern).

Tests cover env-side compression, paired-RB reconstruction (round-trip
tolerance, boundary handling, `drop_delta` toggle, end-to-end env →
collector → RB → sample), `env.rollout()` on both done-paths, and the
loud failure for `SerialEnv` / `ParallelEnv` workers containing the
transform.
Match the established TorchRL convention where a single transform class
handles both the env side (via `_step`) and the RB side (via `forward`).
The compression is the env-side operation; the reconstruction is the
RB-side operation; they are two halves of the same transform, not two
separate classes.

- `NextObservationDelta.forward(td)` now reads root `obs` and
  `("next", "delta", obs)`, writes `("next", obs) = obs + delta` in
  `restore_dtype`, and drops the delta key when `drop_delta=True`.
- `NextObservationDelta.__init__` regains `restore_dtype` and
  `drop_delta` kwargs (used by `forward`). `_step` still computes the
  delta in source dtype and casts once.
- `NextObservationDeltaReconstructor` is removed from
  `torchrl/envs/transforms/rb_transforms.py` and from all export points.
- The previous `test_transform_rb` (which asserted NotImplementedError
  on `rb.sample()`) is rewritten to assert the round-trip:
  `rb.extend` stores the delta, `rb.sample` reconstructs full
  `("next", obs)`. The standalone test class is renamed to
  `TestNextObservationDeltaForward` and now uses the single class
  on both sides.

Usage: attach the same configured `NextObservationDelta(in_keys=[...])`
on the env (compression) and on the replay buffer (reconstruction);
TorchRL dispatches `_step` vs `forward` based on the context.
@vmoens vmoens merged commit 6e46542 into pytorch:main May 21, 2026
108 of 109 checks passed
@vmoens vmoens deleted the worktree-next-obs-delta branch May 21, 2026 14:37
vmoens added a commit to vmoens/rl that referenced this pull request May 21, 2026
Add a new section between Knob 2 and Knob 2.5 that describes the
lossy-delta variant of memory-efficient observation storage shipped
in pytorch#3777:

- env-side: write `(next_obs - obs).to(delta_dtype)` under
  `("next", "delta", obs)` and drop the full-precision next obs.
- RB-side: the same transform reconstructs `("next", obs)` as
  `obs + delta` at sample time.

The new knob trades a smaller memory saving (~25% vs ~50%) for
boundary-preserving reconstruction: no NaN at trajectory ends, so
losses that bootstrap on truncated transitions get the real next
obs instead of the `V(obs[t])` fallback used by the value-estimator
sanitizer. MultiStep is still incompatible.

Cross-references:
- "When not to rehydrate" now points at NextObservationDelta as the
  alternative for truncated-bootstrap-heavy losses.
- Conclusion bullets include the delta knob alongside the compact +
  reconstructor pair.

The runnable code path is unchanged; the new section uses a
`.. code-block:: python` (non-executed) snippet, so the tutorial does
not depend on pytorch#3777 being merged first.
vmoens added a commit to vmoens/rl that referenced this pull request May 21, 2026
…NextObservationDelta

The conflict in torchrl/collectors/_single.py was between two extensions
of the compact_obs docstring -- HEAD added the tutorial / NaN-sanitizer /
MultiStep cross-references, main added the new shifted='compact' GAE
pairing. Resolved by keeping both.

Now that NextObservationDelta (pytorch#3777) is in main, point at it from the
three places that already cross-reference the memory-efficient knobs:

- torchrl/collectors/_single.py compact_obs docstring -- 'lossy-precision
  alternative that *does* preserve boundary transitions'.
- torchrl/collectors/_multi_base.py compact_obs docstring -- same line.
- torchrl/envs/transforms/rb_transforms.py NextStateReconstructor seealso
  -- mention the delta variant for the NaN-at-boundary case.
- torchrl/objectives/value/advantages.py _sanitize_next_obs_nan seealso
  -- mention the delta variant as an alternative that avoids NaN.

No code changes; docs only.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Documentation Improvements or additions to documentation Feature New feature Transforms

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants