Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add first saving functions in local simulation run #50

Closed
wants to merge 9 commits into from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ myvenv/
ENV/
env.bak/
venv.bak/
myvenv/

# Spyder project settings
.spyderproject
Expand Down
4 changes: 3 additions & 1 deletion scripts/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
def parse_args():
parser = argparse.ArgumentParser(description='Simulator Configuration')
# Experiment run arguments
parser.add_argument('--save', action='store_true', help='Save the simulation or not')
parser.add_argument('--saving_name', type=str, default='', help='Name of the saving directory')
parser.add_argument('--log_level', type=str, default='INFO', help='Logging level')
parser.add_argument('--num_steps', type=int, default=10, help='Number of simulation steps')
# Simulator config arguments
Expand Down Expand Up @@ -81,6 +83,6 @@ def parse_args():

lg.info("Running simulation")

simulator.run(threaded=False, num_steps=args.num_steps)
simulator.run(threaded=False, num_steps=args.num_steps, save=args.save, saving_name=args.saving_name)

lg.info("Simulation complete")
95 changes: 95 additions & 0 deletions tests/test_saving_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# !!! TODO : currently the file isn't testing anything because there are problems while comparing jax objects
# !!! + numerical errors while saving and loading with pickle, a better option would be to directly use jax / flax options to save

import logging

import numpy as np
import jax.numpy as jnp

from vivarium.simulator import behaviors
from vivarium.simulator.sim_computation import dynamics_rigid, StateType
from vivarium.controllers.config import AgentConfig, ObjectConfig, SimulatorConfig
from vivarium.controllers import converters
from vivarium.simulator.simulator import Simulator

lg = logging.getLogger(__name__)

num_steps = 10
save = False
saving_name = ''
box_size = 100.0
n_agents = 10
n_objects = 2
num_steps_lax = 4
dt = 0.1
freq = 40.0
neighbor_radius = 100.0
to_jit = True
use_fori_loop = False


def test_saving_loading():
simulator_config = SimulatorConfig(
box_size=box_size,
n_agents=n_agents,
n_objects=n_objects,
num_steps_lax=num_steps_lax,
dt=dt,
freq=freq,
neighbor_radius=neighbor_radius,
to_jit=to_jit,
use_fori_loop=use_fori_loop
)

agent_configs = [
AgentConfig(idx=i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_agents)
]

object_configs = [
ObjectConfig(idx=simulator_config.n_agents + i,
x_position=np.random.rand() * simulator_config.box_size,
y_position=np.random.rand() * simulator_config.box_size,
orientation=np.random.rand() * 2. * np.pi)
for i in range(simulator_config.n_objects)
]

state = converters.set_state_from_config_dict(
{
StateType.AGENT: agent_configs,
StateType.OBJECT: object_configs,
StateType.SIMULATOR: [simulator_config]
}
)

saving_name = "test_dir"
nve_states = []

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)
assert not simulator.recording

lg.info("Running simulation")
simulator.start_recording(saving_name)
assert simulator.recording

# Run the simulation for num_steps and save the nve_state
for _ in range(num_steps):
simulator.step()
nve_states.append(simulator.state.nve_state)

simulator.stop_recording()
assert not simulator.recording
assert not simulator.records

loaded_nve_states = simulator.load(saving_name)
assert loaded_nve_states

lg.info("Simulation complete")

# At the momet the saving and the loading works but there are numerical errors
for state, loaded_state in zip(nve_states, loaded_nve_states):
# This will therefore raise an error in the test :
assert jnp.array_equal(state.position, loaded_state.position)
94 changes: 88 additions & 6 deletions vivarium/simulator/simulator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import time
import threading
import math
import logging
import pickle
import datetime

from functools import partial
from contextlib import contextmanager
Expand Down Expand Up @@ -35,6 +38,11 @@ def __init__(self, state, behavior_bank, dynamics_fn):
self._to_stop = False
self.key = jax.random.PRNGKey(0)

# Attributes to record simulation
self.recording = False
self.records = None
self.saving_dir = None

# TODO: Define which attributes are affected but these functions
self.update_space(self.box_size)
self.update_function_update()
Expand Down Expand Up @@ -72,8 +80,73 @@ def select_simulation_loop_type(self):
"""
if self.state.simulator_state.use_fori_loop:
return self.lax_simulation_loop

return self.classic_simulation_loop

def start_recording(self, saving_name):
"""Start the recording of the simulation

:param saving_name: optional name of the saving file
"""
if self.recording:
lg.warning('Already recording')
self.recording = True
self.records = []

# Either create a savinf_dir with the given name or one with the current datetime
if saving_name:
saving_dir = f"Results/{saving_name}"
else:
return self.classic_simulation_loop
current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
saving_dir = f"Results/experiment_{current_time}"

self.saving_dir = saving_dir
# Create a saving dir if it doesn't exist yet, TODO : Add a warning if risk of overwritting already existing content
os.makedirs(self.saving_dir, exist_ok=True)
lg.info('Saving directory %s created', self.saving_dir)

def record(self, data):
"""Record the desired data during a step

:param data: saved data (e.g simulator.state)
"""
if not self.recording:
lg.warning('Recording not started yet.')
return
self.records.append(data)

def save_records(self):
"""Save the recorded steps in a pickle file"""
if not self.records:
lg.warning('No records to save.')
return

saving_path = f"{self.saving_dir}/frames.pkl"
with open(saving_path, 'wb') as f:
pickle.dump(self.records, f)
lg.info('Simulation frames saved in %s', saving_path)

def stop_recording(self):
"""Stop the recording, save the recorded steps and reset recording information"""
if not self.recording:
lg.warning('Recording not started yet.')
return

self.save_records()
self.recording = False
self.records = []

def load(self, saving_name):
"""Load data corresponding to saving_name

:param saving_name: name used while saving the data
:return: loaded data
"""
saving_path = f"Results/{saving_name}/frames.pkl"
with open(saving_path, 'rb') as f:
data = pickle.load(f)
lg.info('Simulation loaded from %s', saving_path)
return data

def _step(self, state, neighbors, num_iterations):
"""Do a step in the simulation by applying the update function a few iterations on the state and the neighbors
Expand All @@ -95,6 +168,9 @@ def _step(self, state, neighbors, num_iterations):
new_state, neighbors = self.simulation_loop(state=current_state, neighbors=neighbors, num_iterations=num_iterations)
# Check that neighbors array is now ok but should be the case (allocate neighbors tries to compute a new list that is large enough according to the simulation state)
assert not neighbors.did_buffer_overflow

if self.recording:
self.record(new_state.nve_state)

return new_state, neighbors

Expand All @@ -104,7 +180,7 @@ def step(self):
num_iterations = self.num_steps_lax
self.state, self.neighbors = self._step(state, neighbors, num_iterations)

def run(self, threaded=False, num_steps=math.inf):
def run(self, threaded=False, num_steps=math.inf, save=False, saving_name=None):
"""Run the simulator for the desired number of timesteps, either in a separate thread or not

:param threaded: wether to run the simulation in a thread or not, defaults to False
Expand All @@ -117,12 +193,12 @@ def run(self, threaded=False, num_steps=math.inf):
# Else run it either in a thread or not
if threaded:
# Set the num_loops attribute with a partial func to launch _run in a thread
_run = partial(self._run, num_steps=num_steps)
_run = partial(self._run, num_steps=num_steps, save=save, saving_name=saving_name)
threading.Thread(target=_run).start()
else:
self._run(num_steps)
self._run(num_steps=num_steps, save=save, saving_name=saving_name)

def _run(self, num_steps):
def _run(self, num_steps, save, saving_name):
"""Function that runs the simulator for the desired number of steps. Used to be called either normally or in a thread.

:param num_steps: number of simulation steps
Expand All @@ -133,6 +209,9 @@ def _run(self, num_steps):

loop_count = 0
sleep_time = 0

if save:
self.start_recording(saving_name)

# Update the simulation with step for num_steps
while loop_count < num_steps:
Expand All @@ -149,7 +228,10 @@ def _run(self, num_steps):
sleep_time = self.update_sleep_time(frequency=self.freq, elapsed_time=end-start)
time.sleep(sleep_time)

# Encode that the simulation isn't started anymore
if save:
self.stop_recording()

# Encode that the simulation isn't started anymore
self._is_started = False
lg.info('Run stops')

Expand Down
Loading