fix(pt_expt): Fix compiled ckpt load key mismatch#5468
Conversation
for more information, see https://pre-commit.ci
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughThe PR fixes compiled model checkpoint serialization by adding attribute delegation to ChangesCompiled Model Checkpoint Serialization
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This PR fixes restart failures when training with torch.compile enabled in the pt_expt backend by ensuring compiled-model wrappers don’t leak wrapper-prefixed keys into saved checkpoints, and by making the compiled wrapper transparently expose attributes/methods of the original model.
Changes:
- Add
_CompiledModel.__getattr__to delegate unknown attribute/method lookups to the wrappedoriginal_model. - Update checkpoint saving to temporarily unwrap
_CompiledModelinstances so savedstate_dictkeys match what a freshModelWrapperexpects. - Add regression/unit tests covering attribute delegation and restart-from-compiled-checkpoint behavior.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
deepmd/pt_expt/train/training.py |
Adds attribute delegation for _CompiledModel and unwraps compiled models during checkpoint serialization to prevent key mismatches on restart. |
source/tests/pt_expt/test_training.py |
Adds unit/regression tests for _CompiledModel delegation and compiled-checkpoint restart key compatibility. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| compiled_backup: dict[str, _CompiledModel] = {} | ||
| for task_key in list(wrapper.model.keys()): | ||
| m = wrapper.model[task_key] | ||
| if isinstance(m, _CompiledModel): | ||
| compiled_backup[task_key] = m |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5468 +/- ##
=======================================
Coverage 82.46% 82.46%
=======================================
Files 829 829
Lines 88763 88777 +14
Branches 4225 4225
=======================================
+ Hits 73197 73212 +15
- Misses 14274 14275 +1
+ Partials 1292 1290 -2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This resolves restart failure.
Summary by CodeRabbit
Bug Fixes
Tests