Fix: compile multitask#5457
Conversation
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughThis PR introduces task-structure-aware torch.compile caching for multi-task models. It extracts per-task fitting-net buffers, computes a shared-structure identity key, promotes those buffers into explicit FX symbolic inputs for graph reuse, updates checkpoint loading to skip task-buffer remnants, and converts training loss values to floats for display aggregation. ChangesTask-structure-aware torch.compile optimization
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This PR adjusts the pt_expt torch.compile path for multi-task training to reduce redundant compiled graphs (and associated memory/oom issues) by promoting per-task fitting-net buffers to explicit compiled-graph inputs and reusing compiled graphs across tasks when the model structure is shared.
Changes:
- Promote task-specific fitting-net buffers (
bias_atom_e,case_embd) into FX placeholders so one compiled graph can be reused with different per-task buffer values. - Add per-structure caching in the compile pipeline to avoid recompiling the same shared structure for each task.
- Make training-time logging robust by converting tensor scalars to Python floats before formatting/aggregation.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
deepmd/pt_expt/train/training.py |
Adds task-buffer promotion + compiled-graph reuse caching for multi-task compile; adjusts logging scalar handling. |
deepmd/pt_expt/infer/deep_eval.py |
Updates .pt checkpoint loading to ignore newly introduced _CompiledModel per-task buffer copies. |
Comments suppressed due to low confidence (1)
deepmd/pt_expt/train/training.py:1072
- There are existing pt_expt tests covering multi-task + torch.compile, but the new compiled-graph reuse path should be covered by a test that exercises a config where only some components are shared (e.g., fitting_net shared via shared_dict, descriptor not shared). That case would validate the structure-key logic and prevent accidental graph reuse across non-identical forward_lower graphs.
descriptor = model.get_descriptor()
if isinstance(descriptor, DescrptDPA1DP):
n_attn = descriptor.get_numb_attn_layer()
if n_attn > 0:
log.warning(
"Compiling DPA1/se_atten_v2 with %d attention "
"layer(s) (task=%s): the compiled forces/grads "
"are slightly hardware-sensitive (multi-thread "
"reduction order), and may not match the eager "
"path bit-for-bit. Use 'enable_compile: false' "
"or 'attn_layer: 0' for fully reproducible runs.",
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
🧹 Nitpick comments (3)
deepmd/pt_expt/train/training.py (3)
319-319: ⚡ Quick winAdd explicit
strict=Trueto zip call.The
zip()on line 319 iterates overtask_buf_orderandtask_buf_vals, which are guaranteed to have the same length by construction (lines 290-293). Addingstrict=Truedocuments this invariant and provides a runtime assertion if the construction logic ever changes.- for name, val in zip(task_buf_order, task_buf_vals): + for name, val in zip(task_buf_order, task_buf_vals, strict=True):As per coding guidelines, run
ruff check .before committing to catch linting issues.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt_expt/train/training.py` at line 319, The zip over task_buf_order and task_buf_vals in the loop using "for name, val in zip(task_buf_order, task_buf_vals):" should assert the equal-length invariant by adding strict=True; update that zip call to zip(task_buf_order, task_buf_vals, strict=True) so a runtime error surfaces if lengths diverge, then run ruff check . before committing to ensure linting passes.
314-334: ⚡ Quick winPotential issue with buffer restoration logic.
Lines 320 and 334 save and restore buffer entries, but if
originals[name]isNone(buffer didn't exist), line 334 sets_fitting._buffers[name] = Noneinstead of deleting the entry. This could leaveNoneentries in the buffer registry that weren't present before patching.Consider using conditional restoration:
for name, orig in originals.items(): if orig is not None: _fitting._buffers[name] = orig else: _fitting._buffers.pop(name, None)However, if the buffers are guaranteed to exist (since
_get_task_buffersonly extracts existing buffers), this may not be an issue in practice.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt_expt/train/training.py` around lines 314 - 334, The restoration currently writes None back into _fitting._buffers for entries that did not exist before, so change the finally-block that iterates originals (the dict populated from task_buf_order/task_buf_vals) to restore by reassigning when orig is not None and otherwise remove the key (e.g., pop) from _fitting._buffers; locate the dictionary named originals and the finally block that resets _fitting._buffers and replace the unconditional assignment with a conditional restore/remove to avoid leaving None entries after model.forward_lower returns.
92-108: ⚡ Quick winClarify the child name check logic.
Line 103 compares child module names against
_TASK_SPECIFIC_BUFFER_NAMES, which contains buffer names ("bias_atom_e","case_embd"). Child modules fromnamed_children()typically have names like"nets","layers", etc., not buffer names. This check will almost always beTrue, making it effectively a no-op.If the intent is to skip the first child when it's task-specific, the logic may need adjustment. Otherwise, consider removing the check or adding a comment explaining why it's safe to use the first child's
id()directly.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt_expt/train/training.py` around lines 92 - 108, The code in _get_model_structure_key uses named_children() names compared against _TASK_SPECIFIC_BUFFER_NAMES (which lists buffer names like "bias_atom_e"), but child module names come from named_children() and won't match buffer names, so the filter is effectively a no-op; fix by computing the set of task-specific buffer names from the fitting net (e.g., buffers = {n for n,_ in fitting.named_buffers()}) and then skip any child whose name appears in that buffer set (replace the current name check), or if the original intent was to just take the first non-task-specific child drop the faulty comparison and simply return id of the first child from fitting.named_children(); update _get_model_structure_key accordingly and keep reference to fitting, named_children(), named_buffers(), and _TASK_SPECIFIC_BUFFER_NAMES to locate the code to change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 319: The zip over task_buf_order and task_buf_vals in the loop using "for
name, val in zip(task_buf_order, task_buf_vals):" should assert the equal-length
invariant by adding strict=True; update that zip call to zip(task_buf_order,
task_buf_vals, strict=True) so a runtime error surfaces if lengths diverge, then
run ruff check . before committing to ensure linting passes.
- Around line 314-334: The restoration currently writes None back into
_fitting._buffers for entries that did not exist before, so change the
finally-block that iterates originals (the dict populated from
task_buf_order/task_buf_vals) to restore by reassigning when orig is not None
and otherwise remove the key (e.g., pop) from _fitting._buffers; locate the
dictionary named originals and the finally block that resets _fitting._buffers
and replace the unconditional assignment with a conditional restore/remove to
avoid leaving None entries after model.forward_lower returns.
- Around line 92-108: The code in _get_model_structure_key uses named_children()
names compared against _TASK_SPECIFIC_BUFFER_NAMES (which lists buffer names
like "bias_atom_e"), but child module names come from named_children() and won't
match buffer names, so the filter is effectively a no-op; fix by computing the
set of task-specific buffer names from the fitting net (e.g., buffers = {n for
n,_ in fitting.named_buffers()}) and then skip any child whose name appears in
that buffer set (replace the current name check), or if the original intent was
to just take the first non-task-specific child drop the faulty comparison and
simply return id of the first child from fitting.named_children(); update
_get_model_structure_key accordingly and keep reference to fitting,
named_children(), named_buffers(), and _TASK_SPECIFIC_BUFFER_NAMES to locate the
code to change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: bac28a64-1300-4bb0-9d04-cf61734721ce
📒 Files selected for processing (2)
deepmd/pt_expt/infer/deep_eval.pydeepmd/pt_expt/train/training.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5457 +/- ##
=======================================
Coverage 82.46% 82.46%
=======================================
Files 829 829
Lines 88763 88853 +90
Branches 4225 4226 +1
=======================================
+ Hits 73197 73272 +75
- Misses 14274 14289 +15
Partials 1292 1292 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 491-500: If self._task_buf_order is set but accessing buffers
fails, don't silently set task_buf_vals = (); instead, in the except
AttributeError branch for original_model.get_fitting_net()/getattr(...) raise a
clear RuntimeError (or ValueError) that mentions the missing fitting net or
buffer names and refers to self._task_buf_order so callers know why
compiled_forward_lower would fail; keep the existing else path that sets
task_buf_vals = () only when _task_buf_order is empty, and ensure the raised
message names the expected buffer attributes and the method
original_model.get_fitting_net to help locate the root cause.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: add25e51-694f-4a58-8e05-9d3e0d5909e9
📒 Files selected for processing (1)
deepmd/pt_expt/train/training.py
| if self._task_buf_order: | ||
| try: | ||
| _fitting = self.original_model.get_fitting_net() | ||
| task_buf_vals: tuple = tuple( | ||
| getattr(_fitting, name) for name in self._task_buf_order | ||
| ) | ||
| except AttributeError: | ||
| task_buf_vals = () | ||
| else: | ||
| task_buf_vals = () |
There was a problem hiding this comment.
Silent fallback may cause confusing downstream errors.
If _task_buf_order is non-empty (compiled graph expects buffer arguments) but the AttributeError catch triggers, the empty task_buf_vals passed to compiled_forward_lower will cause an argument count mismatch—a confusing error that hides the real issue.
Consider raising an informative error when buffers are expected but unavailable:
Proposed fix
if self._task_buf_order:
try:
_fitting = self.original_model.get_fitting_net()
task_buf_vals: tuple = tuple(
getattr(_fitting, name) for name in self._task_buf_order
)
except AttributeError:
- task_buf_vals = ()
+ raise RuntimeError(
+ f"Compiled graph expects task buffers {self._task_buf_order} "
+ "but fitting net or buffer attributes are unavailable."
+ )
else:
task_buf_vals = ()🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt_expt/train/training.py` around lines 491 - 500, If
self._task_buf_order is set but accessing buffers fails, don't silently set
task_buf_vals = (); instead, in the except AttributeError branch for
original_model.get_fitting_net()/getattr(...) raise a clear RuntimeError (or
ValueError) that mentions the missing fitting net or buffer names and refers to
self._task_buf_order so callers know why compiled_forward_lower would fail; keep
the existing else path that sets task_buf_vals = () only when _task_buf_order is
empty, and ensure the raised message names the expected buffer attributes and
the method original_model.get_fitting_net to help locate the root cause.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 343-345: The zip call iterating over task_buf_order and
task_buf_vals should be made strict to satisfy Ruff B905: change
zip(task_buf_order, task_buf_vals) to zip(task_buf_order, task_buf_vals,
strict=True) in the block that checks name.startswith(_AM_PREFIX). Also remove
or rename the unused variable model_pred (found as model_pred in this file) to
_model_pred (or delete it) to resolve RUF059 so no unused binding remains.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 62caa6e7-1434-49b6-8f30-bfab6dd88904
📒 Files selected for processing (1)
deepmd/pt_expt/train/training.py
| if task_buf_order: | ||
| for name, val in zip(task_buf_order, task_buf_vals): | ||
| if name.startswith(_AM_PREFIX): |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
ruff check .Repository: deepmodeling/deepmd-kit
Length of output: 50381
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "## ruff check (target file only)"
ruff check deepmd/pt_expt/train/training.py
echo "## ruff format --check (target file only)"
ruff format --check deepmd/pt_expt/train/training.pyRepository: deepmodeling/deepmd-kit
Length of output: 1469
🏁 Script executed:
ruff check .Repository: deepmodeling/deepmd-kit
Length of output: 50381
🏁 Script executed:
ruff check .
ruff format .Repository: deepmodeling/deepmd-kit
Length of output: 50381
Fix Ruff B905 (zip strict) and RUF059 (unused model_pred) in deepmd/pt_expt/train/training.py.
- B905 (line 344):
zip(task_buf_order, task_buf_vals)needs explicitstrict=to keep both sequences aligned. - RUF059 (line 1354):
model_predis never used—remove it or rename to_model_pred.
Suggested fix
- for name, val in zip(task_buf_order, task_buf_vals):
+ for name, val in zip(task_buf_order, task_buf_vals, strict=True):- model_pred, loss, more_loss = self.wrapper(
+ _model_pred, loss, more_loss = self.wrapper(🧰 Tools
🪛 Ruff (0.15.14)
[warning] 344-344: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt_expt/train/training.py` around lines 343 - 345, The zip call
iterating over task_buf_order and task_buf_vals should be made strict to satisfy
Ruff B905: change zip(task_buf_order, task_buf_vals) to zip(task_buf_order,
task_buf_vals, strict=True) in the block that checks
name.startswith(_AM_PREFIX). Also remove or rename the unused variable
model_pred (found as model_pred in this file) to _model_pred (or delete it) to
resolve RUF059 so no unused binding remains.
make dataset embedding and energy bias as input not buffer for compile, this allows multitask training share compiled model thus resolve OOM and NCCL timeout issue. Since the empty_cache and del are removed, no GC complaints.
Summary by CodeRabbit
New Features
Bug Fixes