Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename run_sim, save_snapshot, and load_snapshot to run, save and load. #28

Merged
merged 5 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MEDYANSimRunner"
uuid = "b58a3b99-22e3-44d1-b5ea-258f082a6fe8"
authors = ["nhz2 <[email protected]>"]
version = "0.4.2"
version = "0.5.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ These functions can also use the default random number generator, this will auto
At the end of `main.jl` there should be the lines:
```julia
if abspath(PROGRAM_FILE) == @__FILE__
MEDYANSimRunner.run_sim(ARGS; jobs, setup, loop, load_snapshot, save_snapshot, done)
MEDYANSimRunner.run(ARGS; jobs, setup, loop, load, save, done)
end
```

Expand All @@ -59,16 +59,16 @@ The `job` string is also used to seed the default RNG right before `setup` is ca

#### `setup(job::String; kwargs...) -> header_dict, state`
Return the header dictionary to be written as the `header.json` file in output trajectory.
Also return the state that gets passed on to `loop` and the state that gets passed to `save_snapshot` and `load_snapshot`.
Also return the state that gets passed on to `loop` and the state that gets passed to `save` and `load`.

`job::String`: The job. This is used for multi job simulations.

#### `save_snapshot(step::Int, state; kwargs...)-> group::SmallZarrGroups.ZGroup`
#### `save(step::Int, state; kwargs...)-> group::SmallZarrGroups.ZGroup`
Return the state of the system as a `SmallZarrGroups.ZGroup`
This function should not mutate `state`

#### `load_snapshot(step::Int, group::SmallZarrGroups.ZGroup, state; kwargs...) -> state`
Load the state saved by `save_snapshot`
#### `load(step::Int, group::SmallZarrGroups.ZGroup, state; kwargs...) -> state`
Load the state saved by `save`
This function can mutate `state`.
`state` may be the state returned from `setup` or the `state` returned by `loop`.
This function should return the same output if `state` is the state returned by `loop` or the
Expand All @@ -82,7 +82,7 @@ Also return the expected value of step when done will first be true, used for di
This function should not mutate `state`

#### `loop(step::Int, state; kwargs...) -> state`
Return the state that gets passed to `save_snapshot`
Return the state that gets passed to `save`


### Main loop pseudo code
Expand All @@ -95,13 +95,13 @@ Random.seed!(collect(reinterpret(UInt64, sha256(job))))
job_header, state = setup(job)
save job_header
step = 0
SmallZarrGroups.save_zip(snapshot_zip_file, save_snapshot(step, state))
state = load_snapshot(step, SmallZarrGroups.load_zip(snapshot_zip_file), state)
SmallZarrGroups.save_zip(snapshot_zip_file, save(step, state))
state = load(step, SmallZarrGroups.load_zip(snapshot_zip_file), state)
while true
state = loop(step, state)
step = step + 1
SmallZarrGroups.save_zip(snapshot_zip_file, save_snapshot(step, state))
state = load_snapshot(step, SmallZarrGroups.load_zip(snapshot_zip_file), state)
SmallZarrGroups.save_zip(snapshot_zip_file, save(step, state))
state = load(step, SmallZarrGroups.load_zip(snapshot_zip_file), state)
if done(step::Int, state)[1]
break
end
Expand Down
56 changes: 29 additions & 27 deletions src/run-sim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function get_version_string()
end

"""
run_sim(ARGS; setup, loop, load_snapshot, save_snapshot, done)
run(ARGS; setup, loop, load, save, done)

This function should be called at the end of a script to run a simulation.
It takes keyword arguments:
Expand All @@ -41,10 +41,10 @@ is called once at the beginning of the simulation.
- `loop(step::Int, state; kwargs...) -> state`
is called once per step of the simulation.

- `save_snapshot(step::Int, state; kwargs...) -> group::SmallZarrGroups.ZGroup`
- `save(step::Int, state; kwargs...) -> group::SmallZarrGroups.ZGroup`
is called to save a snapshot.

- `load_snapshot(step::Int, group::SmallZarrGroups.ZGroup, state; kwargs...) -> state`
- `load(step::Int, group::SmallZarrGroups.ZGroup, state; kwargs...) -> state`
is called to load a snapshot.

- `done(step::Int, state; kwargs...) -> done::Bool, expected_final_step::Int`
Expand All @@ -54,12 +54,12 @@ is called to check if the simulation is done.

$(CLI_HELP)
"""
function run_sim(cli_args;
function run(cli_args;
jobs::Vector{String},
setup,
loop,
save_snapshot,
load_snapshot,
save,
load,
done,
kwargs...
)
Expand All @@ -78,16 +78,16 @@ function run_sim(cli_args;
continue_job(options.out_dir, job;
setup,
loop,
save_snapshot,
load_snapshot,
save,
load,
done,
)
else
start_job(options.out_dir, job;
setup,
loop,
save_snapshot,
load_snapshot,
save,
load,
done,
)
end
Expand All @@ -98,29 +98,31 @@ function run_sim(cli_args;
continue_job(options.out_dir, job;
setup,
loop,
save_snapshot,
load_snapshot,
save,
load,
done,
)
else
start_job(options.out_dir, job;
setup,
loop,
save_snapshot,
load_snapshot,
save,
load,
done,
)
end
end
return
end

@deprecate run_sim(cli_args;save_snapshot, load_snapshot, kwargs...) run(cli_args; save=save_snapshot, load=load_snapshot, kwargs...) false


function start_job(out_dir, job::String;
setup,
loop,
save_snapshot,
load_snapshot,
save,
load,
done,
)
basic_name_check.(String.(split(job, '/'; keepempty=true)))
Expand Down Expand Up @@ -148,15 +150,15 @@ function start_job(out_dir, job::String;
write_traj_file(traj, "header.json", codeunits(header_str))
local step::Int = 0

state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save_snapshot, load_snapshot, prev_sha256)
state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save, load, prev_sha256)
@info "Simulation started."
while true
copy!(Random.default_rng(), rng_state)
state = loop(step, state)
copy!(rng_state, Random.default_rng())

step += 1
state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save_snapshot, load_snapshot, prev_sha256)
state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save, load, prev_sha256)

copy!(Random.default_rng(), rng_state)
isdone::Bool, expected_final_step::Int64 = done(step::Int, state)
Expand All @@ -177,8 +179,8 @@ end
function continue_job(out_dir, job;
setup,
loop,
save_snapshot,
load_snapshot,
save,
load,
done,
)
basic_name_check.(String.(split(job, '/'; keepempty=true)))
Expand Down Expand Up @@ -232,7 +234,7 @@ function continue_job(out_dir, job;
prev_sha256 = bytes2hex(sha256(header_str))
write_traj_file(traj, "header.json", codeunits(header_str))
step = 0
state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save_snapshot, load_snapshot, prev_sha256)
state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save, load, prev_sha256)
@info "Simulation started."
else
@info "Continuing simulation from step $(step)."
Expand All @@ -241,7 +243,7 @@ function continue_job(out_dir, job;
reread_sub_snapshot_group = snapshot_group["snap"]
rng_state = str_2_rng(attrs(snapshot_group)["rng_state"])
copy!(Random.default_rng(), rng_state)
state = load_snapshot(step, reread_sub_snapshot_group, state)
state = load(step, reread_sub_snapshot_group, state)
copy!(rng_state, Random.default_rng())
prev_sha256 = bytes2hex(sha256(snapshot_data))
if step > 0
Expand All @@ -262,7 +264,7 @@ function continue_job(out_dir, job;
copy!(rng_state, Random.default_rng())

step += 1
state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save_snapshot, load_snapshot, prev_sha256)
state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save, load, prev_sha256)

copy!(Random.default_rng(), rng_state)
isdone, expected_final_step = done(step::Int, state)
Expand All @@ -286,14 +288,14 @@ function save_load_state!(
step::Int,
state,
traj::String,
save_snapshot,
load_snapshot,
save,
load,
prev_sha256::String,
)
snapshot_group = ZGroup()

copy!(Random.default_rng(), rng_state)
sub_snapshot_group = save_snapshot(step, state)
sub_snapshot_group = save(step, state)
copy!(rng_state, Random.default_rng())

snapshot_group["snap"] = sub_snapshot_group
Expand All @@ -304,7 +306,7 @@ function save_load_state!(
reread_sub_snapshot_group = unzip_group(snapshot_data)["snap"]

copy!(Random.default_rng(), rng_state)
state = load_snapshot(step, reread_sub_snapshot_group, state)
state = load(step, reread_sub_snapshot_group, state)
copy!(rng_state, Random.default_rng())

write_traj_file(traj, SNAP_PREFIX*string(step)*SNAP_POSTFIX, snapshot_data)
Expand Down
6 changes: 3 additions & 3 deletions test/example/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ function setup(job::String; kwargs...)
header, state
end

function save_snapshot(step::Int, state; kwargs...)::ZGroup
function save(step::Int, state; kwargs...)::ZGroup
# @info "saving states" state
group = ZGroup()
group["states"] = state
group
end

function load_snapshot(step::Int, group, state; kwargs...)
function load(step::Int, group, state; kwargs...)
state .= collect(group["states"])
state
end
Expand All @@ -52,5 +52,5 @@ function loop(step::Int, state; kwargs...)
end

if abspath(PROGRAM_FILE) == @__FILE__
MEDYANSimRunner.run_sim(ARGS; jobs, setup, loop, load_snapshot, save_snapshot, done)
MEDYANSimRunner.run(ARGS; jobs, setup, loop, load, save, done)
end
33 changes: 27 additions & 6 deletions test/test_ref_out.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ warn_only_logger = MinLevelLogger(current_logger(), Logging.Warn);
args = ["--out=$test_out","--batch=1"]
continue_sim && push!(args,"--continue")
with_logger(warn_only_logger) do
MEDYANSimRunner.run_sim(args;
MEDYANSimRunner.run(args;
UserCode.jobs,
UserCode.setup,
UserCode.save_snapshot,
UserCode.load_snapshot,
UserCode.save,
UserCode.load,
UserCode.loop,
UserCode.done,
)
Expand Down Expand Up @@ -77,11 +77,11 @@ end
args = ["--out=$test_out","--batch=1"]
continue_sim && push!(args,"--continue")
with_logger(warn_only_logger) do
MEDYANSimRunner.run_sim(args;
MEDYANSimRunner.run(args;
UserCode.jobs,
UserCode.setup,
UserCode.save_snapshot,
UserCode.load_snapshot,
UserCode.save,
UserCode.load,
UserCode.loop,
UserCode.done,
)
Expand All @@ -93,4 +93,25 @@ end
end
end
end

@testset "deprecated full run example" begin
test_out = joinpath(@__DIR__, "example/output/deprecated-full")
rm(test_out; force=true, recursive=true)
args = ["--out=$test_out","--batch=1"]
with_logger(warn_only_logger) do
MEDYANSimRunner.run_sim(args;
UserCode.jobs,
UserCode.setup,
save_snapshot=UserCode.save,
load_snapshot=UserCode.load,
UserCode.loop,
UserCode.done,
)
end
out_diff = sprint(MEDYANSimRunner.print_traj_diff, joinpath(ref_out,"a"), joinpath(test_out,"a"))
if !isempty(out_diff)
println(out_diff)
@test false
end
end
end
Loading