Skip to content

Commit

Permalink
add option to select backends TF/PT (#1545)
Browse files Browse the repository at this point in the history
reopen PR #1541 due to branch is deleted

add a new key in `param.json` file

```
"train_backend": "pytorch"/"tensorflow",
```
relate to this issue #1462

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
- Improved model management by dynamically generating model suffixes
based on the selected backend, enhancing compatibility.
  
- **Enhancements**
- Updated model-related functions to incorporate backend-specific model
suffixes for accurate file handling during training processes.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: C. Thang Nguyen <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
3 people authored May 11, 2024
1 parent e13c186 commit 9d29459
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 32 deletions.
10 changes: 10 additions & 0 deletions dpgen/generator/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def training_args() -> list[Argument]:
list[dargs.Argument]
List of training arguments.
"""
doc_train_backend = (
"The backend of the training. Currently only support tensorflow and pytorch."
)
doc_numb_models = "Number of models to be trained in 00.train. 4 is recommend."
doc_training_iter0_model_path = "The model used to init the first iter training. Number of element should be equal to numb_models."
doc_training_init_model = "Iteration > 0, the model parameters will be initilized from the model trained at the previous iteration. Iteration == 0, the model parameters will be initialized from training_iter0_model_path."
Expand Down Expand Up @@ -123,6 +126,13 @@ def training_args() -> list[Argument]:
doc_training_finetune_model = "At interation 0, finetune the model parameters from the given frozen models. Number of element should be equal to numb_models."

return [
Argument(
"train_backend",
str,
optional=True,
default="tensorflow",
doc=doc_train_backend,
),
Argument("numb_models", int, optional=False, doc=doc_numb_models),
Argument(
"training_iter0_model_path",
Expand Down
104 changes: 73 additions & 31 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,19 @@
run_opt_file = os.path.join(ROOT_PATH, "generator/lib/calypso_run_opt.py")


def _get_model_suffix(jdata) -> str:
"""Return the model suffix based on the backend."""
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
backend = jdata.get("train_backend", "tensorflow")
if backend in suffix_map:
suffix = suffix_map[backend]
else:
raise ValueError(
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'."
)
return suffix


def get_job_names(jdata):
jobkeys = []
for ii in jdata.keys():
Expand Down Expand Up @@ -172,7 +185,7 @@ def _check_empty_iter(iter_index, max_v=0):
return all(empty_sys)


def copy_model(numb_model, prv_iter_index, cur_iter_index):
def copy_model(numb_model, prv_iter_index, cur_iter_index, suffix=".pb"):
cwd = os.getcwd()
prv_train_path = os.path.join(make_iter_name(prv_iter_index), train_name)
cur_train_path = os.path.join(make_iter_name(cur_iter_index), train_name)
Expand All @@ -184,7 +197,8 @@ def copy_model(numb_model, prv_iter_index, cur_iter_index):
os.chdir(cur_train_path)
os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii)
os.symlink(
os.path.join(train_task_fmt % ii, "frozen_model.pb"), "graph.%03d.pb" % ii
os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"),
"graph.%03d%s" % (ii, suffix),
)
os.chdir(cwd)
with open(os.path.join(cur_train_path, "copied"), "w") as fp:
Expand Down Expand Up @@ -315,18 +329,19 @@ def make_train(iter_index, jdata, mdata):
number_old_frames = 0
number_new_frames = 0

suffix = _get_model_suffix(jdata)
model_devi_engine = jdata.get("model_devi_engine", "lammps")
if iter_index > 0 and _check_empty_iter(iter_index - 1, fp_task_min):
log_task("prev data is empty, copy prev model")
copy_model(numb_models, iter_index - 1, iter_index)
copy_model(numb_models, iter_index - 1, iter_index, suffix)
return
elif (
model_devi_engine != "calypso"
and iter_index > 0
and _check_skip_train(model_devi_jobs[iter_index - 1])
):
log_task("skip training at step %d " % (iter_index - 1))
copy_model(numb_models, iter_index - 1, iter_index)
copy_model(numb_models, iter_index - 1, iter_index, suffix)
return
else:
iter_name = make_iter_name(iter_index)
Expand Down Expand Up @@ -647,7 +662,9 @@ def make_train(iter_index, jdata, mdata):
)
if copied_models is not None:
for ii in range(len(copied_models)):
_link_old_models(work_path, [copied_models[ii]], ii, basename="init.pb")
_link_old_models(
work_path, [copied_models[ii]], ii, basename=f"init{suffix}"
)
# Copy user defined forward files
symlink_user_forward_files(mdata=mdata, task_type="train", work_path=work_path)
# HDF5 format for training data
Expand Down Expand Up @@ -699,6 +716,7 @@ def run_train(iter_index, jdata, mdata):
# print("debug:run_train:mdata", mdata)
# load json param
numb_models = jdata["numb_models"]
suffix = _get_model_suffix(jdata)
# train_param = jdata['train_param']
train_input_file = default_train_input_file
training_reuse_iter = jdata.get("training_reuse_iter")
Expand Down Expand Up @@ -730,7 +748,11 @@ def run_train(iter_index, jdata, mdata):
"training_init_model, training_init_frozen_model, and training_finetune_model are mutually exclusive."
)

train_command = mdata.get("train_command", "dp")
train_command = mdata.get("train_command", "dp").strip()
# assert train_command == "dp", "The 'train_command' should be 'dp'" # the tests should be updated to run this command
if suffix == ".pth":
train_command += " --pt"

train_resources = mdata["train_resources"]

# paths
Expand Down Expand Up @@ -761,9 +783,9 @@ def run_train(iter_index, jdata, mdata):
if training_init_model:
init_flag = " --init-model old/model.ckpt"
elif training_init_frozen_model is not None:
init_flag = " --init-frz-model old/init.pb"
init_flag = f" --init-frz-model old/init{suffix}"
elif training_finetune_model is not None:
init_flag = " --finetune old/init.pb"
init_flag = f" --finetune old/init{suffix}"
command = f"{train_command} train {train_input_file}{extra_flags}"
command = f"{{ if [ ! -f model.ckpt.index ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}"
command = f"/bin/sh -c {shlex.quote(command)}"
Expand Down Expand Up @@ -792,23 +814,35 @@ def run_train(iter_index, jdata, mdata):
if "srtab_file_path" in jdata.keys():
forward_files.append(zbl_file)
if training_init_model:
forward_files += [
os.path.join("old", "model.ckpt.meta"),
os.path.join("old", "model.ckpt.index"),
os.path.join("old", "model.ckpt.data-00000-of-00001"),
]
if suffix == ".pb":
forward_files += [
os.path.join("old", "model.ckpt.meta"),
os.path.join("old", "model.ckpt.index"),
os.path.join("old", "model.ckpt.data-00000-of-00001"),
]
elif suffix == ".pth":
forward_files += [os.path.join("old", "model.ckpt.pt")]
elif training_init_frozen_model is not None or training_finetune_model is not None:
forward_files.append(os.path.join("old", "init.pb"))
forward_files.append(os.path.join("old", f"init{suffix}"))

backward_files = ["frozen_model.pb", "lcurve.out", "train.log"]
backward_files += [
"model.ckpt.meta",
"model.ckpt.index",
"model.ckpt.data-00000-of-00001",
backward_files = [
f"frozen_model{suffix}",
"lcurve.out",
"train.log",
"checkpoint",
]
if jdata.get("dp_compress", False):
backward_files.append("frozen_model_compressed.pb")
backward_files.append(f"frozen_model_compressed{suffix}")

if suffix == ".pb":
backward_files += [
"model.ckpt.meta",
"model.ckpt.index",
"model.ckpt.data-00000-of-00001",
]
elif suffix == ".pth":
backward_files += ["model.ckpt.pt"]

if not jdata.get("one_h5", False):
init_data_sys_ = jdata["init_data_sys"]
init_data_sys = []
Expand Down Expand Up @@ -879,13 +913,14 @@ def post_train(iter_index, jdata, mdata):
log_task("copied model, do not post train")
return
# symlink models
suffix = _get_model_suffix(jdata)
for ii in range(numb_models):
if not jdata.get("dp_compress", False):
model_name = "frozen_model.pb"
else:
model_name = "frozen_model_compressed.pb"
model_name = f"frozen_model{suffix}"
if jdata.get("dp_compress", False):
model_name = f"frozen_model_compressed{suffix}"

ofile = os.path.join(work_path, "graph.%03d%s" % (ii, suffix))
task_file = os.path.join(train_task_fmt % ii, model_name)
ofile = os.path.join(work_path, "graph.%03d.pb" % ii)
if os.path.isfile(ofile):
os.remove(ofile)
os.symlink(task_file, ofile)
Expand Down Expand Up @@ -1124,7 +1159,8 @@ def make_model_devi(iter_index, jdata, mdata):
iter_name = make_iter_name(iter_index)
train_path = os.path.join(iter_name, train_name)
train_path = os.path.abspath(train_path)
models = sorted(glob.glob(os.path.join(train_path, "graph*pb")))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))
work_path = os.path.join(iter_name, model_devi_name)
create_path(work_path)
if model_devi_engine == "calypso":
Expand Down Expand Up @@ -1305,7 +1341,8 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems):
iter_name = make_iter_name(iter_index)
train_path = os.path.join(iter_name, train_name)
train_path = os.path.abspath(train_path)
models = sorted(glob.glob(os.path.join(train_path, "graph*pb")))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))
task_model_list = []
for ii in models:
task_model_list.append(os.path.join("..", os.path.basename(ii)))
Expand Down Expand Up @@ -1502,7 +1539,8 @@ def _make_model_devi_native(iter_index, jdata, mdata, conf_systems):
iter_name = make_iter_name(iter_index)
train_path = os.path.join(iter_name, train_name)
train_path = os.path.abspath(train_path)
models = glob.glob(os.path.join(train_path, "graph*pb"))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))
task_model_list = []
for ii in models:
task_model_list.append(os.path.join("..", os.path.basename(ii)))
Expand Down Expand Up @@ -1644,7 +1682,8 @@ def _make_model_devi_native_gromacs(iter_index, jdata, mdata, conf_systems):
iter_name = make_iter_name(iter_index)
train_path = os.path.join(iter_name, train_name)
train_path = os.path.abspath(train_path)
models = glob.glob(os.path.join(train_path, "graph*pb"))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))
task_model_list = []
for ii in models:
task_model_list.append(os.path.join("..", os.path.basename(ii)))
Expand Down Expand Up @@ -1827,7 +1866,8 @@ def _make_model_devi_amber(
.replace("@qm_theory@", jdata["low_level"])
.replace("@rcut@", str(jdata["cutoff"]))
)
models = sorted(glob.glob(os.path.join(train_path, "graph.*.pb")))
suffix = _get_model_suffix(jdata)
models = sorted(glob.glob(os.path.join(train_path, f"graph.*{suffix}")))
task_model_list = []
for ii in models:
task_model_list.append(os.path.join("..", os.path.basename(ii)))
Expand Down Expand Up @@ -1935,7 +1975,9 @@ def run_md_model_devi(iter_index, jdata, mdata):
run_tasks = [os.path.basename(ii) for ii in run_tasks_]
# dlog.info("all_task is ", all_task)
# dlog.info("run_tasks in run_model_deviation",run_tasks_)
all_models = glob.glob(os.path.join(work_path, "graph*pb"))

suffix = _get_model_suffix(jdata)
all_models = glob.glob(os.path.join(work_path, f"graph*{suffix}"))
model_names = [os.path.basename(ii) for ii in all_models]

model_devi_engine = jdata.get("model_devi_engine", "lammps")
Expand Down
5 changes: 4 additions & 1 deletion dpgen/simplify/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
record_iter,
)
from dpgen.generator.run import (
_get_model_suffix,
data_system_fmt,
fp_name,
fp_task_fmt,
Expand Down Expand Up @@ -186,7 +187,9 @@ def make_model_devi(iter_index, jdata, mdata):
# link the model
train_path = os.path.join(iter_name, train_name)
train_path = os.path.abspath(train_path)
models = glob.glob(os.path.join(train_path, "graph*pb"))
suffix = _get_model_suffix(jdata)
models = glob.glob(os.path.join(train_path, f"graph*{suffix}"))

for mm in models:
model_name = os.path.basename(mm)
os.symlink(mm, os.path.join(work_path, model_name))
Expand Down

0 comments on commit 9d29459

Please sign in to comment.