From 7e1b1f231013db1a72e65ff0436b33eb16fe22c4 Mon Sep 17 00:00:00 2001 From: Sander Vandenhaute Date: Sun, 21 Jul 2024 14:05:44 -0400 Subject: [PATCH] fix remdsort postprocess --- configs/threadpool.yaml | 1 + psiflow/execution.py | 9 +++++-- psiflow/sampling/server.py | 51 ++++++++++++++++++++++---------------- tests/test_sampling.py | 3 ++- 4 files changed, 40 insertions(+), 24 deletions(-) diff --git a/configs/threadpool.yaml b/configs/threadpool.yaml index 654e439..1b66bd0 100644 --- a/configs/threadpool.yaml +++ b/configs/threadpool.yaml @@ -9,6 +9,7 @@ ModelTraining: gpu: true use_threadpool: true max_training_time: 1 + max_workers: 1 # suppress assertion for multigpu training CP2K: cores_per_worker: 2 max_evaluation_time: 0.3 diff --git a/psiflow/execution.py b/psiflow/execution.py index 84ae4be..12cd96a 100644 --- a/psiflow/execution.py +++ b/psiflow/execution.py @@ -42,6 +42,7 @@ def __init__( cores_per_worker: int, use_threadpool: bool, worker_prepend: str, + max_workers: Optional[int] = None, ) -> None: self.parsl_provider = parsl_provider self.gpu = gpu @@ -49,6 +50,7 @@ def __init__( self.use_threadpool = use_threadpool self.worker_prepend = worker_prepend self.name = self.__class__.__name__ + self._max_workers = max_workers @property def cores_available(self): @@ -62,7 +64,10 @@ def cores_available(self): @property def max_workers(self): - return max(1, math.floor(self.cores_available / self.cores_per_worker)) + if self._max_workers is not None: + return self._max_workers + else: + return max(1, math.floor(self.cores_available / self.cores_per_worker)) @property def max_runtime(self): @@ -239,7 +244,7 @@ def __init__( assert max_training_time * 60 < self.max_runtime self.max_training_time = max_training_time if self.max_workers > 1: - message = ('the max_simulation_time keyword does not work ' + message = ('the max_training_time keyword does not work ' 'in combination with multi-gpu training. Adjust ' 'the maximum number of epochs to control the ' 'duration of training') diff --git a/psiflow/sampling/server.py b/psiflow/sampling/server.py index a61739b..a1fb1eb 100644 --- a/psiflow/sampling/server.py +++ b/psiflow/sampling/server.py @@ -1,7 +1,7 @@ import argparse +import os import ast import glob -import shutil import signal import xml.etree.ElementTree as ET from copy import deepcopy @@ -132,8 +132,14 @@ def remdsort(inputfile, prefix="SRT_"): } ) else: - filename = filename + "_" + padb + "." + o.format - ofilename = ofilename + "_" + padb + "." + o.format + # FIX + if o.format == 'ase': + extension = 'extxyz' + else: + extension = o.format + filename = filename + "_" + padb + "." + extension + ofilename = ofilename + "_" + padb + "." + extension + print(filename, ofilename) ntraj.append( { "filename": filename, @@ -380,10 +386,13 @@ def cleanup(args): if "remd" in content: remdsort("input.xml") for filepath in glob.glob("SRT_*"): + # does not use shutil because it is not instantaneous source = filepath target = filepath.replace("SRT_", "") - assert Path(target).exists() # should exist - shutil.copyfile(source, target) + os.remove(target) + assert not Path(target).exists() + os.rename(source, target) + assert Path(target).exists() i = 0 while i < len(states): # try all formattings of bead index (i-PI inconsistency) @@ -449,19 +458,19 @@ def main(): if not args.cleanup: start(args) else: - try: - cleanup(args) - except BaseException as e: # noqa: B036 - print(e) - print("i-PI cleanup failed!") - print("files in directory:") - for filepath in Path.cwd().glob("*"): - print(filepath) - print("") - - names = [p.name for p in Path.cwd().glob("*")] - if "output.checkpoint" in names: - with open("output.checkpoint", "r") as f: - print(f.read()) - else: - print("no output.checkpoint found") + #try: + cleanup(args) + #except BaseException as e: # noqa: B036 + # print(e) + # print("i-PI cleanup failed!") + # print("files in directory:") + # for filepath in Path.cwd().glob("*"): + # print(filepath) + # print("") + + # names = [p.name for p in Path.cwd().glob("*")] + # if "output.checkpoint" in names: + # with open("output.checkpoint", "r") as f: + # print(f.read()) + # else: + # print("no output.checkpoint found") diff --git a/tests/test_sampling.py b/tests/test_sampling.py index d1306d9..219d87d 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -425,7 +425,8 @@ def test_rex(dataset): assert len(partition(walkers)) == 1 assert len(partition(walkers)[0]) == 2 - _ = sample(walkers, steps=50, step=10) + outputs = sample(walkers, steps=50, step=10) + assert outputs[0].trajectory.length().result() == 6 swaps = np.loadtxt(walkers[0].coupling.swapfile.result().filepath) assert len(swaps) > 0 # at least some successful swaps