Skip to content

Fix: compile multitask#5457

Open
anyangml wants to merge 6 commits into
deepmodeling:masterfrom
anyangml:fix/compile-multitask
Open

Fix: compile multitask#5457
anyangml wants to merge 6 commits into
deepmodeling:masterfrom
anyangml:fix/compile-multitask

Conversation

@anyangml
Copy link
Copy Markdown
Collaborator

@anyangml anyangml commented May 26, 2026

make dataset embedding and energy bias as input not buffer for compile, this allows multitask training share compiled model thus resolve OOM and NCCL timeout issue. Since the empty_cache and del are removed, no GC complaints.

Summary by CodeRabbit

  • New Features

    • Multi-task training now caches and reuses compiled computation graphs for models that share fitting-net structure, speeding repeated-task training.
  • Bug Fixes

    • Checkpoint loading now ignores extraneous per-task buffer entries so restored models contain only original parameters.
    • Training reports now convert tensor-like loss/metric values to floats for correct aggregation and display.

Review Change Stack

Copilot AI review requested due to automatic review settings May 26, 2026 02:03
@dosubot dosubot Bot added the bug label May 26, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 26, 2026

📝 Walkthrough

Walkthrough

This PR introduces task-structure-aware torch.compile caching for multi-task models. It extracts per-task fitting-net buffers, computes a shared-structure identity key, promotes those buffers into explicit FX symbolic inputs for graph reuse, updates checkpoint loading to skip task-buffer remnants, and converts training loss values to floats for display aggregation.

Changes

Task-structure-aware torch.compile optimization

Layer / File(s) Summary
Task buffer extraction and structure key helpers
deepmd/pt_expt/train/training.py
_TASK_SPECIFIC_BUFFER_NAMES list, _get_task_buffers() to extract cloned per-task buffers, and _get_model_structure_key() to compute fitting-net identity for reuse detection.
_trace_and_compile buffer promotion and FX symbolic input handling
deepmd/pt_expt/train/training.py
Signature and docstring extended to accept per-task task_buffers. Computes task_buf_order and fitting submodule, temporarily patches fitting-net _buffers during tracing, includes promoted buffers as extra symbolic inputs to make_fx, and returns (compiled_module, task_buf_order).
_CompiledModel buffer storage and forward passing
deepmd/pt_expt/train/training.py
Constructor accepts and stores task_buf_order; forward() gathers current-task fitting-net and atomic-model tensors in order and passes them as variadic args into compiled forward_lower.
Structure-aware compilation caching and _compile_model orchestration
deepmd/pt_expt/train/training.py
Adds _compiled_by_structure cache keyed by structure identity, computes per-task task_bufs, reuses cached compiled graphs when structure matches, otherwise traces with task buffers and caches results; wraps task branches with _CompiledModel(..., task_buf_order).
Training loss value float conversion for display and aggregation
deepmd/pt_expt/train/training.py
Adds _to_float() helper and applies it to convert tensor-like loss/metric values to Python floats in single-task and multi-task training and validation aggregation.
Skip task buffer entries during checkpoint state_dict cleanup
deepmd/pt_expt/infer/deep_eval.py
DeepEval._load_pt filters out keys containing the ._task_ marker in addition to compiled-forward-lower keys to avoid loading task-buffer copy entries into the inference model.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5423: Both PRs modify deepmd/pt_expt/infer/deep_eval.py—specifically DeepEval._load_pt's state-dict cleanup/filtered-key loading logic—so the main PR's new "._task_" key omission is directly related to the retrieved PR's .pt checkpoint loading and compiled-key dropping.

Suggested reviewers

  • njzjz
  • iProzd
🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'Fix: compile multitask' is vague and generic. It uses non-descriptive terms that don't clearly convey what specific issue is being fixed or what the multitask compilation change entails. Use a more descriptive title that clarifies the specific fix, such as 'Fix multitask model sharing by promoting per-task buffers to inputs for torch.compile' or 'Support compiled model reuse across tasks by treating per-task buffers as inputs'.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Comment thread deepmd/pt_expt/train/training.py Dismissed
Comment thread deepmd/pt_expt/train/training.py Fixed
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adjusts the pt_expt torch.compile path for multi-task training to reduce redundant compiled graphs (and associated memory/oom issues) by promoting per-task fitting-net buffers to explicit compiled-graph inputs and reusing compiled graphs across tasks when the model structure is shared.

Changes:

  • Promote task-specific fitting-net buffers (bias_atom_e, case_embd) into FX placeholders so one compiled graph can be reused with different per-task buffer values.
  • Add per-structure caching in the compile pipeline to avoid recompiling the same shared structure for each task.
  • Make training-time logging robust by converting tensor scalars to Python floats before formatting/aggregation.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
deepmd/pt_expt/train/training.py Adds task-buffer promotion + compiled-graph reuse caching for multi-task compile; adjusts logging scalar handling.
deepmd/pt_expt/infer/deep_eval.py Updates .pt checkpoint loading to ignore newly introduced _CompiledModel per-task buffer copies.
Comments suppressed due to low confidence (1)

deepmd/pt_expt/train/training.py:1072

  • There are existing pt_expt tests covering multi-task + torch.compile, but the new compiled-graph reuse path should be covered by a test that exercises a config where only some components are shared (e.g., fitting_net shared via shared_dict, descriptor not shared). That case would validate the structure-key logic and prevent accidental graph reuse across non-identical forward_lower graphs.
            descriptor = model.get_descriptor()
            if isinstance(descriptor, DescrptDPA1DP):
                n_attn = descriptor.get_numb_attn_layer()
                if n_attn > 0:
                    log.warning(
                        "Compiling DPA1/se_atten_v2 with %d attention "
                        "layer(s) (task=%s): the compiled forces/grads "
                        "are slightly hardware-sensitive (multi-thread "
                        "reduction order), and may not match the eager "
                        "path bit-for-bit.  Use 'enable_compile: false' "
                        "or 'attn_layer: 0' for fully reproducible runs.",

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/pt_expt/train/training.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
deepmd/pt_expt/train/training.py (3)

319-319: ⚡ Quick win

Add explicit strict=True to zip call.

The zip() on line 319 iterates over task_buf_order and task_buf_vals, which are guaranteed to have the same length by construction (lines 290-293). Adding strict=True documents this invariant and provides a runtime assertion if the construction logic ever changes.

-            for name, val in zip(task_buf_order, task_buf_vals):
+            for name, val in zip(task_buf_order, task_buf_vals, strict=True):

As per coding guidelines, run ruff check . before committing to catch linting issues.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` at line 319, The zip over task_buf_order
and task_buf_vals in the loop using "for name, val in zip(task_buf_order,
task_buf_vals):" should assert the equal-length invariant by adding strict=True;
update that zip call to zip(task_buf_order, task_buf_vals, strict=True) so a
runtime error surfaces if lengths diverge, then run ruff check . before
committing to ensure linting passes.

314-334: ⚡ Quick win

Potential issue with buffer restoration logic.

Lines 320 and 334 save and restore buffer entries, but if originals[name] is None (buffer didn't exist), line 334 sets _fitting._buffers[name] = None instead of deleting the entry. This could leave None entries in the buffer registry that weren't present before patching.

Consider using conditional restoration:

for name, orig in originals.items():
    if orig is not None:
        _fitting._buffers[name] = orig
    else:
        _fitting._buffers.pop(name, None)

However, if the buffers are guaranteed to exist (since _get_task_buffers only extracts existing buffers), this may not be an issue in practice.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` around lines 314 - 334, The restoration
currently writes None back into _fitting._buffers for entries that did not exist
before, so change the finally-block that iterates originals (the dict populated
from task_buf_order/task_buf_vals) to restore by reassigning when orig is not
None and otherwise remove the key (e.g., pop) from _fitting._buffers; locate the
dictionary named originals and the finally block that resets _fitting._buffers
and replace the unconditional assignment with a conditional restore/remove to
avoid leaving None entries after model.forward_lower returns.

92-108: ⚡ Quick win

Clarify the child name check logic.

Line 103 compares child module names against _TASK_SPECIFIC_BUFFER_NAMES, which contains buffer names ("bias_atom_e", "case_embd"). Child modules from named_children() typically have names like "nets", "layers", etc., not buffer names. This check will almost always be True, making it effectively a no-op.

If the intent is to skip the first child when it's task-specific, the logic may need adjustment. Otherwise, consider removing the check or adding a comment explaining why it's safe to use the first child's id() directly.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` around lines 92 - 108, The code in
_get_model_structure_key uses named_children() names compared against
_TASK_SPECIFIC_BUFFER_NAMES (which lists buffer names like "bias_atom_e"), but
child module names come from named_children() and won't match buffer names, so
the filter is effectively a no-op; fix by computing the set of task-specific
buffer names from the fitting net (e.g., buffers = {n for n,_ in
fitting.named_buffers()}) and then skip any child whose name appears in that
buffer set (replace the current name check), or if the original intent was to
just take the first non-task-specific child drop the faulty comparison and
simply return id of the first child from fitting.named_children(); update
_get_model_structure_key accordingly and keep reference to fitting,
named_children(), named_buffers(), and _TASK_SPECIFIC_BUFFER_NAMES to locate the
code to change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 319: The zip over task_buf_order and task_buf_vals in the loop using "for
name, val in zip(task_buf_order, task_buf_vals):" should assert the equal-length
invariant by adding strict=True; update that zip call to zip(task_buf_order,
task_buf_vals, strict=True) so a runtime error surfaces if lengths diverge, then
run ruff check . before committing to ensure linting passes.
- Around line 314-334: The restoration currently writes None back into
_fitting._buffers for entries that did not exist before, so change the
finally-block that iterates originals (the dict populated from
task_buf_order/task_buf_vals) to restore by reassigning when orig is not None
and otherwise remove the key (e.g., pop) from _fitting._buffers; locate the
dictionary named originals and the finally block that resets _fitting._buffers
and replace the unconditional assignment with a conditional restore/remove to
avoid leaving None entries after model.forward_lower returns.
- Around line 92-108: The code in _get_model_structure_key uses named_children()
names compared against _TASK_SPECIFIC_BUFFER_NAMES (which lists buffer names
like "bias_atom_e"), but child module names come from named_children() and won't
match buffer names, so the filter is effectively a no-op; fix by computing the
set of task-specific buffer names from the fitting net (e.g., buffers = {n for
n,_ in fitting.named_buffers()}) and then skip any child whose name appears in
that buffer set (replace the current name check), or if the original intent was
to just take the first non-task-specific child drop the faulty comparison and
simply return id of the first child from fitting.named_children(); update
_get_model_structure_key accordingly and keep reference to fitting,
named_children(), named_buffers(), and _TASK_SPECIFIC_BUFFER_NAMES to locate the
code to change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: bac28a64-1300-4bb0-9d04-cf61734721ce

📥 Commits

Reviewing files that changed from the base of the PR and between f39a081 and 4cee0bf.

📒 Files selected for processing (2)
  • deepmd/pt_expt/infer/deep_eval.py
  • deepmd/pt_expt/train/training.py

@codecov
Copy link
Copy Markdown

codecov Bot commented May 26, 2026

Codecov Report

❌ Patch coverage is 85.84071% with 16 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.46%. Comparing base (f39a081) to head (9ce8d3e).

Files with missing lines Patch % Lines
deepmd/pt_expt/train/training.py 86.36% 15 Missing ⚠️
deepmd/pt_expt/infer/deep_eval.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master    #5457   +/-   ##
=======================================
  Coverage   82.46%   82.46%           
=======================================
  Files         829      829           
  Lines       88763    88853   +90     
  Branches     4225     4226    +1     
=======================================
+ Hits        73197    73272   +75     
- Misses      14274    14289   +15     
  Partials     1292     1292           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 491-500: If self._task_buf_order is set but accessing buffers
fails, don't silently set task_buf_vals = (); instead, in the except
AttributeError branch for original_model.get_fitting_net()/getattr(...) raise a
clear RuntimeError (or ValueError) that mentions the missing fitting net or
buffer names and refers to self._task_buf_order so callers know why
compiled_forward_lower would fail; keep the existing else path that sets
task_buf_vals = () only when _task_buf_order is empty, and ensure the raised
message names the expected buffer attributes and the method
original_model.get_fitting_net to help locate the root cause.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: add25e51-694f-4a58-8e05-9d3e0d5909e9

📥 Commits

Reviewing files that changed from the base of the PR and between 4cee0bf and f3e29fe.

📒 Files selected for processing (1)
  • deepmd/pt_expt/train/training.py

Comment on lines +491 to +500
if self._task_buf_order:
try:
_fitting = self.original_model.get_fitting_net()
task_buf_vals: tuple = tuple(
getattr(_fitting, name) for name in self._task_buf_order
)
except AttributeError:
task_buf_vals = ()
else:
task_buf_vals = ()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Silent fallback may cause confusing downstream errors.

If _task_buf_order is non-empty (compiled graph expects buffer arguments) but the AttributeError catch triggers, the empty task_buf_vals passed to compiled_forward_lower will cause an argument count mismatch—a confusing error that hides the real issue.

Consider raising an informative error when buffers are expected but unavailable:

Proposed fix
         if self._task_buf_order:
             try:
                 _fitting = self.original_model.get_fitting_net()
                 task_buf_vals: tuple = tuple(
                     getattr(_fitting, name) for name in self._task_buf_order
                 )
             except AttributeError:
-                task_buf_vals = ()
+                raise RuntimeError(
+                    f"Compiled graph expects task buffers {self._task_buf_order} "
+                    "but fitting net or buffer attributes are unavailable."
+                )
         else:
             task_buf_vals = ()
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` around lines 491 - 500, If
self._task_buf_order is set but accessing buffers fails, don't silently set
task_buf_vals = (); instead, in the except AttributeError branch for
original_model.get_fitting_net()/getattr(...) raise a clear RuntimeError (or
ValueError) that mentions the missing fitting net or buffer names and refers to
self._task_buf_order so callers know why compiled_forward_lower would fail; keep
the existing else path that sets task_buf_vals = () only when _task_buf_order is
empty, and ensure the raised message names the expected buffer attributes and
the method original_model.get_fitting_net to help locate the root cause.

Comment thread deepmd/pt_expt/train/training.py Dismissed
Comment thread deepmd/pt_expt/train/training.py Dismissed
@anyangml anyangml requested review from njzjz and wanghan-iapcm May 26, 2026 09:27
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 343-345: The zip call iterating over task_buf_order and
task_buf_vals should be made strict to satisfy Ruff B905: change
zip(task_buf_order, task_buf_vals) to zip(task_buf_order, task_buf_vals,
strict=True) in the block that checks name.startswith(_AM_PREFIX). Also remove
or rename the unused variable model_pred (found as model_pred in this file) to
_model_pred (or delete it) to resolve RUF059 so no unused binding remains.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 62caa6e7-1434-49b6-8f30-bfab6dd88904

📥 Commits

Reviewing files that changed from the base of the PR and between f3e29fe and 9ce8d3e.

📒 Files selected for processing (1)
  • deepmd/pt_expt/train/training.py

Comment on lines +343 to +345
if task_buf_order:
for name, val in zip(task_buf_order, task_buf_vals):
if name.startswith(_AM_PREFIX):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
ruff check .

Repository: deepmodeling/deepmd-kit

Length of output: 50381


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "## ruff check (target file only)"
ruff check deepmd/pt_expt/train/training.py

echo "## ruff format --check (target file only)"
ruff format --check deepmd/pt_expt/train/training.py

Repository: deepmodeling/deepmd-kit

Length of output: 1469


🏁 Script executed:

ruff check .

Repository: deepmodeling/deepmd-kit

Length of output: 50381


🏁 Script executed:

ruff check .
ruff format .

Repository: deepmodeling/deepmd-kit

Length of output: 50381


Fix Ruff B905 (zip strict) and RUF059 (unused model_pred) in deepmd/pt_expt/train/training.py.

  • B905 (line 344): zip(task_buf_order, task_buf_vals) needs explicit strict= to keep both sequences aligned.
  • RUF059 (line 1354): model_pred is never used—remove it or rename to _model_pred.
Suggested fix
-            for name, val in zip(task_buf_order, task_buf_vals):
+            for name, val in zip(task_buf_order, task_buf_vals, strict=True):
-            model_pred, loss, more_loss = self.wrapper(
+            _model_pred, loss, more_loss = self.wrapper(
🧰 Tools
🪛 Ruff (0.15.14)

[warning] 344-344: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/train/training.py` around lines 343 - 345, The zip call
iterating over task_buf_order and task_buf_vals should be made strict to satisfy
Ruff B905: change zip(task_buf_order, task_buf_vals) to zip(task_buf_order,
task_buf_vals, strict=True) in the block that checks
name.startswith(_AM_PREFIX). Also remove or rename the unused variable
model_pred (found as model_pred in this file) to _model_pred (or delete it) to
resolve RUF059 so no unused binding remains.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants