From adb6ef099cc0ce6918f18b445e1538a01bbad89e Mon Sep 17 00:00:00 2001 From: corentinlger Date: Fri, 15 Mar 2024 20:03:22 +0100 Subject: [PATCH 1/7] Add first saving functions in local simulation run --- .gitignore | 4 +++- scripts/run_simulation.py | 3 ++- vivarium/simulator/simulator.py | 37 ++++++++++++++++++++++++++++----- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 446d794..b9c67cb 100644 --- a/.gitignore +++ b/.gitignore @@ -119,6 +119,7 @@ venv/ ENV/ env.bak/ venv.bak/ +myvenv/ # Spyder project settings .spyderproject @@ -158,5 +159,6 @@ profiler_stats .vscode .pylintrc -myvenv/ + test.ipynb +Results/* \ No newline at end of file diff --git a/scripts/run_simulation.py b/scripts/run_simulation.py index e9ba40d..36d09fb 100644 --- a/scripts/run_simulation.py +++ b/scripts/run_simulation.py @@ -15,6 +15,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='Simulator Configuration') # Experiment run arguments parser.add_argument('--num_steps', type=int, default=10, help='Number of simulation loops') + parser.add_argument('--saving_name', type=str, default='', help='Name of the saving directory') # Simulator config arguments parser.add_argument('--box_size', type=float, default=100.0, help='Size of the simulation box') parser.add_argument('--n_agents', type=int, default=10, help='Number of agents') @@ -77,6 +78,6 @@ def parse_args(): lg.info("Running simulation") - simulator.run(threaded=False, num_steps=args.num_steps) + simulator.run(threaded=False, num_steps=args.num_steps, saving_name=args.saving_name) lg.info("Simulation complete") diff --git a/vivarium/simulator/simulator.py b/vivarium/simulator/simulator.py index cd93c82..6ab07d5 100644 --- a/vivarium/simulator/simulator.py +++ b/vivarium/simulator/simulator.py @@ -1,7 +1,9 @@ +import os import time import threading import math import logging +import pickle from functools import partial from contextlib import contextmanager @@ -98,7 +100,7 @@ def step(self, state, neighbors): return new_state, neighbors - def run(self, threaded=False, num_steps=math.inf): + def run(self, threaded=False, num_steps=math.inf, saving_name=None): """Run the simulator for the desired number of timesteps, either in a separate thread or not :param threaded: wether to run the simulation in a thread or not, defaults to False @@ -111,12 +113,12 @@ def run(self, threaded=False, num_steps=math.inf): # Else run it either in a thread or not if threaded: # Set the num_loops attribute with a partial func to launch _run in a thread - _run = partial(self._run, num_steps=num_steps) + _run = partial(self._run, num_steps=num_steps, saving_name=saving_name) threading.Thread(target=_run).start() else: - self._run(num_steps) + self._run(num_steps=num_steps, saving_name=saving_name) - def _run(self, num_steps): + def _run(self, num_steps, saving_name): """Function that runs the simulator for the desired number of steps. Used to be called either normally or in a thread. :param num_steps: number of simulation steps @@ -127,6 +129,15 @@ def _run(self, num_steps): loop_count = 0 sleep_time = 0 + + if saving_name: + save = True + frames = [] + saving_dir = f"Results/{saving_name}" + os.makedirs(saving_dir, exist_ok=True) + lg.info(f'Saving directory {saving_dir} created') + else: + save = False # Update the simulation with step for num_steps while loop_count < num_steps: @@ -136,6 +147,10 @@ def _run(self, num_steps): break self.state, self.neighbors = self.step(state=self.state, neighbors=self.neighbors) + + if save: + # TODO : Cannot save the whole state because of the __getattr__ method (weird error when pickling) + frames.append(self.state.nve_state) loop_count += 1 # Sleep for updated sleep_time seconds @@ -143,10 +158,22 @@ def _run(self, num_steps): sleep_time = self.update_sleep_time(frequency=self.freq, elapsed_time=end-start) time.sleep(sleep_time) - # Encode that the simulation isn't started anymore + if save: + self.save_frames(frames, saving_dir) + + # Encode that the simulation isn't started anymore self._is_started = False lg.info('Run stops') + def save_frames(self, frames, saving_dir): + print(f"{type(frames) = }") + saving_path = f"{saving_dir}/frames.pkl" + frames.append(1) + print(frames[-1]) + with open(saving_path, 'wb') as f: + pickle.dump(frames, f) + lg.info('Simulation frames saved in {saving_path}') + def update_sleep_time(self, frequency, elapsed_time): """Compute the time we need to sleep to respect the update frequency From 0203273a8a6b05afdeda2b9c66153d2a2f2d7ac5 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Tue, 19 Mar 2024 11:56:25 +0100 Subject: [PATCH 2/7] update gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index b9c67cb..6be7df3 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,5 @@ profiler_stats .pylintrc test.ipynb +test.py Results/* \ No newline at end of file From cd60a84e7c708a8586ee1260a03ea34d12f86c6f Mon Sep 17 00:00:00 2001 From: corentinlger Date: Tue, 19 Mar 2024 11:57:39 +0100 Subject: [PATCH 3/7] Update saving mechanism with start and stop recording --- vivarium/simulator/simulator.py | 110 ++++++++++++++++++++++++-------- 1 file changed, 82 insertions(+), 28 deletions(-) diff --git a/vivarium/simulator/simulator.py b/vivarium/simulator/simulator.py index 6ab07d5..fbd7670 100644 --- a/vivarium/simulator/simulator.py +++ b/vivarium/simulator/simulator.py @@ -4,6 +4,7 @@ import math import logging import pickle +import datetime from functools import partial from contextlib import contextmanager @@ -37,6 +38,11 @@ def __init__(self, state, behavior_bank, dynamics_fn): self._to_stop = False self.key = jax.random.PRNGKey(0) + # Attributes to record simulation + self.recording = False + self.records = [] + self.saving_dir = None + # TODO: Define which attributes are affected but these functions self.update_space(self.box_size) self.update_function_update() @@ -74,8 +80,72 @@ def select_simulation_loop_type(self): """ if self.state.simulator_state.use_fori_loop: return self.lax_simulation_loop + + return self.classic_simulation_loop + + def start_recording(self, saving_name): + """Start the recording of the simulation + + :param saving_name: optional name of the saving file + """ + if self.recording: + lg.warning('Already recording') + self.recording = True + + # Either create a savinf_dir with the given name or one with the current datetime + if saving_name: + saving_dir = f"Results/{saving_name}" else: - return self.classic_simulation_loop + current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + saving_dir = f"Results/experiment_{current_time}" + + self.saving_dir = saving_dir + # Create a saving dir if it doesn't exist yet, TODO : Add a warning if risk of overwritting already existing content + os.makedirs(self.saving_dir, exist_ok=True) + lg.info('Saving directory %s created', self.saving_dir) + + def record(self, data): + """Record the desired data during a step + + :param data: saved data (e.g simulator.state) + """ + if not self.recording: + lg.warning('Recording not started yet.') + return + self.records.append(data) + + def save_records(self): + """Save the recorded steps in a pickle file""" + if not self.records: + lg.warning('No records to save.') + return + + saving_path = f"{self.saving_dir}/frames.pkl" + with open(saving_path, 'wb') as f: + pickle.dump(self.records, f) + lg.info('Simulation frames saved in %s', saving_path) + + def stop_recording(self): + """Stop the recording, save the recorded steps and reset recording information""" + if not self.recording: + lg.warning('Recording not started yet.') + return + + self.save_records() + self.recording = False + self.records = [] + + def load(self, saving_name): + """Load data corresponding to saving_name + + :param saving_name: name used while saving the data + :return: loaded data + """ + saving_path = f"Results/{saving_name}/frames.pkl" + with open(saving_path, 'rb') as f: + data = pickle.load(f) + lg.info('Simulation loaded from %s', saving_path) + return data def step(self, state, neighbors): """Do a step in the simulation by applying the update function a few iterations on the state and the neighbors @@ -97,10 +167,13 @@ def step(self, state, neighbors): new_state, neighbors = self.simulation_loop(state=current_state, neighbors=neighbors, num_iterations=self.num_steps_lax) # Check that neighbors array is now ok but should be the case (allocate neighbors tries to compute a new list that is large enough according to the simulation state) assert not neighbors.did_buffer_overflow + + if self.recording: + self.record(new_state.nve_state) return new_state, neighbors - def run(self, threaded=False, num_steps=math.inf, saving_name=None): + def run(self, threaded=False, num_steps=math.inf, save=False, saving_name=None): """Run the simulator for the desired number of timesteps, either in a separate thread or not :param threaded: wether to run the simulation in a thread or not, defaults to False @@ -113,12 +186,12 @@ def run(self, threaded=False, num_steps=math.inf, saving_name=None): # Else run it either in a thread or not if threaded: # Set the num_loops attribute with a partial func to launch _run in a thread - _run = partial(self._run, num_steps=num_steps, saving_name=saving_name) + _run = partial(self._run, num_steps=num_steps, save=save, saving_name=saving_name) threading.Thread(target=_run).start() else: - self._run(num_steps=num_steps, saving_name=saving_name) + self._run(num_steps=num_steps, save=save, saving_name=saving_name) - def _run(self, num_steps, saving_name): + def _run(self, num_steps, save, saving_name): """Function that runs the simulator for the desired number of steps. Used to be called either normally or in a thread. :param num_steps: number of simulation steps @@ -130,14 +203,8 @@ def _run(self, num_steps, saving_name): loop_count = 0 sleep_time = 0 - if saving_name: - save = True - frames = [] - saving_dir = f"Results/{saving_name}" - os.makedirs(saving_dir, exist_ok=True) - lg.info(f'Saving directory {saving_dir} created') - else: - save = False + if save: + self.start_recording(saving_name) # Update the simulation with step for num_steps while loop_count < num_steps: @@ -147,11 +214,7 @@ def _run(self, num_steps, saving_name): break self.state, self.neighbors = self.step(state=self.state, neighbors=self.neighbors) - - if save: - # TODO : Cannot save the whole state because of the __getattr__ method (weird error when pickling) - frames.append(self.state.nve_state) - loop_count += 1 + loop_count += 1 # Sleep for updated sleep_time seconds end = time.time() @@ -159,21 +222,12 @@ def _run(self, num_steps, saving_name): time.sleep(sleep_time) if save: - self.save_frames(frames, saving_dir) + self.stop_recording() # Encode that the simulation isn't started anymore self._is_started = False lg.info('Run stops') - def save_frames(self, frames, saving_dir): - print(f"{type(frames) = }") - saving_path = f"{saving_dir}/frames.pkl" - frames.append(1) - print(frames[-1]) - with open(saving_path, 'wb') as f: - pickle.dump(frames, f) - lg.info('Simulation frames saved in {saving_path}') - def update_sleep_time(self, frequency, elapsed_time): """Compute the time we need to sleep to respect the update frequency From 1bdf86cdee039839398c3b8cf3ec9fad29920867 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Tue, 19 Mar 2024 11:58:46 +0100 Subject: [PATCH 4/7] Use saving in run_simulation --- scripts/run_simulation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/run_simulation.py b/scripts/run_simulation.py index 36d09fb..4619eda 100644 --- a/scripts/run_simulation.py +++ b/scripts/run_simulation.py @@ -15,6 +15,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='Simulator Configuration') # Experiment run arguments parser.add_argument('--num_steps', type=int, default=10, help='Number of simulation loops') + parser.add_argument('--save', action='store_true', help='Save the simulation or not') parser.add_argument('--saving_name', type=str, default='', help='Name of the saving directory') # Simulator config arguments parser.add_argument('--box_size', type=float, default=100.0, help='Size of the simulation box') @@ -78,6 +79,6 @@ def parse_args(): lg.info("Running simulation") - simulator.run(threaded=False, num_steps=args.num_steps, saving_name=args.saving_name) + simulator.run(threaded=False, num_steps=args.num_steps, save=args.save, saving_name=args.saving_name) lg.info("Simulation complete") From e610d6e85db9f48b85f6acdf5bb83307c95b3087 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Tue, 19 Mar 2024 11:59:19 +0100 Subject: [PATCH 5/7] Add first dummy test for saving and loading --- tests/test_saving_loading.py | 97 ++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 tests/test_saving_loading.py diff --git a/tests/test_saving_loading.py b/tests/test_saving_loading.py new file mode 100644 index 0000000..00553e3 --- /dev/null +++ b/tests/test_saving_loading.py @@ -0,0 +1,97 @@ +# !!! TODO : currently the file isn't testing anything because there are problems while comparing jax objects +# !!! + numerical errors while saving and loading with pickle, a better option would be to directly use jax / flax options to save + +import logging + +import jax.numpy as jnp +import numpy as np + +from vivarium.simulator import behaviors +from vivarium.simulator.sim_computation import dynamics_rigid, StateType +from vivarium.controllers.config import AgentConfig, ObjectConfig, SimulatorConfig +from vivarium.controllers import converters +from vivarium.simulator.simulator import Simulator + +lg = logging.getLogger(__name__) + +num_steps = 10 +save = False +saving_name = '' +box_size = 100.0 +n_agents = 10 +n_objects = 2 +num_steps_lax = 4 +dt = 0.1 +freq = 40.0 +neighbor_radius = 100.0 +to_jit = True +use_fori_loop = False +log_level = 'INFO' + + +logging.basicConfig(level=log_level.upper()) + +simulator_config = SimulatorConfig( + box_size=box_size, + n_agents=n_agents, + n_objects=n_objects, + num_steps_lax=num_steps_lax, + dt=dt, + freq=freq, + neighbor_radius=neighbor_radius, + to_jit=to_jit, + use_fori_loop=use_fori_loop +) + +agent_configs = [ + AgentConfig(idx=i, + x_position=np.random.rand() * simulator_config.box_size, + y_position=np.random.rand() * simulator_config.box_size, + orientation=np.random.rand() * 2. * np.pi) + for i in range(simulator_config.n_agents) + ] + +object_configs = [ + ObjectConfig(idx=simulator_config.n_agents + i, + x_position=np.random.rand() * simulator_config.box_size, + y_position=np.random.rand() * simulator_config.box_size, + orientation=np.random.rand() * 2. * np.pi) + for i in range(simulator_config.n_objects) + ] + +state = converters.set_state_from_config_dict( + { + StateType.AGENT: agent_configs, + StateType.OBJECT: object_configs, + StateType.SIMULATOR: [simulator_config] + } + ) + +num_steps = 10 +saving_name = "test_dir" +nve_states = [] + +simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) +assert not simulator.recording + +lg.info("Running simulation") +simulator.start_recording(saving_name) +assert simulator.recording + +# Run the simulation for num_steps and save the nve_state +for _ in range(num_steps): + simulator.state, simulator.neighbors = simulator.step(simulator.state, simulator.neighbors) + nve_states.append(state.nve_state) + +simulator.stop_recording() +assert not simulator.recording +assert not simulator.records + +loaded_nve_states = simulator.load(saving_name) +assert loaded_nve_states + +lg.info("Simulation complete") + +# At the momet the saving and the loading works but +for state, loaded_state in zip(nve_states, loaded_nve_states): + assert jnp.array_equal(state.position, loaded_state.position) From ef532e3c52cba6b088352eaf9c2219464306c617 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 21 Mar 2024 17:52:12 +0100 Subject: [PATCH 6/7] Change the way records is instanciated --- vivarium/simulator/simulator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vivarium/simulator/simulator.py b/vivarium/simulator/simulator.py index fbd7670..0bcd69c 100644 --- a/vivarium/simulator/simulator.py +++ b/vivarium/simulator/simulator.py @@ -40,7 +40,7 @@ def __init__(self, state, behavior_bank, dynamics_fn): # Attributes to record simulation self.recording = False - self.records = [] + self.records = None self.saving_dir = None # TODO: Define which attributes are affected but these functions @@ -91,6 +91,7 @@ def start_recording(self, saving_name): if self.recording: lg.warning('Already recording') self.recording = True + self.records = [] # Either create a savinf_dir with the given name or one with the current datetime if saving_name: From 511f78301924ff1b9a4287ec6850a571f6ab408e Mon Sep 17 00:00:00 2001 From: corentinlger Date: Wed, 27 Mar 2024 20:00:56 +0100 Subject: [PATCH 7/7] Update tests with new step function --- tests/test_saving_loading.py | 132 +++++++++++++++++------------------ 1 file changed, 65 insertions(+), 67 deletions(-) diff --git a/tests/test_saving_loading.py b/tests/test_saving_loading.py index 00553e3..803ad98 100644 --- a/tests/test_saving_loading.py +++ b/tests/test_saving_loading.py @@ -3,8 +3,8 @@ import logging -import jax.numpy as jnp import numpy as np +import jax.numpy as jnp from vivarium.simulator import behaviors from vivarium.simulator.sim_computation import dynamics_rigid, StateType @@ -26,72 +26,70 @@ neighbor_radius = 100.0 to_jit = True use_fori_loop = False -log_level = 'INFO' - - -logging.basicConfig(level=log_level.upper()) - -simulator_config = SimulatorConfig( - box_size=box_size, - n_agents=n_agents, - n_objects=n_objects, - num_steps_lax=num_steps_lax, - dt=dt, - freq=freq, - neighbor_radius=neighbor_radius, - to_jit=to_jit, - use_fori_loop=use_fori_loop -) - -agent_configs = [ - AgentConfig(idx=i, - x_position=np.random.rand() * simulator_config.box_size, - y_position=np.random.rand() * simulator_config.box_size, - orientation=np.random.rand() * 2. * np.pi) - for i in range(simulator_config.n_agents) - ] - -object_configs = [ - ObjectConfig(idx=simulator_config.n_agents + i, - x_position=np.random.rand() * simulator_config.box_size, - y_position=np.random.rand() * simulator_config.box_size, - orientation=np.random.rand() * 2. * np.pi) - for i in range(simulator_config.n_objects) - ] - -state = converters.set_state_from_config_dict( - { - StateType.AGENT: agent_configs, - StateType.OBJECT: object_configs, - StateType.SIMULATOR: [simulator_config] - } - ) - -num_steps = 10 -saving_name = "test_dir" -nve_states = [] - -simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) -assert not simulator.recording -lg.info("Running simulation") -simulator.start_recording(saving_name) -assert simulator.recording -# Run the simulation for num_steps and save the nve_state -for _ in range(num_steps): - simulator.state, simulator.neighbors = simulator.step(simulator.state, simulator.neighbors) - nve_states.append(state.nve_state) - -simulator.stop_recording() -assert not simulator.recording -assert not simulator.records - -loaded_nve_states = simulator.load(saving_name) -assert loaded_nve_states - -lg.info("Simulation complete") +def test_saving_loading(): + simulator_config = SimulatorConfig( + box_size=box_size, + n_agents=n_agents, + n_objects=n_objects, + num_steps_lax=num_steps_lax, + dt=dt, + freq=freq, + neighbor_radius=neighbor_radius, + to_jit=to_jit, + use_fori_loop=use_fori_loop + ) -# At the momet the saving and the loading works but -for state, loaded_state in zip(nve_states, loaded_nve_states): - assert jnp.array_equal(state.position, loaded_state.position) + agent_configs = [ + AgentConfig(idx=i, + x_position=np.random.rand() * simulator_config.box_size, + y_position=np.random.rand() * simulator_config.box_size, + orientation=np.random.rand() * 2. * np.pi) + for i in range(simulator_config.n_agents) + ] + + object_configs = [ + ObjectConfig(idx=simulator_config.n_agents + i, + x_position=np.random.rand() * simulator_config.box_size, + y_position=np.random.rand() * simulator_config.box_size, + orientation=np.random.rand() * 2. * np.pi) + for i in range(simulator_config.n_objects) + ] + + state = converters.set_state_from_config_dict( + { + StateType.AGENT: agent_configs, + StateType.OBJECT: object_configs, + StateType.SIMULATOR: [simulator_config] + } + ) + + saving_name = "test_dir" + nve_states = [] + + simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) + assert not simulator.recording + + lg.info("Running simulation") + simulator.start_recording(saving_name) + assert simulator.recording + + # Run the simulation for num_steps and save the nve_state + for _ in range(num_steps): + simulator.step() + nve_states.append(simulator.state.nve_state) + + simulator.stop_recording() + assert not simulator.recording + assert not simulator.records + + loaded_nve_states = simulator.load(saving_name) + assert loaded_nve_states + + lg.info("Simulation complete") + + # At the momet the saving and the loading works but there are numerical errors + for state, loaded_state in zip(nve_states, loaded_nve_states): + # This will therefore raise an error in the test : + assert jnp.array_equal(state.position, loaded_state.position)