feat(tf): support training stat_file#5551
Conversation
Allow TensorFlow training to accept training/stat_file and reuse saved energy statistics in the same type-map directory layout as PyTorch. This ports the useful part of PR deepmodeling#4926 onto current master and keeps TensorFlow's 1-D fitting bias shape internally. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
for more information, see https://pre-commit.ci
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (4)
🚧 Files skipped from review as they are similar to previous changes (3)
📝 WalkthroughWalkthroughAdds ChangesTF stat_file Feature
Sequence Diagram(s)sequenceDiagram
participant Entrypoint as train.py entrypoint
participant DPTrainer as DPTrainer.build
participant EnerModel as EnerModel.data_stat
participant StatUtil as compute_output_stats
participant Filesystem as DPPath / Filesystem
Entrypoint->>Filesystem: extract training.stat_file
Entrypoint->>Filesystem: create empty HDF5 file or directory on chief rank
Entrypoint->>Entrypoint: wrap as DPPath(stat_file, "a")
Entrypoint->>DPTrainer: model.build(..., stat_file_path=DPPath)
DPTrainer->>EnerModel: model.data_stat(data, stat_file_path=DPPath)
EnerModel->>StatUtil: compute_output_stats(all_stat, ntypes, stat_file_path=DPPath, ...)
StatUtil->>Filesystem: _restore_from_file: check bias/std files
alt files exist
Filesystem-->>StatUtil: loaded bias_out, std_out
else files missing
StatUtil->>StatUtil: compute via compute_stats_from_redu
StatUtil->>Filesystem: _save_to_file: write bias/std files
end
StatUtil-->>EnerModel: bias_out, std_out
EnerModel->>EnerModel: fitting.bias_atom_e updated
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/tf/entrypoints/train.py`:
- Around line 247-252: The code at lines 247-252 fails for nested paths because
parent directories are not created before attempting to initialize the stat file
target. Before the conditional block that checks if stat_file_raw exists, ensure
all parent directories are created using
Path(stat_file_raw).parent.mkdir(parents=True, exist_ok=True). This will create
the full directory hierarchy needed before attempting either the h5py.File()
call for HDF5 files or the Path(stat_file_raw).mkdir() call for directory
creation. The parents=True parameter ensures intermediate directories are
created, and exist_ok=True prevents errors if directories already exist.
In `@deepmd/tf/model/linear.py`:
- Around line 98-103: The data_stat method passes the same stat_file_path to all
child models in the loop, which causes later submodels to incorrectly reuse
previously saved statistics from earlier submodels. Modify the data_stat method
to namespace the stat_file_path for each model by appending a unique identifier
(such as the model index or model name) to the original path before passing it
to model.data_stat(). This ensures each model in self.models writes to and reads
from its own separate stat file.
In `@deepmd/tf/model/pairwise_dprc.py`:
- Around line 325-327: The data_stat method in the pairwise_dprc.py file passes
the same stat_file_path to both qm_model.data_stat() and qmmm_model.data_stat()
calls, which can cause one model to load the other's statistics instead of
computing its own. Modify the stat_file_path parameter for each call to use
separate namespaces or identifiers that distinguish between QM and QMMM models,
ensuring each model loads or computes its own statistics independently.
In `@deepmd/tf/utils/stat.py`:
- Around line 16-19: The functions `_restore_from_file` and
`compute_output_stats` use a mutable default argument `keys=["energy"]` which
triggers Ruff B006 violations. Replace the mutable default list with `keys=None`
in both function signatures, then at the beginning of each function body, add a
check to initialize `keys` to `["energy"]` if it is `None` (e.g., using `if keys
is None: keys = ["energy"]`). Since the parameter is only read and never
mutated, this change is safe and will resolve the Ruff violations.
In `@source/tests/consistent/test_stat_file.py`:
- Around line 115-117: The subprocess.run() call in the test does not include a
timeout parameter, which can cause the test to hang indefinitely if the CLI
command stalls. Add a timeout parameter to the subprocess.run() call to ensure
the test completes within a reasonable timeframe and fails gracefully if the
subprocess takes too long to complete.
- Around line 212-220: The tearDown method is deleting files from the current
working directory without proper scoping, which can remove artifacts created by
other tests or processes. Create a test-specific temporary directory in the
setUp method that is unique to this test instance, configure the test to write
its outputs to this directory, and modify the tearDown method to only clean up
files within this test-specific directory. This ensures that the cleanup in
tearDown is isolated and does not interfere with shared cwd artifacts or other
concurrent tests.
In `@source/tests/tf/test_stat_file_integration.py`:
- Around line 99-103: The current assertion in the stat_path validation block is
conditional on stat_path.exists(), which means if the stat_file is never
created, the assertion is skipped entirely and the test passes without detecting
the failure. Remove the if stat_path.exists() condition and instead assert
unconditionally that the stat_path both exists and is a directory. This ensures
that when the training.stat_file creation is broken, the test will properly fail
rather than silently passing.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 23d2b3aa-2fce-44a0-8c17-e0bda3504fd0
📒 Files selected for processing (14)
deepmd/tf/entrypoints/train.pydeepmd/tf/model/dos.pydeepmd/tf/model/ener.pydeepmd/tf/model/frozen.pydeepmd/tf/model/linear.pydeepmd/tf/model/model.pydeepmd/tf/model/pairwise_dprc.pydeepmd/tf/model/tensor.pydeepmd/tf/train/trainer.pydeepmd/tf/utils/stat.pydeepmd/utils/argcheck.pysource/tests/consistent/test_stat_file.pysource/tests/tf/test_stat_file.pysource/tests/tf/test_stat_file_integration.py
| def data_stat( | ||
| self, data: DeepmdDataSystem, stat_file_path: DPPath | None = None | ||
| ) -> None: | ||
| for model in self.models: | ||
| model.data_stat(data) | ||
| model.data_stat(data, stat_file_path=stat_file_path) | ||
|
|
There was a problem hiding this comment.
Namespace stat_file_path per submodel to avoid stat collisions.
Line 102 passes the same path to every child model. Because EnerModel restores from existing files before recomputing (see deepmd/tf/model/ener.py, Line 187), later submodels can incorrectly reuse earlier submodels’ saved stats.
Suggested fix
def data_stat(
self, data: DeepmdDataSystem, stat_file_path: DPPath | None = None
) -> None:
- for model in self.models:
- model.data_stat(data, stat_file_path=stat_file_path)
+ for ii, model in enumerate(self.models):
+ model_stat_path = (
+ None if stat_file_path is None else stat_file_path / f"model{ii}"
+ )
+ model.data_stat(data, stat_file_path=model_stat_path)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/tf/model/linear.py` around lines 98 - 103, The data_stat method passes
the same stat_file_path to all child models in the loop, which causes later
submodels to incorrectly reuse previously saved statistics from earlier
submodels. Modify the data_stat method to namespace the stat_file_path for each
model by appending a unique identifier (such as the model index or model name)
to the original path before passing it to model.data_stat(). This ensures each
model in self.models writes to and reads from its own separate stat file.
| def data_stat(self, data: dict, stat_file_path: DPPath | None = None) -> None: | ||
| self.qm_model.data_stat(data, stat_file_path=stat_file_path) | ||
| self.qmmm_model.data_stat(data, stat_file_path=stat_file_path) |
There was a problem hiding this comment.
Use separate stat-file namespaces for QM and QMMM models.
Line 326 and Line 327 forward the same stat_file_path to both branches. With restore-first behavior in EnerModel (Line 187 in deepmd/tf/model/ener.py), one branch can load the other branch’s stats instead of computing its own.
Suggested fix
def data_stat(self, data: dict, stat_file_path: DPPath | None = None) -> None:
- self.qm_model.data_stat(data, stat_file_path=stat_file_path)
- self.qmmm_model.data_stat(data, stat_file_path=stat_file_path)
+ qm_stat_path = None if stat_file_path is None else stat_file_path / "qm"
+ qmmm_stat_path = None if stat_file_path is None else stat_file_path / "qmmm"
+ self.qm_model.data_stat(data, stat_file_path=qm_stat_path)
+ self.qmmm_model.data_stat(data, stat_file_path=qmmm_stat_path)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/tf/model/pairwise_dprc.py` around lines 325 - 327, The data_stat
method in the pairwise_dprc.py file passes the same stat_file_path to both
qm_model.data_stat() and qmmm_model.data_stat() calls, which can cause one model
to load the other's statistics instead of computing its own. Modify the
stat_file_path parameter for each call to use separate namespaces or identifiers
that distinguish between QM and QMMM models, ensuring each model loads or
computes its own statistics independently.
| result = subprocess.run( | ||
| cmd, cwd=temp_dir, capture_output=True, text=True, env=env | ||
| ) |
There was a problem hiding this comment.
Add a timeout to the training subprocess.
Without a timeout, this test can hang the whole suite indefinitely on stalled CLI runs.
Suggested fix
- result = subprocess.run(
- cmd, cwd=temp_dir, capture_output=True, text=True, env=env
- )
+ try:
+ result = subprocess.run(
+ cmd,
+ cwd=temp_dir,
+ capture_output=True,
+ text=True,
+ env=env,
+ timeout=300,
+ )
+ except subprocess.TimeoutExpired as exc:
+ self.fail(f"Training timed out for {backend} backend: {exc}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| result = subprocess.run( | |
| cmd, cwd=temp_dir, capture_output=True, text=True, env=env | |
| ) | |
| try: | |
| result = subprocess.run( | |
| cmd, | |
| cwd=temp_dir, | |
| capture_output=True, | |
| text=True, | |
| env=env, | |
| timeout=300, | |
| ) | |
| except subprocess.TimeoutExpired as exc: | |
| self.fail(f"Training timed out for {backend} backend: {exc}") |
🧰 Tools
🪛 Ruff (0.15.17)
[error] 115-115: subprocess call: check for execution of untrusted input
(S603)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@source/tests/consistent/test_stat_file.py` around lines 115 - 117, The
subprocess.run() call in the test does not include a timeout parameter, which
can cause the test to hang indefinitely if the CLI command stalls. Add a timeout
parameter to the subprocess.run() call to ensure the test completes within a
reasonable timeframe and fails gracefully if the subprocess takes too long to
complete.
| def tearDown(self) -> None: | ||
| """Clean up any temporary files.""" | ||
| # Clean up any leftover files | ||
| for path in ["checkpoint", "lcurve.out", "model.ckpt"]: | ||
| if os.path.exists(path): | ||
| if os.path.isdir(path): | ||
| shutil.rmtree(path) | ||
| else: | ||
| os.remove(path) |
There was a problem hiding this comment.
Avoid deleting shared cwd artifacts in tearDown.
This cleanup is not scoped to files created by this test and can remove unrelated outputs from other tests/processes.
Suggested fix
- def tearDown(self) -> None:
- """Clean up any temporary files."""
- # Clean up any leftover files
- for path in ["checkpoint", "lcurve.out", "model.ckpt"]:
- if os.path.exists(path):
- if os.path.isdir(path):
- shutil.rmtree(path)
- else:
- os.remove(path)
+ # No teardown cleanup needed: all artifacts are written under TemporaryDirectory.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@source/tests/consistent/test_stat_file.py` around lines 212 - 220, The
tearDown method is deleting files from the current working directory without
proper scoping, which can remove artifacts created by other tests or processes.
Create a test-specific temporary directory in the setUp method that is unique to
this test instance, configure the test to write its outputs to this directory,
and modify the tearDown method to only clean up files within this test-specific
directory. This ensures that the cleanup in tearDown is isolated and does not
interfere with shared cwd artifacts or other concurrent tests.
| stat_path = Path(stat_file_path) | ||
| if stat_path.exists(): | ||
| self.assertTrue( | ||
| stat_path.is_dir(), "Stat file path should be a directory" | ||
| ) |
There was a problem hiding this comment.
Assert stat_file creation unconditionally.
This assertion currently passes even when nothing is created, so the test can miss a broken training.stat_file path.
Suggested fix
stat_path = Path(stat_file_path)
- if stat_path.exists():
- self.assertTrue(
- stat_path.is_dir(), "Stat file path should be a directory"
- )
+ self.assertTrue(
+ stat_path.exists(), "Stat file path should be created"
+ )
+ self.assertTrue(
+ stat_path.is_dir(), "Stat file path should be a directory"
+ )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@source/tests/tf/test_stat_file_integration.py` around lines 99 - 103, The
current assertion in the stat_path validation block is conditional on
stat_path.exists(), which means if the stat_file is never created, the assertion
is skipped entirely and the test passes without detecting the failure. Remove
the if stat_path.exists() condition and instead assert unconditionally that the
stat_path both exists and is a directory. This ensures that when the
training.stat_file creation is broken, the test will properly fail rather than
silently passing.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5551 +/- ##
==========================================
- Coverage 82.23% 82.15% -0.08%
==========================================
Files 894 897 +3
Lines 102002 102777 +775
Branches 4276 4344 +68
==========================================
+ Hits 83877 84435 +558
- Misses 16823 17005 +182
- Partials 1302 1337 +35 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Persist observed_type for TensorFlow stat files and normalize the stat-file test input before calling the lower-level training helper. Also broadcast the global output std to match the shared statistic logic. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
Problem
training/stat_fileis accepted by the shared input schema but was effectively only wired for non-TF backends.master.Change
stat_fileplumbing onto currentmaster: create/openDPPath, pass it throughDPTrainer.build()andModel.data_stat(), and save/load energy statistics under the PyTorch-compatible type-map subdirectory.bias_atom_eas the historical 1-D vector while storing stat files in the cross-backend(ntypes, 1)format.Notes
stat_filein TF #4017.uvx ruff checkon touched files,uvx ruff format --checkon touched files, andpython3 -m py_compileon touched files passed.pytestand the unbuilt checkout lacks compileddeepmd.lib.Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
Summary by CodeRabbit
New Features
training.stat_fileend-to-end: TensorFlow now prepares the target location, passes it through the training/stat APIs, and uses it to compute/restore energy statistics (including observed type handling).Bug Fixes
Tests
training.stat_file.