Skip to content

Commit

Permalink
Update docs and clean files
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Nov 27, 2024
1 parent 9f8e2ba commit ce9af5a
Show file tree
Hide file tree
Showing 11 changed files with 318 additions and 191 deletions.
24 changes: 19 additions & 5 deletions vivarium/controllers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@

from vivarium.simulator.simulator_states import StateType


mass = monomer.mass()
mass_center = float(mass.center[0])
mass_orientation = float(mass.orientation[0])


class Config(Parameterized):

"""Base class for configuration objects"""
def to_dict(self, params=None):
"""Return a dictionary with the configuration parameters
:param params: params, defaults to None
:return: dictionary with the configuration parameters
"""
d = self.param.values()
del d['name']
if params is not None:
Expand All @@ -22,26 +26,34 @@ def to_dict(self, params=None):
return d

def param_names(self):
"""Return the names of the configuration parameters
:return: list of parameter names
"""
return list(self.to_dict().keys())

def json(self):
"""Return a JSON representation of the configuration
:return: JSON representation of the configuration
"""
return self.param.serialize_parameters(subset=self.param_names())


class AgentConfig(Config):
"""Configuration class for agents"""
idx = param.Integer()
# ent_sensedtype = param.Integer()
x_position = param.Number(0.)
y_position = param.Number(0.)
orientation = param.Number(0.)
mass_center = param.Number(mass_center)
mass_orientation = param.Number(mass_orientation)
# TODO : Change behavior back to a list of objects
# TODO : Change the behaviors to a list of objects in the future
behavior = param.Array(np.array([0.]))
left_motor = param.Number(0., bounds=(0., 1.))
right_motor = param.Number(0., bounds=(0., 1.))
# TODO : Will be problems here if proximeters if non occlusion mode (as many proximeter values as neighbors)
# TODO : Except if we only consider the non occlusion case where the sensors information is just the sensor of closest entity
# TODO : Will be problems here if proximeters if non occlusion mode (as many proximeter values as neighbors), except if we only consider the non occlusion case where the sensors information is just the sensor of closest entity
left_prox = param.Number(0., bounds=(0., 1.))
right_prox = param.Number(0., bounds=(0., 1.))
prox_sensed_ent_type = param.Array(np.array([0]))
Expand All @@ -67,6 +79,7 @@ def __init__(self, **params):


class ObjectConfig(Config):
"""Configuration class for objects"""
idx = param.Integer()
# ent_sensedtype = param.Integer()
x_position = param.Number(0.)
Expand All @@ -85,6 +98,7 @@ def __init__(self, **params):


class SimulatorConfig(Config):
"""Configuration class for the simulator"""
idx = param.Integer(0, constant=True)
time = param.Integer(0)
box_size = param.Number(100., bounds=(0, None))
Expand Down
238 changes: 136 additions & 102 deletions vivarium/controllers/converters.py

Large diffs are not rendered by default.

29 changes: 21 additions & 8 deletions vivarium/controllers/panel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,28 @@


class PanelConfig(Config):
"""Base class for panel configurations"""
pass


class PanelEntityConfig(PanelConfig):
"""Base class for panel configurations of entities"""
visible = param.Boolean(True)


class PanelAgentConfig(PanelEntityConfig):
"""Base class for panel configurations of agents"""
visible_wheels = param.Boolean(True)
visible_proxs = param.Boolean(True)


class PanelObjectConfig(PanelEntityConfig):
"""Base class for panel configurations of objects"""
pass


class PanelSimulatorConfig(Config):
"""Base class for panel configurations of the simulator"""
hide_non_existing = param.Boolean(True)
config_update = param.Boolean(False)


# Mapping between config classes and their corresponding state types
panel_config_to_stype = {
PanelSimulatorConfig: StateType.SIMULATOR,
PanelAgentConfig: StateType.AGENT,
Expand All @@ -46,6 +47,7 @@ class PanelSimulatorConfig(Config):


class Selected(param.Parameterized):
"""Class to store the selected entities in the interface"""
selection = param.ListSelector([0], objects=[0])

def selection_nve_idx(self, ent_idx):
Expand All @@ -56,7 +58,7 @@ def __len__(self):


class PanelController(SimulatorController):

"""Controller for the panel interface"""
def __init__(self, **params):
self._selected_configs_watchers = None
self._selected_panel_configs_watchers = None
Expand All @@ -83,17 +85,22 @@ def trigger_hide_non_existing(self):
self.panel_simulator_config.hide_non_existing = True

def watch_selected_configs(self):
watchers = {etype: config.param.watch(self.push_selected_to_config_list, config.param_names(), onlychanged=True)
for etype, config in self.selected_configs.items()}
"""Watch the selected configurations"""
watchers = {
etype: config.param.watch(self.push_selected_to_config_list, config.param_names(), onlychanged=True)
for etype, config in self.selected_configs.items()
}
return watchers

def watch_selected_panel_configs(self):
"""Watch the selected panel configurations"""
watchers = {etype: config.param.watch(self.push_selected_to_config_list, config.param_names(), onlychanged=True)
for etype, config in self.selected_panel_configs.items()}
return watchers

@contextmanager
def dont_push_selected_configs(self):
"""Context manager to avoid pushing the selected configurations"""
if self._selected_configs_watchers is not None:
for etype, config in self.selected_configs.items():
config.param.unwatch(self._selected_configs_watchers[etype])
Expand All @@ -104,6 +111,7 @@ def dont_push_selected_configs(self):

@contextmanager
def dont_push_selected_panel_configs(self):
"""Context manager to avoid pushing the selected panel configurations"""
if self._selected_panel_configs_watchers is not None:
for etype, config in self.selected_panel_configs.items():
config.param.unwatch(self._selected_panel_configs_watchers[etype])
Expand All @@ -113,11 +121,13 @@ def dont_push_selected_panel_configs(self):
self._selected_panel_configs_watchers = self.watch_selected_panel_configs()

def update_entity_list(self, *events):
"""Update the entity list"""
state = self.state
for etype, selected in self.selected_entities.items():
selected.param.selection.objects = state.entity_idx(etype).tolist()

def pull_selected_configs(self, *events):
"""Pull the selected configurations"""
state = self.state
config_dict = {etype.to_state_type(): [config] for etype, config in self.selected_configs.items()}
with self.dont_push_selected_configs():
Expand All @@ -128,15 +138,18 @@ def pull_selected_configs(self, *events):
return state

def pull_selected_panel_configs(self, *events):
"""Pull the selected panel configurations"""
with self.dont_push_selected_panel_configs():
for etype, panel_config in self.selected_panel_configs.items():
panel_config.param.update(**self.panel_configs[etype.to_state_type()][self.selected_entities[etype].selection[0]].to_dict())

def pull_all_data(self):
"""Pull all the data from the simulator"""
self.pull_selected_configs()
self.pull_configs({StateType.SIMULATOR: self.configs[StateType.SIMULATOR]})

def push_selected_to_config_list(self, *events):
"""Push the selected configurations to the configuration list"""
lg.info("Push_selected_to_config_list %d", len(events))
for e in events:
if isinstance(e.obj, PanelConfig):
Expand Down
17 changes: 14 additions & 3 deletions vivarium/controllers/simulator_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@


class SimulatorController(param.Parameterized):
"""Base controller class to interact with the simulator.
"""
"""Base controller class to interact with the simulator."""
configs = param.Dict({StateType.SIMULATOR: SimulatorConfig(), StateType.AGENT: [], StateType.OBJECT: []})
refresh_change_period = param.Number(1)
change_time = param.Integer(0)
Expand All @@ -35,6 +34,7 @@ def __init__(self, client=None, **params):
self.subtypes_labels = self.client.subtypes_labels

def watch_configs(self):
"""Watch the parameters of the configs to push the changes to the simulator."""
watchers = {etype: [config.param.watch(self.push_state, config.param_names(), onlychanged=True) for config in configs] for etype, configs in self.configs.items()}
return watchers

Expand All @@ -43,6 +43,7 @@ def simulator_config(self):
return self.configs[StateType.SIMULATOR][0]

def push_state(self, *events):
"""Push the state changes to the simulator."""
if self._in_batch:
self._event_list.extend(events)
return
Expand All @@ -53,6 +54,7 @@ def push_state(self, *events):

@contextmanager
def dont_push_entity_configs(self):
"""Context manager to avoid pushing the entity configs to the simulator."""
for etype, configs in self.configs.items():
for i, config in enumerate(configs):
config.param.unwatch(self._config_watchers[etype][i])
Expand All @@ -63,6 +65,7 @@ def dont_push_entity_configs(self):

@contextmanager
def batch_set_state(self):
"""Context manager to set the state changes in batch."""
self._in_batch = True
self._event_list = []
try:
Expand All @@ -73,35 +76,43 @@ def batch_set_state(self):
self._event_list = None

def pull_all_data(self):
"""Pull all the data from the simulator."""
self.pull_configs()

def pull_configs(self, configs=None):
"""Pull the configurations."""
configs = configs or self.configs
state = self.state
# lg.debug(f"Pull_configs; {state = } {configs = }")
with self.dont_push_entity_configs():
converters.set_configs_from_state(state, configs)
return state

def is_started(self):
"""Check if the simulator is started."""
return self.client.is_started()

def start(self):
"""Start the simulator."""
self.client.start()

def stop(self):
"""Stop the simulator."""
self.client.stop()

def update_state(self):
"""Update the state of the simulator."""
self.state = self.client.get_state()
return self.state

def get_nve_state(self):
"""Get the NVE state of the simulator."""
self.state = self.client.get_nve_state()
return self.state

def get_scene_name(self):
"""Get the scene name of the simulator."""
self.client.get_scene_name()

def get_subtypes_labels(self):
"""Get the subtypes labels of the simulator."""
self.client.get_subtypes_labels()
3 changes: 1 addition & 2 deletions vivarium/controllers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

class Logger(object):
def __init__(self):
"""Logger class that logs data for the agents
"""
"""Logger class that logs data for the agents"""
self.logs = {}

def add(self, log_field, data):
Expand Down
30 changes: 12 additions & 18 deletions vivarium/environments/base_env.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
# TODO : Update this file to make it match with current architecture
import logging as lg

from functools import partial
from typing import Tuple

import jax.numpy as jnp

from jax import jit
from flax import struct
from jax_md.dataclasses import dataclass as md_dataclass


@struct.dataclass
@md_dataclass
class BaseState:
time: jnp.int32
box_size: jnp.int32

@md_dataclass
class Neighbors:
neighbors: jnp.array
agents_neighs_idx: jnp.array
agents_idx_dense: jnp.array


class BaseEnv:
Expand All @@ -24,23 +28,13 @@ def init_state(self) -> BaseState:
raise(NotImplementedError)

@partial(jit, static_argnums=(0,))
def _step(self, state: BaseState, neighbors: jnp.array) -> Tuple[BaseState, jnp.array]:
def _step_env(self, state: BaseState, neighbors_storage: Neighbors) -> Tuple[BaseState, Neighbors]:
raise(NotImplementedError)

def step(self, state: BaseState) -> BaseState:
current_state = state
state, neighbors = self._step(current_state, self.neighbors)

if self.neighbors.did_buffer_overflow:
# reallocate neighbors and run the simulation from current_state
lg.warning('BUFFER OVERFLOW: rebuilding neighbors')
neighbors = self.allocate_neighbors(state)
assert not neighbors.did_buffer_overflow

self.neighbors = neighbors
return state
def step(self, state: BaseState, num_updates) -> BaseState:
raise(NotImplementedError)

def allocate_neighbors(self, state, position=None):
def allocate_neighbors(self, state: BaseState, position=None):
position = state.entities.position.center if position is None else position
neighbors = self.neighbor_fn.allocate(position)
return neighbors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,9 +368,9 @@ def _step_env(self, state: State, neighbors_storage: Neighbors) -> Tuple[State,
)

# Update the entities and the state
state = state.replace(agents=agents)
state = state.set(agents=agents)
entities = self.apply_physics(state, neighbors)
state = state.replace(time=state.time+1, entities=entities)
state = state.set(time=state.time+1, entities=entities)

# Update the neighbors storage
neighbors = neighbors.update(state.entities.position.center)
Expand Down
Loading

0 comments on commit ce9af5a

Please sign in to comment.