Skip to content

feat(tf): support training stat_file#5551

Open
njzjz-bot wants to merge 3 commits into
deepmodeling:masterfrom
njzjz:openclaw/fix-tf-stat-file-4017
Open

feat(tf): support training stat_file#5551
njzjz-bot wants to merge 3 commits into
deepmodeling:masterfrom
njzjz:openclaw/fix-tf-stat-file-4017

Conversation

@njzjz-bot

@njzjz-bot njzjz-bot commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Problem

Change

  • Port the TF stat_file plumbing onto current master: create/open DPPath, pass it through DPTrainer.build() and Model.data_stat(), and save/load energy statistics under the PyTorch-compatible type-map subdirectory.
  • Keep TensorFlow's internal bias_atom_e as the historical 1-D vector while storing stat files in the cross-backend (ntypes, 1) format.
  • Add TF and TF/PT consistency coverage derived from feat(tf): add support for stat_file parameter #4926.

Notes

Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)

Summary by CodeRabbit

  • New Features

    • Enabled TensorFlow training to honor training.stat_file end-to-end: TensorFlow now prepares the target location, passes it through the training/stat APIs, and uses it to compute/restore energy statistics (including observed type handling).
  • Bug Fixes

    • Improved stat-file loading behavior by restoring requested bias/std values when available, and falling back to computing them from training data otherwise.
  • Tests

    • Added TensorFlow unit and integration tests for training.stat_file.
    • Added cross-backend consistency tests to compare TensorFlow vs PyTorch stat outputs.

Allow TensorFlow training to accept training/stat_file and reuse saved energy statistics in the same type-map directory layout as PyTorch. This ports the useful part of PR deepmodeling#4926 onto current master and keeps TensorFlow's 1-D fitting bias shape internally.

Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
@dosubot dosubot Bot added the new feature label Jun 17, 2026
@coderabbitai

coderabbitai Bot commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 09a17d8e-751b-47bc-94db-32b839edcb55

📥 Commits

Reviewing files that changed from the base of the PR and between 650ed47 and 3eb2243.

📒 Files selected for processing (4)
  • deepmd/tf/entrypoints/train.py
  • deepmd/tf/model/ener.py
  • deepmd/tf/utils/stat.py
  • source/tests/tf/test_stat_file.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • deepmd/tf/entrypoints/train.py
  • source/tests/tf/test_stat_file.py
  • deepmd/tf/model/ener.py

📝 Walkthrough

Walkthrough

Adds stat_file support to the TensorFlow backend. A new compute_output_stats function in deepmd/tf/utils/stat.py handles filesystem-first restore and compute/save of per-key atomic bias and std arrays. The stat_file_path parameter is propagated through the entire model hierarchy's data_stat signatures and wired through DPTrainer.build and the training entrypoint. The training.stat_file argcheck documentation is updated to remove the PyTorch-only marker.

Changes

TF stat_file Feature

Layer / File(s) Summary
compute_output_stats utility
deepmd/tf/utils/stat.py
Adds _restore_from_file, _save_to_file, _post_process_stat, collect_observed_types_from_stat, save_observed_types_to_file, and compute_output_stats to compute, persist, and restore per-key atomic bias and std arrays compatible with the PyTorch stat-file format.
Model hierarchy data_stat signature propagation
deepmd/tf/model/model.py, deepmd/tf/model/dos.py, deepmd/tf/model/tensor.py, deepmd/tf/model/frozen.py, deepmd/tf/model/linear.py, deepmd/tf/model/pairwise_dprc.py
Adds optional stat_file_path: DPPath | None = None to data_stat across the model base class and all concrete model subclasses; LinearModel and PairwiseDPRc forward the parameter to their sub-model calls.
EnerModel integration with compute_output_stats
deepmd/tf/model/ener.py
Extends EnerModel.data_stat and _compute_output_stat to accept stat_file_path; when provided, namespaces path by type_map, saves observed types, constructs assigned_bias from fitting atomic energies, calls compute_output_stats with keys=["energy"] and optional rcond, and updates fitting.bias_atom_e from the returned bias.
Trainer and entrypoint wiring
deepmd/tf/train/trainer.py, deepmd/tf/entrypoints/train.py
DPTrainer.build gains stat_file_path parameter and forwards it to model.data_stat; the training entrypoint reads training.stat_file, creates an empty HDF5 file (for .h5/.hdf5) or directory on the chief rank, wraps it as DPPath(..., "a"), and passes it into model.build.
argcheck doc fix and tests
deepmd/utils/argcheck.py, source/tests/tf/test_stat_file.py, source/tests/tf/test_stat_file_integration.py, source/tests/consistent/test_stat_file.py
Removes the PyTorch-only documentation prefix from training.stat_file in argcheck; adds TF unit and integration tests verifying stat_file directory creation and config flow, plus a cross-backend consistency test comparing TF and PT stat file outputs by directory structure and .npy file values.

Sequence Diagram(s)

sequenceDiagram
    participant Entrypoint as train.py entrypoint
    participant DPTrainer as DPTrainer.build
    participant EnerModel as EnerModel.data_stat
    participant StatUtil as compute_output_stats
    participant Filesystem as DPPath / Filesystem

    Entrypoint->>Filesystem: extract training.stat_file
    Entrypoint->>Filesystem: create empty HDF5 file or directory on chief rank
    Entrypoint->>Entrypoint: wrap as DPPath(stat_file, "a")
    Entrypoint->>DPTrainer: model.build(..., stat_file_path=DPPath)
    DPTrainer->>EnerModel: model.data_stat(data, stat_file_path=DPPath)
    EnerModel->>StatUtil: compute_output_stats(all_stat, ntypes, stat_file_path=DPPath, ...)
    StatUtil->>Filesystem: _restore_from_file: check bias/std files
    alt files exist
        Filesystem-->>StatUtil: loaded bias_out, std_out
    else files missing
        StatUtil->>StatUtil: compute via compute_stats_from_redu
        StatUtil->>Filesystem: _save_to_file: write bias/std files
    end
    StatUtil-->>EnerModel: bias_out, std_out
    EnerModel->>EnerModel: fitting.bias_atom_e updated
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5269: Both PRs implement/propagate observed-type persistence via stat_file: the main PR adds TF-side save_observed_types_to_file/stat_file_path plumbing into compute_output_stats, while the retrieved PR adds dpmodel stat-file helpers to restore/save observed_type and uses them in PT stat computation.

  • deepmodeling/deepmd-kit#5270: Both PRs implement support for an optional training.stat_file by creating the target stat location (HDF5 vs directory) in the backend training entrypoint and wiring it into the training/stat computation pipeline (TF: deepmd/tf/entrypoints/train.pymodel.build(... stat_file_path=...); PT: deepmd/pt_expt/entrypoints/main.pyget_trainer prepares the same stat path).

Suggested reviewers

  • iProzd
  • wanghan-iapcm
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 55.56% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat(tf): support training stat_file' directly and clearly describes the main change: adding stat_file support to the TensorFlow backend.
Linked Issues check ✅ Passed The PR fully implements the feature request #4017 to support training.stat_file in TensorFlow backend, including stat file creation/management, cross-backend format compatibility, and comprehensive test coverage.
Out of Scope Changes check ✅ Passed All changes are focused on implementing stat_file support for TensorFlow backend as specified in the objectives, with no unrelated modifications detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/tf/entrypoints/train.py`:
- Around line 247-252: The code at lines 247-252 fails for nested paths because
parent directories are not created before attempting to initialize the stat file
target. Before the conditional block that checks if stat_file_raw exists, ensure
all parent directories are created using
Path(stat_file_raw).parent.mkdir(parents=True, exist_ok=True). This will create
the full directory hierarchy needed before attempting either the h5py.File()
call for HDF5 files or the Path(stat_file_raw).mkdir() call for directory
creation. The parents=True parameter ensures intermediate directories are
created, and exist_ok=True prevents errors if directories already exist.

In `@deepmd/tf/model/linear.py`:
- Around line 98-103: The data_stat method passes the same stat_file_path to all
child models in the loop, which causes later submodels to incorrectly reuse
previously saved statistics from earlier submodels. Modify the data_stat method
to namespace the stat_file_path for each model by appending a unique identifier
(such as the model index or model name) to the original path before passing it
to model.data_stat(). This ensures each model in self.models writes to and reads
from its own separate stat file.

In `@deepmd/tf/model/pairwise_dprc.py`:
- Around line 325-327: The data_stat method in the pairwise_dprc.py file passes
the same stat_file_path to both qm_model.data_stat() and qmmm_model.data_stat()
calls, which can cause one model to load the other's statistics instead of
computing its own. Modify the stat_file_path parameter for each call to use
separate namespaces or identifiers that distinguish between QM and QMMM models,
ensuring each model loads or computes its own statistics independently.

In `@deepmd/tf/utils/stat.py`:
- Around line 16-19: The functions `_restore_from_file` and
`compute_output_stats` use a mutable default argument `keys=["energy"]` which
triggers Ruff B006 violations. Replace the mutable default list with `keys=None`
in both function signatures, then at the beginning of each function body, add a
check to initialize `keys` to `["energy"]` if it is `None` (e.g., using `if keys
is None: keys = ["energy"]`). Since the parameter is only read and never
mutated, this change is safe and will resolve the Ruff violations.

In `@source/tests/consistent/test_stat_file.py`:
- Around line 115-117: The subprocess.run() call in the test does not include a
timeout parameter, which can cause the test to hang indefinitely if the CLI
command stalls. Add a timeout parameter to the subprocess.run() call to ensure
the test completes within a reasonable timeframe and fails gracefully if the
subprocess takes too long to complete.
- Around line 212-220: The tearDown method is deleting files from the current
working directory without proper scoping, which can remove artifacts created by
other tests or processes. Create a test-specific temporary directory in the
setUp method that is unique to this test instance, configure the test to write
its outputs to this directory, and modify the tearDown method to only clean up
files within this test-specific directory. This ensures that the cleanup in
tearDown is isolated and does not interfere with shared cwd artifacts or other
concurrent tests.

In `@source/tests/tf/test_stat_file_integration.py`:
- Around line 99-103: The current assertion in the stat_path validation block is
conditional on stat_path.exists(), which means if the stat_file is never
created, the assertion is skipped entirely and the test passes without detecting
the failure. Remove the if stat_path.exists() condition and instead assert
unconditionally that the stat_path both exists and is a directory. This ensures
that when the training.stat_file creation is broken, the test will properly fail
rather than silently passing.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 23d2b3aa-2fce-44a0-8c17-e0bda3504fd0

📥 Commits

Reviewing files that changed from the base of the PR and between a4c5592 and 650ed47.

📒 Files selected for processing (14)
  • deepmd/tf/entrypoints/train.py
  • deepmd/tf/model/dos.py
  • deepmd/tf/model/ener.py
  • deepmd/tf/model/frozen.py
  • deepmd/tf/model/linear.py
  • deepmd/tf/model/model.py
  • deepmd/tf/model/pairwise_dprc.py
  • deepmd/tf/model/tensor.py
  • deepmd/tf/train/trainer.py
  • deepmd/tf/utils/stat.py
  • deepmd/utils/argcheck.py
  • source/tests/consistent/test_stat_file.py
  • source/tests/tf/test_stat_file.py
  • source/tests/tf/test_stat_file_integration.py

Comment thread deepmd/tf/entrypoints/train.py Outdated
Comment thread deepmd/tf/model/linear.py
Comment on lines +98 to 103
def data_stat(
self, data: DeepmdDataSystem, stat_file_path: DPPath | None = None
) -> None:
for model in self.models:
model.data_stat(data)
model.data_stat(data, stat_file_path=stat_file_path)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Namespace stat_file_path per submodel to avoid stat collisions.

Line 102 passes the same path to every child model. Because EnerModel restores from existing files before recomputing (see deepmd/tf/model/ener.py, Line 187), later submodels can incorrectly reuse earlier submodels’ saved stats.

Suggested fix
 def data_stat(
     self, data: DeepmdDataSystem, stat_file_path: DPPath | None = None
 ) -> None:
-    for model in self.models:
-        model.data_stat(data, stat_file_path=stat_file_path)
+    for ii, model in enumerate(self.models):
+        model_stat_path = (
+            None if stat_file_path is None else stat_file_path / f"model{ii}"
+        )
+        model.data_stat(data, stat_file_path=model_stat_path)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/tf/model/linear.py` around lines 98 - 103, The data_stat method passes
the same stat_file_path to all child models in the loop, which causes later
submodels to incorrectly reuse previously saved statistics from earlier
submodels. Modify the data_stat method to namespace the stat_file_path for each
model by appending a unique identifier (such as the model index or model name)
to the original path before passing it to model.data_stat(). This ensures each
model in self.models writes to and reads from its own separate stat file.

Comment on lines +325 to +327
def data_stat(self, data: dict, stat_file_path: DPPath | None = None) -> None:
self.qm_model.data_stat(data, stat_file_path=stat_file_path)
self.qmmm_model.data_stat(data, stat_file_path=stat_file_path)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Use separate stat-file namespaces for QM and QMMM models.

Line 326 and Line 327 forward the same stat_file_path to both branches. With restore-first behavior in EnerModel (Line 187 in deepmd/tf/model/ener.py), one branch can load the other branch’s stats instead of computing its own.

Suggested fix
 def data_stat(self, data: dict, stat_file_path: DPPath | None = None) -> None:
-    self.qm_model.data_stat(data, stat_file_path=stat_file_path)
-    self.qmmm_model.data_stat(data, stat_file_path=stat_file_path)
+    qm_stat_path = None if stat_file_path is None else stat_file_path / "qm"
+    qmmm_stat_path = None if stat_file_path is None else stat_file_path / "qmmm"
+    self.qm_model.data_stat(data, stat_file_path=qm_stat_path)
+    self.qmmm_model.data_stat(data, stat_file_path=qmmm_stat_path)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/tf/model/pairwise_dprc.py` around lines 325 - 327, The data_stat
method in the pairwise_dprc.py file passes the same stat_file_path to both
qm_model.data_stat() and qmmm_model.data_stat() calls, which can cause one model
to load the other's statistics instead of computing its own. Modify the
stat_file_path parameter for each call to use separate namespaces or identifiers
that distinguish between QM and QMMM models, ensuring each model loads or
computes its own statistics independently.

Comment thread deepmd/tf/utils/stat.py
Comment on lines +115 to +117
result = subprocess.run(
cmd, cwd=temp_dir, capture_output=True, text=True, env=env
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Add a timeout to the training subprocess.

Without a timeout, this test can hang the whole suite indefinitely on stalled CLI runs.

Suggested fix
-        result = subprocess.run(
-            cmd, cwd=temp_dir, capture_output=True, text=True, env=env
-        )
+        try:
+            result = subprocess.run(
+                cmd,
+                cwd=temp_dir,
+                capture_output=True,
+                text=True,
+                env=env,
+                timeout=300,
+            )
+        except subprocess.TimeoutExpired as exc:
+            self.fail(f"Training timed out for {backend} backend: {exc}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
result = subprocess.run(
cmd, cwd=temp_dir, capture_output=True, text=True, env=env
)
try:
result = subprocess.run(
cmd,
cwd=temp_dir,
capture_output=True,
text=True,
env=env,
timeout=300,
)
except subprocess.TimeoutExpired as exc:
self.fail(f"Training timed out for {backend} backend: {exc}")
🧰 Tools
🪛 Ruff (0.15.17)

[error] 115-115: subprocess call: check for execution of untrusted input

(S603)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/consistent/test_stat_file.py` around lines 115 - 117, The
subprocess.run() call in the test does not include a timeout parameter, which
can cause the test to hang indefinitely if the CLI command stalls. Add a timeout
parameter to the subprocess.run() call to ensure the test completes within a
reasonable timeframe and fails gracefully if the subprocess takes too long to
complete.

Comment on lines +212 to +220
def tearDown(self) -> None:
"""Clean up any temporary files."""
# Clean up any leftover files
for path in ["checkpoint", "lcurve.out", "model.ckpt"]:
if os.path.exists(path):
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Avoid deleting shared cwd artifacts in tearDown.

This cleanup is not scoped to files created by this test and can remove unrelated outputs from other tests/processes.

Suggested fix
-    def tearDown(self) -> None:
-        """Clean up any temporary files."""
-        # Clean up any leftover files
-        for path in ["checkpoint", "lcurve.out", "model.ckpt"]:
-            if os.path.exists(path):
-                if os.path.isdir(path):
-                    shutil.rmtree(path)
-                else:
-                    os.remove(path)
+    # No teardown cleanup needed: all artifacts are written under TemporaryDirectory.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/consistent/test_stat_file.py` around lines 212 - 220, The
tearDown method is deleting files from the current working directory without
proper scoping, which can remove artifacts created by other tests or processes.
Create a test-specific temporary directory in the setUp method that is unique to
this test instance, configure the test to write its outputs to this directory,
and modify the tearDown method to only clean up files within this test-specific
directory. This ensures that the cleanup in tearDown is isolated and does not
interfere with shared cwd artifacts or other concurrent tests.

Comment on lines +99 to +103
stat_path = Path(stat_file_path)
if stat_path.exists():
self.assertTrue(
stat_path.is_dir(), "Stat file path should be a directory"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Assert stat_file creation unconditionally.

This assertion currently passes even when nothing is created, so the test can miss a broken training.stat_file path.

Suggested fix
             stat_path = Path(stat_file_path)
-            if stat_path.exists():
-                self.assertTrue(
-                    stat_path.is_dir(), "Stat file path should be a directory"
-                )
+            self.assertTrue(
+                stat_path.exists(), "Stat file path should be created"
+            )
+            self.assertTrue(
+                stat_path.is_dir(), "Stat file path should be a directory"
+            )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/tf/test_stat_file_integration.py` around lines 99 - 103, The
current assertion in the stat_path validation block is conditional on
stat_path.exists(), which means if the stat_file is never created, the assertion
is skipped entirely and the test passes without detecting the failure. Remove
the if stat_path.exists() condition and instead assert unconditionally that the
stat_path both exists and is a directory. This ensures that when the
training.stat_file creation is broken, the test will properly fail rather than
silently passing.

@codecov

codecov Bot commented Jun 17, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 82.00000% with 27 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.15%. Comparing base (4a552e3) to head (3eb2243).
⚠️ Report is 4 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/tf/utils/stat.py 77.55% 22 Missing ⚠️
deepmd/tf/entrypoints/train.py 87.50% 2 Missing ⚠️
deepmd/tf/model/ener.py 89.47% 2 Missing ⚠️
deepmd/tf/model/linear.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5551      +/-   ##
==========================================
- Coverage   82.23%   82.15%   -0.08%     
==========================================
  Files         894      897       +3     
  Lines      102002   102777     +775     
  Branches     4276     4344      +68     
==========================================
+ Hits        83877    84435     +558     
- Misses      16823    17005     +182     
- Partials     1302     1337      +35     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Persist observed_type for TensorFlow stat files and normalize the
stat-file test input before calling the lower-level training helper.
Also broadcast the global output std to match the shared statistic logic.

Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Support stat_file in TF

1 participant