diff --git a/Project.toml b/Project.toml index 3fe1d09..bcc9449 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MEDYANSimRunner" uuid = "b58a3b99-22e3-44d1-b5ea-258f082a6fe8" authors = ["nhz2 "] -version = "0.4.2" +version = "0.5.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/README.md b/README.md index 3144da4..2ebaadc 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ These functions can also use the default random number generator, this will auto At the end of `main.jl` there should be the lines: ```julia if abspath(PROGRAM_FILE) == @__FILE__ - MEDYANSimRunner.run_sim(ARGS; jobs, setup, loop, load_snapshot, save_snapshot, done) + MEDYANSimRunner.run(ARGS; jobs, setup, loop, load, save, done) end ``` @@ -59,16 +59,16 @@ The `job` string is also used to seed the default RNG right before `setup` is ca #### `setup(job::String; kwargs...) -> header_dict, state` Return the header dictionary to be written as the `header.json` file in output trajectory. -Also return the state that gets passed on to `loop` and the state that gets passed to `save_snapshot` and `load_snapshot`. +Also return the state that gets passed on to `loop` and the state that gets passed to `save` and `load`. `job::String`: The job. This is used for multi job simulations. -#### `save_snapshot(step::Int, state; kwargs...)-> group::SmallZarrGroups.ZGroup` +#### `save(step::Int, state; kwargs...)-> group::SmallZarrGroups.ZGroup` Return the state of the system as a `SmallZarrGroups.ZGroup` This function should not mutate `state` -#### `load_snapshot(step::Int, group::SmallZarrGroups.ZGroup, state; kwargs...) -> state` -Load the state saved by `save_snapshot` +#### `load(step::Int, group::SmallZarrGroups.ZGroup, state; kwargs...) -> state` +Load the state saved by `save` This function can mutate `state`. `state` may be the state returned from `setup` or the `state` returned by `loop`. This function should return the same output if `state` is the state returned by `loop` or the @@ -82,7 +82,7 @@ Also return the expected value of step when done will first be true, used for di This function should not mutate `state` #### `loop(step::Int, state; kwargs...) -> state` -Return the state that gets passed to `save_snapshot` +Return the state that gets passed to `save` ### Main loop pseudo code @@ -95,13 +95,13 @@ Random.seed!(collect(reinterpret(UInt64, sha256(job)))) job_header, state = setup(job) save job_header step = 0 -SmallZarrGroups.save_zip(snapshot_zip_file, save_snapshot(step, state)) -state = load_snapshot(step, SmallZarrGroups.load_zip(snapshot_zip_file), state) +SmallZarrGroups.save_zip(snapshot_zip_file, save(step, state)) +state = load(step, SmallZarrGroups.load_zip(snapshot_zip_file), state) while true state = loop(step, state) step = step + 1 - SmallZarrGroups.save_zip(snapshot_zip_file, save_snapshot(step, state)) - state = load_snapshot(step, SmallZarrGroups.load_zip(snapshot_zip_file), state) + SmallZarrGroups.save_zip(snapshot_zip_file, save(step, state)) + state = load(step, SmallZarrGroups.load_zip(snapshot_zip_file), state) if done(step::Int, state)[1] break end diff --git a/src/run-sim.jl b/src/run-sim.jl index 3c5b31f..9836bd3 100644 --- a/src/run-sim.jl +++ b/src/run-sim.jl @@ -25,7 +25,7 @@ function get_version_string() end """ - run_sim(ARGS; setup, loop, load_snapshot, save_snapshot, done) + run(ARGS; setup, loop, load, save, done) This function should be called at the end of a script to run a simulation. It takes keyword arguments: @@ -41,10 +41,10 @@ is called once at the beginning of the simulation. - `loop(step::Int, state; kwargs...) -> state` is called once per step of the simulation. -- `save_snapshot(step::Int, state; kwargs...) -> group::SmallZarrGroups.ZGroup` +- `save(step::Int, state; kwargs...) -> group::SmallZarrGroups.ZGroup` is called to save a snapshot. - - `load_snapshot(step::Int, group::SmallZarrGroups.ZGroup, state; kwargs...) -> state` + - `load(step::Int, group::SmallZarrGroups.ZGroup, state; kwargs...) -> state` is called to load a snapshot. - `done(step::Int, state; kwargs...) -> done::Bool, expected_final_step::Int` @@ -54,12 +54,12 @@ is called to check if the simulation is done. $(CLI_HELP) """ -function run_sim(cli_args; +function run(cli_args; jobs::Vector{String}, setup, loop, - save_snapshot, - load_snapshot, + save, + load, done, kwargs... ) @@ -78,16 +78,16 @@ function run_sim(cli_args; continue_job(options.out_dir, job; setup, loop, - save_snapshot, - load_snapshot, + save, + load, done, ) else start_job(options.out_dir, job; setup, loop, - save_snapshot, - load_snapshot, + save, + load, done, ) end @@ -98,16 +98,16 @@ function run_sim(cli_args; continue_job(options.out_dir, job; setup, loop, - save_snapshot, - load_snapshot, + save, + load, done, ) else start_job(options.out_dir, job; setup, loop, - save_snapshot, - load_snapshot, + save, + load, done, ) end @@ -115,12 +115,14 @@ function run_sim(cli_args; return end +@deprecate run_sim(cli_args;save_snapshot, load_snapshot, kwargs...) run(cli_args; save=save_snapshot, load=load_snapshot, kwargs...) false + function start_job(out_dir, job::String; setup, loop, - save_snapshot, - load_snapshot, + save, + load, done, ) basic_name_check.(String.(split(job, '/'; keepempty=true))) @@ -148,7 +150,7 @@ function start_job(out_dir, job::String; write_traj_file(traj, "header.json", codeunits(header_str)) local step::Int = 0 - state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save_snapshot, load_snapshot, prev_sha256) + state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save, load, prev_sha256) @info "Simulation started." while true copy!(Random.default_rng(), rng_state) @@ -156,7 +158,7 @@ function start_job(out_dir, job::String; copy!(rng_state, Random.default_rng()) step += 1 - state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save_snapshot, load_snapshot, prev_sha256) + state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save, load, prev_sha256) copy!(Random.default_rng(), rng_state) isdone::Bool, expected_final_step::Int64 = done(step::Int, state) @@ -177,8 +179,8 @@ end function continue_job(out_dir, job; setup, loop, - save_snapshot, - load_snapshot, + save, + load, done, ) basic_name_check.(String.(split(job, '/'; keepempty=true))) @@ -232,7 +234,7 @@ function continue_job(out_dir, job; prev_sha256 = bytes2hex(sha256(header_str)) write_traj_file(traj, "header.json", codeunits(header_str)) step = 0 - state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save_snapshot, load_snapshot, prev_sha256) + state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save, load, prev_sha256) @info "Simulation started." else @info "Continuing simulation from step $(step)." @@ -241,7 +243,7 @@ function continue_job(out_dir, job; reread_sub_snapshot_group = snapshot_group["snap"] rng_state = str_2_rng(attrs(snapshot_group)["rng_state"]) copy!(Random.default_rng(), rng_state) - state = load_snapshot(step, reread_sub_snapshot_group, state) + state = load(step, reread_sub_snapshot_group, state) copy!(rng_state, Random.default_rng()) prev_sha256 = bytes2hex(sha256(snapshot_data)) if step > 0 @@ -262,7 +264,7 @@ function continue_job(out_dir, job; copy!(rng_state, Random.default_rng()) step += 1 - state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save_snapshot, load_snapshot, prev_sha256) + state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save, load, prev_sha256) copy!(Random.default_rng(), rng_state) isdone, expected_final_step = done(step::Int, state) @@ -286,14 +288,14 @@ function save_load_state!( step::Int, state, traj::String, - save_snapshot, - load_snapshot, + save, + load, prev_sha256::String, ) snapshot_group = ZGroup() copy!(Random.default_rng(), rng_state) - sub_snapshot_group = save_snapshot(step, state) + sub_snapshot_group = save(step, state) copy!(rng_state, Random.default_rng()) snapshot_group["snap"] = sub_snapshot_group @@ -304,7 +306,7 @@ function save_load_state!( reread_sub_snapshot_group = unzip_group(snapshot_data)["snap"] copy!(Random.default_rng(), rng_state) - state = load_snapshot(step, reread_sub_snapshot_group, state) + state = load(step, reread_sub_snapshot_group, state) copy!(rng_state, Random.default_rng()) write_traj_file(traj, SNAP_PREFIX*string(step)*SNAP_POSTFIX, snapshot_data) diff --git a/test/example/main.jl b/test/example/main.jl index 81dd87b..20186f0 100644 --- a/test/example/main.jl +++ b/test/example/main.jl @@ -28,14 +28,14 @@ function setup(job::String; kwargs...) header, state end -function save_snapshot(step::Int, state; kwargs...)::ZGroup +function save(step::Int, state; kwargs...)::ZGroup # @info "saving states" state group = ZGroup() group["states"] = state group end -function load_snapshot(step::Int, group, state; kwargs...) +function load(step::Int, group, state; kwargs...) state .= collect(group["states"]) state end @@ -52,5 +52,5 @@ function loop(step::Int, state; kwargs...) end if abspath(PROGRAM_FILE) == @__FILE__ - MEDYANSimRunner.run_sim(ARGS; jobs, setup, loop, load_snapshot, save_snapshot, done) + MEDYANSimRunner.run(ARGS; jobs, setup, loop, load, save, done) end \ No newline at end of file diff --git a/test/test_ref_out.jl b/test/test_ref_out.jl index e837b55..593c47c 100644 --- a/test/test_ref_out.jl +++ b/test/test_ref_out.jl @@ -27,11 +27,11 @@ warn_only_logger = MinLevelLogger(current_logger(), Logging.Warn); args = ["--out=$test_out","--batch=1"] continue_sim && push!(args,"--continue") with_logger(warn_only_logger) do - MEDYANSimRunner.run_sim(args; + MEDYANSimRunner.run(args; UserCode.jobs, UserCode.setup, - UserCode.save_snapshot, - UserCode.load_snapshot, + UserCode.save, + UserCode.load, UserCode.loop, UserCode.done, ) @@ -77,11 +77,11 @@ end args = ["--out=$test_out","--batch=1"] continue_sim && push!(args,"--continue") with_logger(warn_only_logger) do - MEDYANSimRunner.run_sim(args; + MEDYANSimRunner.run(args; UserCode.jobs, UserCode.setup, - UserCode.save_snapshot, - UserCode.load_snapshot, + UserCode.save, + UserCode.load, UserCode.loop, UserCode.done, ) @@ -93,4 +93,25 @@ end end end end + +@testset "deprecated full run example" begin + test_out = joinpath(@__DIR__, "example/output/deprecated-full") + rm(test_out; force=true, recursive=true) + args = ["--out=$test_out","--batch=1"] + with_logger(warn_only_logger) do + MEDYANSimRunner.run_sim(args; + UserCode.jobs, + UserCode.setup, + save_snapshot=UserCode.save, + load_snapshot=UserCode.load, + UserCode.loop, + UserCode.done, + ) + end + out_diff = sprint(MEDYANSimRunner.print_traj_diff, joinpath(ref_out,"a"), joinpath(test_out,"a")) + if !isempty(out_diff) + println(out_diff) + @test false + end +end end \ No newline at end of file