Skip to content

Commit

Permalink
fix remdsort postprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
svandenhaute committed Jul 21, 2024
1 parent c5ae100 commit 7e1b1f2
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 24 deletions.
1 change: 1 addition & 0 deletions configs/threadpool.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ ModelTraining:
gpu: true
use_threadpool: true
max_training_time: 1
max_workers: 1 # suppress assertion for multigpu training
CP2K:
cores_per_worker: 2
max_evaluation_time: 0.3
Expand Down
9 changes: 7 additions & 2 deletions psiflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ def __init__(
cores_per_worker: int,
use_threadpool: bool,
worker_prepend: str,
max_workers: Optional[int] = None,
) -> None:
self.parsl_provider = parsl_provider
self.gpu = gpu
self.cores_per_worker = cores_per_worker
self.use_threadpool = use_threadpool
self.worker_prepend = worker_prepend
self.name = self.__class__.__name__
self._max_workers = max_workers

@property
def cores_available(self):
Expand All @@ -62,7 +64,10 @@ def cores_available(self):

@property
def max_workers(self):
return max(1, math.floor(self.cores_available / self.cores_per_worker))
if self._max_workers is not None:
return self._max_workers
else:
return max(1, math.floor(self.cores_available / self.cores_per_worker))

@property
def max_runtime(self):
Expand Down Expand Up @@ -239,7 +244,7 @@ def __init__(
assert max_training_time * 60 < self.max_runtime
self.max_training_time = max_training_time
if self.max_workers > 1:
message = ('the max_simulation_time keyword does not work '
message = ('the max_training_time keyword does not work '
'in combination with multi-gpu training. Adjust '
'the maximum number of epochs to control the '
'duration of training')
Expand Down
51 changes: 30 additions & 21 deletions psiflow/sampling/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import os
import ast
import glob
import shutil
import signal
import xml.etree.ElementTree as ET
from copy import deepcopy
Expand Down Expand Up @@ -132,8 +132,14 @@ def remdsort(inputfile, prefix="SRT_"):
}
)
else:
filename = filename + "_" + padb + "." + o.format
ofilename = ofilename + "_" + padb + "." + o.format
# FIX
if o.format == 'ase':
extension = 'extxyz'
else:
extension = o.format
filename = filename + "_" + padb + "." + extension
ofilename = ofilename + "_" + padb + "." + extension
print(filename, ofilename)
ntraj.append(
{
"filename": filename,
Expand Down Expand Up @@ -380,10 +386,13 @@ def cleanup(args):
if "remd" in content:
remdsort("input.xml")
for filepath in glob.glob("SRT_*"):
# does not use shutil because it is not instantaneous
source = filepath
target = filepath.replace("SRT_", "")
assert Path(target).exists() # should exist
shutil.copyfile(source, target)
os.remove(target)
assert not Path(target).exists()
os.rename(source, target)
assert Path(target).exists()
i = 0
while i < len(states):
# try all formattings of bead index (i-PI inconsistency)
Expand Down Expand Up @@ -449,19 +458,19 @@ def main():
if not args.cleanup:
start(args)
else:
try:
cleanup(args)
except BaseException as e: # noqa: B036
print(e)
print("i-PI cleanup failed!")
print("files in directory:")
for filepath in Path.cwd().glob("*"):
print(filepath)
print("")

names = [p.name for p in Path.cwd().glob("*")]
if "output.checkpoint" in names:
with open("output.checkpoint", "r") as f:
print(f.read())
else:
print("no output.checkpoint found")
#try:
cleanup(args)
#except BaseException as e: # noqa: B036
# print(e)
# print("i-PI cleanup failed!")
# print("files in directory:")
# for filepath in Path.cwd().glob("*"):
# print(filepath)
# print("")

# names = [p.name for p in Path.cwd().glob("*")]
# if "output.checkpoint" in names:
# with open("output.checkpoint", "r") as f:
# print(f.read())
# else:
# print("no output.checkpoint found")
3 changes: 2 additions & 1 deletion tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,8 @@ def test_rex(dataset):
assert len(partition(walkers)) == 1
assert len(partition(walkers)[0]) == 2

_ = sample(walkers, steps=50, step=10)
outputs = sample(walkers, steps=50, step=10)
assert outputs[0].trajectory.length().result() == 6

swaps = np.loadtxt(walkers[0].coupling.swapfile.result().filepath)
assert len(swaps) > 0 # at least some successful swaps
Expand Down

0 comments on commit 7e1b1f2

Please sign in to comment.