diff --git a/dpgen/generator/arginfo.py b/dpgen/generator/arginfo.py index 9ed6ba887..92097af89 100644 --- a/dpgen/generator/arginfo.py +++ b/dpgen/generator/arginfo.py @@ -87,6 +87,9 @@ def training_args() -> list[Argument]: list[dargs.Argument] List of training arguments. """ + doc_train_backend = ( + "The backend of the training. Currently only support tensorflow and pytorch." + ) doc_numb_models = "Number of models to be trained in 00.train. 4 is recommend." doc_training_iter0_model_path = "The model used to init the first iter training. Number of element should be equal to numb_models." doc_training_init_model = "Iteration > 0, the model parameters will be initilized from the model trained at the previous iteration. Iteration == 0, the model parameters will be initialized from training_iter0_model_path." @@ -123,6 +126,13 @@ def training_args() -> list[Argument]: doc_training_finetune_model = "At interation 0, finetune the model parameters from the given frozen models. Number of element should be equal to numb_models." return [ + Argument( + "train_backend", + str, + optional=True, + default="tensorflow", + doc=doc_train_backend, + ), Argument("numb_models", int, optional=False, doc=doc_numb_models), Argument( "training_iter0_model_path", diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 4e03a471b..dbc387c05 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -125,6 +125,19 @@ run_opt_file = os.path.join(ROOT_PATH, "generator/lib/calypso_run_opt.py") +def _get_model_suffix(jdata) -> str: + """Return the model suffix based on the backend.""" + suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"} + backend = jdata.get("train_backend", "tensorflow") + if backend in suffix_map: + suffix = suffix_map[backend] + else: + raise ValueError( + f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'." + ) + return suffix + + def get_job_names(jdata): jobkeys = [] for ii in jdata.keys(): @@ -172,7 +185,7 @@ def _check_empty_iter(iter_index, max_v=0): return all(empty_sys) -def copy_model(numb_model, prv_iter_index, cur_iter_index): +def copy_model(numb_model, prv_iter_index, cur_iter_index, suffix=".pb"): cwd = os.getcwd() prv_train_path = os.path.join(make_iter_name(prv_iter_index), train_name) cur_train_path = os.path.join(make_iter_name(cur_iter_index), train_name) @@ -184,7 +197,8 @@ def copy_model(numb_model, prv_iter_index, cur_iter_index): os.chdir(cur_train_path) os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii) os.symlink( - os.path.join(train_task_fmt % ii, "frozen_model.pb"), "graph.%03d.pb" % ii + os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"), + "graph.%03d%s" % (ii, suffix), ) os.chdir(cwd) with open(os.path.join(cur_train_path, "copied"), "w") as fp: @@ -315,10 +329,11 @@ def make_train(iter_index, jdata, mdata): number_old_frames = 0 number_new_frames = 0 + suffix = _get_model_suffix(jdata) model_devi_engine = jdata.get("model_devi_engine", "lammps") if iter_index > 0 and _check_empty_iter(iter_index - 1, fp_task_min): log_task("prev data is empty, copy prev model") - copy_model(numb_models, iter_index - 1, iter_index) + copy_model(numb_models, iter_index - 1, iter_index, suffix) return elif ( model_devi_engine != "calypso" @@ -326,7 +341,7 @@ def make_train(iter_index, jdata, mdata): and _check_skip_train(model_devi_jobs[iter_index - 1]) ): log_task("skip training at step %d " % (iter_index - 1)) - copy_model(numb_models, iter_index - 1, iter_index) + copy_model(numb_models, iter_index - 1, iter_index, suffix) return else: iter_name = make_iter_name(iter_index) @@ -647,7 +662,9 @@ def make_train(iter_index, jdata, mdata): ) if copied_models is not None: for ii in range(len(copied_models)): - _link_old_models(work_path, [copied_models[ii]], ii, basename="init.pb") + _link_old_models( + work_path, [copied_models[ii]], ii, basename=f"init{suffix}" + ) # Copy user defined forward files symlink_user_forward_files(mdata=mdata, task_type="train", work_path=work_path) # HDF5 format for training data @@ -699,6 +716,7 @@ def run_train(iter_index, jdata, mdata): # print("debug:run_train:mdata", mdata) # load json param numb_models = jdata["numb_models"] + suffix = _get_model_suffix(jdata) # train_param = jdata['train_param'] train_input_file = default_train_input_file training_reuse_iter = jdata.get("training_reuse_iter") @@ -730,7 +748,11 @@ def run_train(iter_index, jdata, mdata): "training_init_model, training_init_frozen_model, and training_finetune_model are mutually exclusive." ) - train_command = mdata.get("train_command", "dp") + train_command = mdata.get("train_command", "dp").strip() + # assert train_command == "dp", "The 'train_command' should be 'dp'" # the tests should be updated to run this command + if suffix == ".pth": + train_command += " --pt" + train_resources = mdata["train_resources"] # paths @@ -761,9 +783,9 @@ def run_train(iter_index, jdata, mdata): if training_init_model: init_flag = " --init-model old/model.ckpt" elif training_init_frozen_model is not None: - init_flag = " --init-frz-model old/init.pb" + init_flag = f" --init-frz-model old/init{suffix}" elif training_finetune_model is not None: - init_flag = " --finetune old/init.pb" + init_flag = f" --finetune old/init{suffix}" command = f"{train_command} train {train_input_file}{extra_flags}" command = f"{{ if [ ! -f model.ckpt.index ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}" command = f"/bin/sh -c {shlex.quote(command)}" @@ -792,23 +814,35 @@ def run_train(iter_index, jdata, mdata): if "srtab_file_path" in jdata.keys(): forward_files.append(zbl_file) if training_init_model: - forward_files += [ - os.path.join("old", "model.ckpt.meta"), - os.path.join("old", "model.ckpt.index"), - os.path.join("old", "model.ckpt.data-00000-of-00001"), - ] + if suffix == ".pb": + forward_files += [ + os.path.join("old", "model.ckpt.meta"), + os.path.join("old", "model.ckpt.index"), + os.path.join("old", "model.ckpt.data-00000-of-00001"), + ] + elif suffix == ".pth": + forward_files += [os.path.join("old", "model.ckpt.pt")] elif training_init_frozen_model is not None or training_finetune_model is not None: - forward_files.append(os.path.join("old", "init.pb")) + forward_files.append(os.path.join("old", f"init{suffix}")) - backward_files = ["frozen_model.pb", "lcurve.out", "train.log"] - backward_files += [ - "model.ckpt.meta", - "model.ckpt.index", - "model.ckpt.data-00000-of-00001", + backward_files = [ + f"frozen_model{suffix}", + "lcurve.out", + "train.log", "checkpoint", ] if jdata.get("dp_compress", False): - backward_files.append("frozen_model_compressed.pb") + backward_files.append(f"frozen_model_compressed{suffix}") + + if suffix == ".pb": + backward_files += [ + "model.ckpt.meta", + "model.ckpt.index", + "model.ckpt.data-00000-of-00001", + ] + elif suffix == ".pth": + backward_files += ["model.ckpt.pt"] + if not jdata.get("one_h5", False): init_data_sys_ = jdata["init_data_sys"] init_data_sys = [] @@ -879,13 +913,14 @@ def post_train(iter_index, jdata, mdata): log_task("copied model, do not post train") return # symlink models + suffix = _get_model_suffix(jdata) for ii in range(numb_models): - if not jdata.get("dp_compress", False): - model_name = "frozen_model.pb" - else: - model_name = "frozen_model_compressed.pb" + model_name = f"frozen_model{suffix}" + if jdata.get("dp_compress", False): + model_name = f"frozen_model_compressed{suffix}" + + ofile = os.path.join(work_path, "graph.%03d%s" % (ii, suffix)) task_file = os.path.join(train_task_fmt % ii, model_name) - ofile = os.path.join(work_path, "graph.%03d.pb" % ii) if os.path.isfile(ofile): os.remove(ofile) os.symlink(task_file, ofile) @@ -1124,7 +1159,8 @@ def make_model_devi(iter_index, jdata, mdata): iter_name = make_iter_name(iter_index) train_path = os.path.join(iter_name, train_name) train_path = os.path.abspath(train_path) - models = sorted(glob.glob(os.path.join(train_path, "graph*pb"))) + suffix = _get_model_suffix(jdata) + models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) work_path = os.path.join(iter_name, model_devi_name) create_path(work_path) if model_devi_engine == "calypso": @@ -1305,7 +1341,8 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems): iter_name = make_iter_name(iter_index) train_path = os.path.join(iter_name, train_name) train_path = os.path.abspath(train_path) - models = sorted(glob.glob(os.path.join(train_path, "graph*pb"))) + suffix = _get_model_suffix(jdata) + models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) task_model_list = [] for ii in models: task_model_list.append(os.path.join("..", os.path.basename(ii))) @@ -1502,7 +1539,8 @@ def _make_model_devi_native(iter_index, jdata, mdata, conf_systems): iter_name = make_iter_name(iter_index) train_path = os.path.join(iter_name, train_name) train_path = os.path.abspath(train_path) - models = glob.glob(os.path.join(train_path, "graph*pb")) + suffix = _get_model_suffix(jdata) + models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) task_model_list = [] for ii in models: task_model_list.append(os.path.join("..", os.path.basename(ii))) @@ -1644,7 +1682,8 @@ def _make_model_devi_native_gromacs(iter_index, jdata, mdata, conf_systems): iter_name = make_iter_name(iter_index) train_path = os.path.join(iter_name, train_name) train_path = os.path.abspath(train_path) - models = glob.glob(os.path.join(train_path, "graph*pb")) + suffix = _get_model_suffix(jdata) + models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}"))) task_model_list = [] for ii in models: task_model_list.append(os.path.join("..", os.path.basename(ii))) @@ -1827,7 +1866,8 @@ def _make_model_devi_amber( .replace("@qm_theory@", jdata["low_level"]) .replace("@rcut@", str(jdata["cutoff"])) ) - models = sorted(glob.glob(os.path.join(train_path, "graph.*.pb"))) + suffix = _get_model_suffix(jdata) + models = sorted(glob.glob(os.path.join(train_path, f"graph.*{suffix}"))) task_model_list = [] for ii in models: task_model_list.append(os.path.join("..", os.path.basename(ii))) @@ -1935,7 +1975,9 @@ def run_md_model_devi(iter_index, jdata, mdata): run_tasks = [os.path.basename(ii) for ii in run_tasks_] # dlog.info("all_task is ", all_task) # dlog.info("run_tasks in run_model_deviation",run_tasks_) - all_models = glob.glob(os.path.join(work_path, "graph*pb")) + + suffix = _get_model_suffix(jdata) + all_models = glob.glob(os.path.join(work_path, f"graph*{suffix}")) model_names = [os.path.basename(ii) for ii in all_models] model_devi_engine = jdata.get("model_devi_engine", "lammps") diff --git a/dpgen/simplify/simplify.py b/dpgen/simplify/simplify.py index 7ad08dc77..0ff09d87e 100644 --- a/dpgen/simplify/simplify.py +++ b/dpgen/simplify/simplify.py @@ -31,6 +31,7 @@ record_iter, ) from dpgen.generator.run import ( + _get_model_suffix, data_system_fmt, fp_name, fp_task_fmt, @@ -186,7 +187,9 @@ def make_model_devi(iter_index, jdata, mdata): # link the model train_path = os.path.join(iter_name, train_name) train_path = os.path.abspath(train_path) - models = glob.glob(os.path.join(train_path, "graph*pb")) + suffix = _get_model_suffix(jdata) + models = glob.glob(os.path.join(train_path, f"graph*{suffix}")) + for mm in models: model_name = os.path.basename(mm) os.symlink(mm, os.path.join(work_path, model_name))