From 7091267ec8aa583bcff1d163f28d6eb1347dc81e Mon Sep 17 00:00:00 2001 From: Jiabin Qin Date: Mon, 20 Apr 2026 13:36:09 +0800 Subject: [PATCH] Normalize served prev_action_chunk + reserved-key sample_kwargs passthrough MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two related fixes to served Policy.infer: 1. RTC prev_action_chunk is now normalized to model space before reaching sample_actions. Pi0RTC.sample_actions() consumes prev_action_chunk in model space (post-Normalize), but Policy.infer was forwarding obs["prev_action_chunk"] from the wire raw. Agilex inference clients send a raw deploy-space slice of their execution buffer, so the guidance term was operating on un-normalized inputs — a silent train-deploy contract break (masked because Agilex action norm-stats are close to unit-variance, so the magnitude error is small). The fix adds a _normalize_and_pad_prev_chunk helper that delegates to the same transforms.Normalize the serving pipeline uses (so use_quantile_norm is honored), pads to action_horizon, and is wired from the loaded checkpoint via three new optional Policy params (norm_stats, use_quantile_norm, action_horizon). policy_config wires them automatically — call sites unchanged. Also guards against silent d=0 cheap-path activation when a client sends prev_action_chunk without inference_delay. 2. Reserved-key obs["_sample_kwargs"] allowlist for transport-layer sample_kwargs overrides (currently: noise). Previously the websocket protocol dropped the noise= kwarg — making deterministic served eval impossible. The reserved-key namespace (leading underscore) avoids collision with future models that legitimately use observation field names like "noise". Explicit noise= kwarg (in-process callers) takes precedence. Both fixes are backward-compatible: existing callers see no behavior change. Existing RTC clients that previously sent raw deploy-space chunks will now receive the correct normalized chunk — this is the bug fix. --- src/openpi/policies/policy.py | 115 ++++++++++++++++++++++++--- src/openpi/policies/policy_config.py | 5 ++ 2 files changed, 108 insertions(+), 12 deletions(-) diff --git a/src/openpi/policies/policy.py b/src/openpi/policies/policy.py index e6ac080..d34d2b9 100755 --- a/src/openpi/policies/policy.py +++ b/src/openpi/policies/policy.py @@ -18,8 +18,53 @@ from openpi.shared import array_typing as at from openpi.shared import nnx_utils +logger = logging.getLogger(__name__) + BasePolicy: TypeAlias = _base_policy.BasePolicy +# Reserved transport-layer key in the observation dict for served clients to override +# sample_kwargs (e.g. pass a deterministic noise sample). Leading underscore signals +# "transport-layer field, not a model observation input" — avoids collisions with future +# models that legitimately use observation field names like "noise". +_RESERVED_SAMPLE_KWARGS_KEY = "_sample_kwargs" +_ALLOWED_TRANSPORT_SAMPLE_KWARGS = frozenset({"noise"}) + + +def _normalize_and_pad_prev_chunk( + raw: np.ndarray, + *, + norm_stats: dict[str, _transforms.NormStats], + use_quantile_norm: bool, + action_horizon: int, +) -> np.ndarray: + """Normalize a client-supplied ``prev_action_chunk`` into model space and pad to ``action_horizon``. + + The model's RTC ``sample_actions`` consumes ``prev_action_chunk`` in **model space** + (post-Normalize), but websocket clients send a raw ``(d, state_dim)`` slice of their + deploy-space execution buffer. Without this helper the guidance term operates on + un-normalized inputs — a silent train-deploy contract break. + + Delegates to the same ``transforms.Normalize`` instance the serving pipeline uses so the + formula (z-score vs quantile) cannot drift. Pads the chunk to the model's + ``action_horizon`` because the JAX/PyTorch RTC implementations require that shape. + """ + state_dim = raw.shape[-1] + action_stats = norm_stats["actions"] + if state_dim > action_stats.mean.shape[-1]: + raise ValueError( + f"prev_action_chunk state_dim={state_dim} exceeds norm_stats['actions'] width " + f"{action_stats.mean.shape[-1]}; client is sending more joints than the checkpoint knows about." + ) + normalizer = _transforms.Normalize({"actions": action_stats}, use_quantiles=use_quantile_norm) + normalized = normalizer({"actions": raw})["actions"] + d = normalized.shape[0] + if d < action_horizon: + pad = np.zeros((action_horizon - d, state_dim), dtype=np.float32) + normalized = np.concatenate([normalized, pad], axis=0) + elif d > action_horizon: + normalized = normalized[:action_horizon] + return normalized.astype(np.float32, copy=False) + class Policy(BasePolicy): def __init__( @@ -33,6 +78,9 @@ def __init__( metadata: dict[str, Any] | None = None, pytorch_device: str = "cpu", is_pytorch: bool = False, + norm_stats: dict[str, _transforms.NormStats] | None = None, + use_quantile_norm: bool = False, + action_horizon: int | None = None, ): """Initialize the Policy. @@ -54,6 +102,15 @@ def __init__( self._metadata = metadata or {} self._is_pytorch_model = is_pytorch self._pytorch_device = pytorch_device + if norm_stats is not None and action_horizon is None: + raise ValueError( + "Policy(norm_stats=...) requires action_horizon to also be provided; " + "without it, server-side prev_action_chunk normalization cannot pad to the model's horizon." + ) + self._norm_stats = norm_stats + self._use_quantile_norm = use_quantile_norm + self._action_horizon = action_horizon + self._rtc_log_emitted = False if self._is_pytorch_model: self._model = self._model.to(pytorch_device) @@ -81,21 +138,55 @@ def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict: # type: # Prepare kwargs for sample_actions sample_kwargs = dict(self._sample_kwargs) - # TODO: For RTC passthrough: allow client to provide delay/prev chunk/horizon for realtime_action-capable models - if "prev_action_chunk" in obs: - sample_kwargs["prev_action_chunk"] = obs["prev_action_chunk"] - if "inference_delay" in obs: + # RTC cheap-path guidance. Client sends a raw (d, state_dim) slice of its blended queue head + # along with inference_delay; we normalize with the same stats+mode as the serving Normalize + # transform and pad horizon to the model's action_horizon before forwarding. Both fields must + # be present together — forwarding prev_action_chunk without inference_delay would silently + # trip the cheap-path gate with d=0, running the eager loop with no prefix conditioning. + has_prev = "prev_action_chunk" in obs + has_delay = "inference_delay" in obs + if has_prev and not has_delay: + logger.warning( + "[rtc_cheap_path] obs has prev_action_chunk but not inference_delay; skipping cheap-path " + "forwarding to avoid silent d=0 activation. Client must send both fields together." + ) + elif has_prev and has_delay: + raw_prev = np.asarray(obs["prev_action_chunk"], dtype=np.float32) + if self._norm_stats is not None and self._action_horizon is not None: + prev_chunk = _normalize_and_pad_prev_chunk( + raw_prev, + norm_stats=self._norm_stats, + use_quantile_norm=self._use_quantile_norm, + action_horizon=self._action_horizon, + ) + sample_kwargs["prev_action_chunk"] = prev_chunk + log_fn = logger.info if not self._rtc_log_emitted else logger.debug + log_fn( + "[rtc] forwarded prev_action_chunk d=%d ah=%d quantile=%s", + raw_prev.shape[0], self._action_horizon, self._use_quantile_norm, + ) + self._rtc_log_emitted = True + else: + sample_kwargs["prev_action_chunk"] = raw_prev sample_kwargs["inference_delay"] = obs["inference_delay"] if "execute_horizon" in obs: sample_kwargs["execute_horizon"] = obs["execute_horizon"] - # if "enable_rtc" in obs: - # sample_kwargs["enable_rtc"] = obs["enable_rtc"] - # if "mask_prefix_delay" in obs: - # sample_kwargs["mask_prefix_delay"] = obs["mask_prefix_delay"] - # if "prefix_attention_schedule" in obs: - # sample_kwargs["prefix_attention_schedule"] = obs["prefix_attention_schedule"] - # if "max_guidance_weight" in obs: - # sample_kwargs["max_guidance_weight"] = obs["max_guidance_weight"] + # Reserved-key transport for sample_kwargs overrides (currently: noise). + # Explicit `noise=` kwarg (in-process callers) takes precedence over obs-supplied noise. + sample_kwargs_override = obs.get(_RESERVED_SAMPLE_KWARGS_KEY) or {} + if not isinstance(sample_kwargs_override, dict): + raise TypeError( + f"obs[{_RESERVED_SAMPLE_KWARGS_KEY!r}] must be a dict, " + f"got {type(sample_kwargs_override).__name__}" + ) + unknown = set(sample_kwargs_override) - _ALLOWED_TRANSPORT_SAMPLE_KWARGS + if unknown: + raise ValueError( + f"obs[{_RESERVED_SAMPLE_KWARGS_KEY!r}] contains unsupported keys: {sorted(unknown)}; " + f"allowlist: {sorted(_ALLOWED_TRANSPORT_SAMPLE_KWARGS)}" + ) + if "noise" in sample_kwargs_override and noise is None: + noise = np.asarray(sample_kwargs_override["noise"]) if noise is not None: noise = torch.from_numpy(noise).to(self._pytorch_device) if self._is_pytorch_model else jnp.asarray(noise) diff --git a/src/openpi/policies/policy_config.py b/src/openpi/policies/policy_config.py index 6570df0..dca1f63 100755 --- a/src/openpi/policies/policy_config.py +++ b/src/openpi/policies/policy_config.py @@ -91,4 +91,9 @@ def create_trained_policy( metadata=train_config.policy_metadata, is_pytorch=is_pytorch, pytorch_device=pytorch_device if is_pytorch else None, + # Wire RTC normalization params from the loaded checkpoint so served prev_action_chunk + # is normalized into model space before reaching Pi0RTC.sample_actions(). + norm_stats=norm_stats, + use_quantile_norm=data_config.use_quantile_norm, + action_horizon=train_config.model.action_horizon, )