Skip to content

fix(pt_expt): Fix compiled ckpt load key mismatch#5468

Open
anyangml wants to merge 2 commits into
deepmodeling:masterfrom
anyangml:fix/pt-expt-compile-restart
Open

fix(pt_expt): Fix compiled ckpt load key mismatch#5468
anyangml wants to merge 2 commits into
deepmodeling:masterfrom
anyangml:fix/pt-expt-compile-restart

Conversation

@anyangml
Copy link
Copy Markdown
Collaborator

@anyangml anyangml commented May 27, 2026

This resolves restart failure.

Summary by CodeRabbit

  • Bug Fixes

    • Fixed checkpoint compatibility when using compiled models, ensuring proper state dictionary alignment during training resumption.
    • Improved compiled model wrapper attribute delegation to preserve access to all model methods and properties.
  • Tests

    • Added comprehensive test coverage for compiled model functionality and checkpoint restart scenarios.

Review Change Stack

Copilot AI review requested due to automatic review settings May 27, 2026 09:15
@dosubot dosubot Bot added the bug label May 27, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 27, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 5791face-c8e2-4b1b-b1ba-17e0faacc335

📥 Commits

Reviewing files that changed from the base of the PR and between 4e64f8b and 22e79fa.

📒 Files selected for processing (2)
  • deepmd/pt_expt/train/training.py
  • source/tests/pt_expt/test_training.py

📝 Walkthrough

Walkthrough

The PR fixes compiled model checkpoint serialization by adding attribute delegation to _CompiledModel and updating Trainer.save_checkpoint() to temporarily unwrap compiled models before saving. Unit and integration tests verify delegation behavior and successful restart cycles from compiled checkpoints.

Changes

Compiled Model Checkpoint Serialization

Layer / File(s) Summary
_CompiledModel attribute delegation
deepmd/pt_expt/train/training.py, source/tests/pt_expt/test_training.py
_CompiledModel.__getattr__ delegates unknown attribute access to original_model. Unit tests verify method and property delegation, correct error handling for missing attributes, and that wrapper-owned submodules are not delegated.
Checkpoint serialization and restart
deepmd/pt_expt/train/training.py, source/tests/pt_expt/test_training.py
Trainer.save_checkpoint() swaps per-task _CompiledModel instances with their original_model before serialization, then restores them. Integration test trains with compile enabled, verifies checkpoint state-dict format matches uncompiled expectations, loads weights into a fresh uncompiled wrapper, and restarts from the compiled checkpoint with step restoration and successful training.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 47.37% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly addresses the main issue being fixed: compiled checkpoint loading key mismatch during restart, which is the core purpose of the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes restart failures when training with torch.compile enabled in the pt_expt backend by ensuring compiled-model wrappers don’t leak wrapper-prefixed keys into saved checkpoints, and by making the compiled wrapper transparently expose attributes/methods of the original model.

Changes:

  • Add _CompiledModel.__getattr__ to delegate unknown attribute/method lookups to the wrapped original_model.
  • Update checkpoint saving to temporarily unwrap _CompiledModel instances so saved state_dict keys match what a fresh ModelWrapper expects.
  • Add regression/unit tests covering attribute delegation and restart-from-compiled-checkpoint behavior.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
deepmd/pt_expt/train/training.py Adds attribute delegation for _CompiledModel and unwraps compiled models during checkpoint serialization to prevent key mismatches on restart.
source/tests/pt_expt/test_training.py Adds unit/regression tests for _CompiledModel delegation and compiled-checkpoint restart key compatibility.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1140 to +1144
compiled_backup: dict[str, _CompiledModel] = {}
for task_key in list(wrapper.model.keys()):
m = wrapper.model[task_key]
if isinstance(m, _CompiledModel):
compiled_backup[task_key] = m
@anyangml anyangml requested a review from wanghan-iapcm May 27, 2026 09:23
@codecov
Copy link
Copy Markdown

codecov Bot commented May 27, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.46%. Comparing base (4e64f8b) to head (22e79fa).

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #5468   +/-   ##
=======================================
  Coverage   82.46%   82.46%           
=======================================
  Files         829      829           
  Lines       88763    88777   +14     
  Branches     4225     4225           
=======================================
+ Hits        73197    73212   +15     
- Misses      14274    14275    +1     
+ Partials     1292     1290    -2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants