diff --git a/dpgen/simplify/simplify.py b/dpgen/simplify/simplify.py index 24205fda3..30b3472ac 100644 --- a/dpgen/simplify/simplify.py +++ b/dpgen/simplify/simplify.py @@ -221,7 +221,9 @@ def run_model_devi(iter_index, jdata, mdata): commands = [] run_tasks = ["."] # get models - models = glob.glob(os.path.join(work_path, "graph*pb")) + suffix = _get_model_suffix(jdata) + models = glob.glob(os.path.join(work_path, f"graph*{suffix}")) + assert len(models) > 0, "No model file found." model_names = [os.path.basename(ii) for ii in models] task_model_list = [] for ii in model_names: diff --git a/tests/simplify/test_run_model_devi.py b/tests/simplify/test_run_model_devi.py index e928afa8e..28d5732e5 100644 --- a/tests/simplify/test_run_model_devi.py +++ b/tests/simplify/test_run_model_devi.py @@ -17,6 +17,9 @@ class TestOneH5(unittest.TestCase): def setUp(self): work_path = Path("iter.000000") / "01.model_devi" work_path.mkdir(parents=True, exist_ok=True) + # fake models + for ii in range(4): + (work_path / f"graph.{ii:03d}.pb").touch() with tempfile.TemporaryDirectory() as tmpdir: with open(Path(tmpdir) / "test.xyz", "w") as f: f.write(