diff --git a/conf/scene/default.yaml b/conf/scene/default.yaml index bfbe8bc..a559b5b 100644 --- a/conf/scene/default.yaml +++ b/conf/scene/default.yaml @@ -34,3 +34,4 @@ simulator: num_steps_lax: 4 to_jit: true use_fori_loop: false + has_changed: false diff --git a/scripts/run_server.py b/scripts/run_server.py index feb6f98..240a400 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -1,14 +1,12 @@ import logging import hydra +import hydra.core +import hydra.core.global_hydra from omegaconf import DictConfig, OmegaConf from vivarium.simulator import behaviors -from vivarium.simulator.states import init_simulator_state -from vivarium.simulator.states import init_agent_state -from vivarium.simulator.states import init_object_state -from vivarium.simulator.states import init_entities_state -from vivarium.simulator.states import init_state +from vivarium.simulator.states import init_state_from_dict from vivarium.simulator.simulator import Simulator from vivarium.simulator.physics_engine import dynamics_rigid from vivarium.simulator.grpc_server.simulator_server import serve @@ -21,23 +19,14 @@ def main(cfg: DictConfig = None) -> None: args = OmegaConf.merge(cfg.default, cfg.scene) - simulator_state = init_simulator_state(**args.simulator) - - agents_state = init_agent_state(simulator_state=simulator_state, **args.agents) - - objects_state = init_object_state(simulator_state=simulator_state, **args.objects) - - entities_state = init_entities_state(simulator_state=simulator_state, **args.entities) - - state = init_state( - simulator_state=simulator_state, - agents_state=agents_state, - objects_state=objects_state, - entities_state=entities_state - ) + state = init_state_from_dict(args) simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid) + # necessary to be able to load other scenes + glob = hydra.core.global_hydra.GlobalHydra() + glob.clear() + lg.info('Simulator server started') serve(simulator) diff --git a/vivarium/controllers/config.py b/vivarium/controllers/config.py index 234d839..44d3002 100644 --- a/vivarium/controllers/config.py +++ b/vivarium/controllers/config.py @@ -86,6 +86,8 @@ class SimulatorConfig(Config): use_fori_loop = param.Boolean(False) collision_eps = param.Number(0.1) collision_alpha = param.Number(0.5) + has_changed = param.Boolean(False) + def __init__(self, **params): super().__init__(**params) diff --git a/vivarium/controllers/converters.py b/vivarium/controllers/converters.py index 82c7654..8925f3a 100644 --- a/vivarium/controllers/converters.py +++ b/vivarium/controllers/converters.py @@ -106,7 +106,8 @@ def get_default_state(n_entities_dict): neighbor_radius=jnp.array([1.]), to_jit= jnp.array([1]), use_fori_loop=jnp.array([0]), collision_alpha=jnp.array([0.]), - collision_eps=jnp.array([0.])), + collision_eps=jnp.array([0.]), + has_changed=jnp.array([0])), entities_state=EntitiesState(position=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), momentum=None, force=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)), diff --git a/vivarium/controllers/simulator_controller.py b/vivarium/controllers/simulator_controller.py index f2bceae..6b5ec91 100644 --- a/vivarium/controllers/simulator_controller.py +++ b/vivarium/controllers/simulator_controller.py @@ -104,6 +104,10 @@ def get_nve_state(self): self.state = self.client.get_nve_state() return self.state + def load_scene(self, scene): + self.client.load_scene(scene) + self.client.state = self.client.get_state() + self.__init__(client=self.client) if __name__ == "__main__": diff --git a/vivarium/interface/panel_app.py b/vivarium/interface/panel_app.py index db89e99..6c5e2bf 100644 --- a/vivarium/interface/panel_app.py +++ b/vivarium/interface/panel_app.py @@ -41,23 +41,45 @@ def __init__(self, config, panel_configs, panel_simulator_config, selected, etyp self.config[i].param.watch(self.hide_non_existing, "exists", onlychanged=True) def drag_cb(self, attr, old, new): + """Callback for the drag & drop of entities + + :param attr: (unused) + :param old: (unused) + :param new: The event containing the new positions of the entities + """ for i, c in enumerate(self.config): c.x_position = new['x'][i] c.y_position = new['y'][i] @contextmanager def no_drag_cb(self): + """Prevent the CDS from updating the configs when the change comes from the + server + """ self.cds.remove_on_change('data', self.drag_cb) yield self.cds.on_change('data', self.drag_cb) def get_cds_data(self, state): + """Update the ColumnDataSource with the new data + + :param state: The state coming from the server + :return: Data dictionary for the ColumnDataSource + """ raise NotImplementedError() def update_cds(self, state): + """Updates the ColumnDataSource with new data from server + + :param state: The state coming from the server + """ self.cds.data.update(self.get_cds_data(state)) def create_cds_view(self): + """Creates a ColumnDataSource view for each visibility attribute + + :return: A dictionary of ColumnDataSource views for each visibility attribute + """ # For each attribute in the panel config, create a filter # that is a logical AND of the visibility and the attribute return { @@ -67,20 +89,40 @@ def create_cds_view(self): } def update_cds_view(self, event): + """Updates the view of the ColumnDataSource if the visibility of an entity + changes + + :param event: The event containing the changed value + """ n = event.name for attr in [n] if n != "visible" else self.panel_configs[0].param_names(): f = [getattr(pc, attr) and pc.visible for pc in self.panel_configs] self.cds_view[attr].filter = BooleanFilter(f) def update_selected_plot(self, event): + """Updates the selected entities in the plot + + :param event: The event containing the new selected entities + """ self.cds.selected.indices = event.new def hide_all_non_existing(self, event): + """Hides or shows all the entities that do not exist according to the global + visibility of non-existing entities + + :param event: The event containing the new global "visibility of non-existing + entities" value + """ for i, pc in enumerate(self.panel_configs): if not self.config[i].exists: pc.visible = not event.new def hide_non_existing(self, event): + """Hides or shows an entity that does not exist depending on the global + visibility of non-existing entities + + :param event: The event containing the new existence value + """ if not self.panel_simulator_config.hide_non_existing: return idx = self.config.index(event.obj) @@ -88,11 +130,18 @@ def hide_non_existing(self, event): def update_selected_simulator(self): + """Updates the list of selected entities in the Selection list + """ indices = self.cds.selected.indices if len(indices) > 0 and indices != self.selected.selection: self.selected.selection = indices def plot(self, fig: figure): + """Plot the objects on the bokeh figure + + :param fig: A bokeh figure + :return: The figure with the objects plotted + """ raise NotImplementedError() @@ -205,6 +254,7 @@ class WindowManager(Parameterized): align="center", value=config_types[1:]) update_switch = pn.widgets.Switch(name="Update plot", value=True, align="center") update_timestep = pn.widgets.IntSlider(name="Timestep (ms)", value=1, start=1, end=1000) + scene_loader = pn.widgets.FileInput(accept=".yml", name="Load scene", align="center") def __init__(self, **kwargs): super().__init__(**kwargs) self.entity_manager_classes = {EntityType.AGENT: AgentManager, @@ -224,6 +274,10 @@ def __init__(self, **kwargs): self.set_callbacks() def start_toggle_cb(self, event): + """Callback for the start/stop button + + :param event: The event for the new value of the button + """ if event.new != self.controller.is_started(): if event.new: self.controller.start() @@ -237,12 +291,25 @@ def entity_toggle_cb(self, event): cc.visible = cc.name in event.new def update_timestep_cb(self, event): + """Callback for the timestep of the plot update + + :param event: The event for the new value of the timestep + """ self.pcb_plot.period = event.new def update_plot_cb(self): + """Periodic callback for the plot update + """ for em in self.entity_managers.values(): em.update_selected_simulator() state = self.controller.update_state() + # TODO: change this part to use a function, that could be reused by the load button + if state.simulator_state.has_changed: + self = WindowManager() + self.controller.update_entity_list() + self.plot = self.create_plot() + self.app = self.create_app() + self.controller.simulator_config.has_changed = False self.controller.pull_configs() if self.controller.panel_simulator_config.config_update: self.controller.pull_selected_configs() @@ -251,12 +318,20 @@ def update_plot_cb(self): em.update_cds(state) def update_switch_cb(self, event): + """Callback for the plot update switch + + :param event: The event for the new value of the switch + """ if event.new and not self.pcb_plot.running: self.pcb_plot.start() elif not event.new and self.pcb_plot.running: self.pcb_plot.stop() def create_plot(self): + """Creates a bokeh plot for the simulator + + :return: A bokeh plot + """ p_tools = "crosshair,pan,wheel_zoom,box_zoom,reset,tap,box_select,lasso_select" p = figure(tools=p_tools, active_drag="box_select") p.axis.major_label_text_font_size = "24px" @@ -270,6 +345,10 @@ def create_plot(self): return p def create_app(self): + """Creates a panel app + + :return: the panel app + """ self.config_columns = pn.Row(* [pn.Column( pn.pane.Markdown("### SIMULATOR", align="center"), @@ -287,7 +366,9 @@ def create_app(self): for etype in self.entity_managers.keys()]) app = pn.Row(pn.Column(pn.Row(pn.pane.Markdown("### Start/Stop server", align="center"), - self.start_toggle), + self.start_toggle, + pn.pane.Markdown("### Load scene", align="center"), + self.scene_loader), pn.Row(pn.pane.Markdown("### Start/Stop update", align="center"), self.update_switch, self.update_timestep), pn.panel(self.plot)), @@ -296,7 +377,9 @@ def create_app(self): return app def set_callbacks(self): - # putting directly the slider value causes bugs on some OS + """ + Set the callbacks for all the widgets in the app + """ self.pcb_plot = pn.state.add_periodic_callback(self.update_plot_cb, self.update_timestep.value) self.entity_toggle.param.watch(self.entity_toggle_cb, "value") diff --git a/vivarium/simulator/grpc_server/converters.py b/vivarium/simulator/grpc_server/converters.py index 182259b..4415bb6 100644 --- a/vivarium/simulator/grpc_server/converters.py +++ b/vivarium/simulator/grpc_server/converters.py @@ -25,7 +25,8 @@ def proto_to_simulator_state(simulator_state): to_jit=proto_to_ndarray(simulator_state.to_jit).astype(int), use_fori_loop=proto_to_ndarray(simulator_state.use_fori_loop).astype(int), collision_eps=proto_to_ndarray(simulator_state.collision_eps).astype(float), - collision_alpha=proto_to_ndarray(simulator_state.collision_alpha).astype(float) + collision_alpha=proto_to_ndarray(simulator_state.collision_alpha).astype(float), + has_changed=proto_to_ndarray(simulator_state.has_changed).astype(int) ) @@ -87,7 +88,8 @@ def simulator_state_to_proto(simulator_state): to_jit=ndarray_to_proto(simulator_state.to_jit), use_fori_loop=ndarray_to_proto(simulator_state.use_fori_loop), collision_eps=ndarray_to_proto(simulator_state.collision_eps), - collision_alpha=ndarray_to_proto(simulator_state.collision_alpha) + collision_alpha=ndarray_to_proto(simulator_state.collision_alpha), + has_changed=ndarray_to_proto(simulator_state.has_changed) ) diff --git a/vivarium/simulator/grpc_server/protos/simulator.proto b/vivarium/simulator/grpc_server/protos/simulator.proto index dc92806..9c42eb8 100644 --- a/vivarium/simulator/grpc_server/protos/simulator.proto +++ b/vivarium/simulator/grpc_server/protos/simulator.proto @@ -40,6 +40,7 @@ service SimulatorServer { rpc Stop(google.protobuf.Empty) returns (google.protobuf.Empty) {} + rpc LoadScene(Scene) returns (google.protobuf.Empty) {} } message AgentIdx { @@ -68,6 +69,7 @@ message SimulatorState { NDArray use_fori_loop = 10; NDArray collision_eps = 11; NDArray collision_alpha = 12; + NDArray has_changed = 13; } message EntitiesState { @@ -124,3 +126,7 @@ message AddAgentInput { message IsStartedState { bool is_started = 1; } + +message Scene { + string scene = 1; +} \ No newline at end of file diff --git a/vivarium/simulator/grpc_server/simulator_client.py b/vivarium/simulator/grpc_server/simulator_client.py index a43e7a7..c95b7b4 100644 --- a/vivarium/simulator/grpc_server/simulator_client.py +++ b/vivarium/simulator/grpc_server/simulator_client.py @@ -57,3 +57,7 @@ def step(self): def is_started(self): return self.stub.IsStarted(Empty()).is_started + + def load_scene(self, scene): + message = simulator_pb2.Scene(scene=scene) + return self.stub.LoadScene(message) diff --git a/vivarium/simulator/grpc_server/simulator_pb2.py b/vivarium/simulator/grpc_server/simulator_pb2.py index e459f53..7d79553 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.py +++ b/vivarium/simulator/grpc_server/simulator_pb2.py @@ -15,7 +15,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x03\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nmax_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0bmax_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x02\n\rEntitiesState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\xb7\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tmax_speed\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xc7\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12\x30\n\x0e\x65ntities_state\x18\x02 \x01(\x0b\x32\x18.simulator.EntitiesState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\">\n\rAddAgentInput\x12\x12\n\nmax_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\x32\xbb\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x41\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x18.simulator.EntitiesState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0fsimulator.proto\x12\tsimulator\x1a\x1bgoogle/protobuf/empty.proto\"\x17\n\x08\x41gentIdx\x12\x0b\n\x03idx\x18\x01 \x03(\x05\"\x1a\n\x07NDArray\x12\x0f\n\x07ndarray\x18\x01 \x01(\x0c\"X\n\tRigidBody\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0borientation\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\"\x92\x04\n\x0eSimulatorState\x12\x1f\n\x03idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62ox_size\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nmax_agents\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0bmax_objects\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rnum_steps_lax\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\x1e\n\x02\x64t\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04\x66req\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0fneighbor_radius\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06to_jit\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\ruse_fori_loop\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rcollision_eps\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\x12+\n\x0f\x63ollision_alpha\x18\x0c \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0bhas_changed\x18\r \x01(\x0b\x32\x12.simulator.NDArray\"\xe9\x02\n\rEntitiesState\x12&\n\x08position\x18\x01 \x01(\x0b\x32\x14.simulator.RigidBody\x12&\n\x08momentum\x18\x02 \x01(\x0b\x32\x14.simulator.RigidBody\x12#\n\x05\x66orce\x18\x03 \x01(\x0b\x32\x14.simulator.RigidBody\x12\"\n\x04mass\x18\x04 \x01(\x0b\x32\x14.simulator.RigidBody\x12$\n\x08\x64iameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12\'\n\x0b\x65ntity_type\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12&\n\nentity_idx\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x66riction\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12\"\n\x06\x65xists\x18\t \x01(\x0b\x32\x12.simulator.NDArray\"\xb7\x03\n\nAgentState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12 \n\x04prox\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05motor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\x12$\n\x08\x62\x65havior\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0ewheel_diameter\x18\x05 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tspeed_mul\x18\x06 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\tmax_speed\x18\x07 \x01(\x0b\x32\x12.simulator.NDArray\x12%\n\ttheta_mul\x18\x08 \x01(\x0b\x32\x12.simulator.NDArray\x12*\n\x0eproxs_dist_max\x18\t \x01(\x0b\x32\x12.simulator.NDArray\x12)\n\rproxs_cos_min\x18\n \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x0b \x01(\x0b\x32\x12.simulator.NDArray\"\x7f\n\x0bObjectState\x12#\n\x07nve_idx\x18\x01 \x01(\x0b\x32\x12.simulator.NDArray\x12(\n\x0c\x63ustom_field\x18\x02 \x01(\x0b\x32\x12.simulator.NDArray\x12!\n\x05\x63olor\x18\x03 \x01(\x0b\x32\x12.simulator.NDArray\"\xc7\x01\n\x05State\x12\x32\n\x0fsimulator_state\x18\x01 \x01(\x0b\x32\x19.simulator.SimulatorState\x12\x30\n\x0e\x65ntities_state\x18\x02 \x01(\x0b\x32\x18.simulator.EntitiesState\x12*\n\x0b\x61gent_state\x18\x03 \x01(\x0b\x32\x15.simulator.AgentState\x12,\n\x0cobject_state\x18\x04 \x01(\x0b\x32\x16.simulator.ObjectState\"h\n\x0bStateChange\x12\x0f\n\x07nve_idx\x18\x01 \x03(\x05\x12\x0f\n\x07\x63ol_idx\x18\x02 \x03(\x05\x12\x14\n\x0cnested_field\x18\x03 \x03(\t\x12!\n\x05value\x18\x04 \x01(\x0b\x32\x12.simulator.NDArray\">\n\rAddAgentInput\x12\x12\n\nmax_agents\x18\x01 \x01(\x05\x12\x19\n\x11serialized_config\x18\x02 \x01(\t\"$\n\x0eIsStartedState\x12\x12\n\nis_started\x18\x01 \x01(\x08\"\x16\n\x05Scene\x12\r\n\x05scene\x18\x01 \x01(\t2\xf4\x04\n\x0fSimulatorServer\x12\x32\n\x04Step\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x36\n\x08GetState\x12\x16.google.protobuf.Empty\x1a\x10.simulator.State\"\x00\x12\x41\n\x0bGetNVEState\x12\x16.google.protobuf.Empty\x1a\x18.simulator.EntitiesState\"\x00\x12@\n\rGetAgentState\x12\x16.google.protobuf.Empty\x1a\x15.simulator.AgentState\"\x00\x12\x42\n\x0eGetObjectState\x12\x16.google.protobuf.Empty\x1a\x16.simulator.ObjectState\"\x00\x12<\n\x08SetState\x12\x16.simulator.StateChange\x1a\x16.google.protobuf.Empty\"\x00\x12@\n\tIsStarted\x12\x16.google.protobuf.Empty\x1a\x19.simulator.IsStartedState\"\x00\x12\x39\n\x05Start\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x38\n\x04Stop\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x37\n\tLoadScene\x12\x10.simulator.Scene\x1a\x16.google.protobuf.Empty\"\x00\x42\x34\n\x1aio.grpc.examples.simulatorB\x0eSimulatorProtoP\x01\xa2\x02\x03SIMb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -30,21 +30,23 @@ _globals['_RIGIDBODY']._serialized_start=112 _globals['_RIGIDBODY']._serialized_end=200 _globals['_SIMULATORSTATE']._serialized_start=203 - _globals['_SIMULATORSTATE']._serialized_end=692 - _globals['_ENTITIESSTATE']._serialized_start=695 - _globals['_ENTITIESSTATE']._serialized_end=1056 - _globals['_AGENTSTATE']._serialized_start=1059 - _globals['_AGENTSTATE']._serialized_end=1498 - _globals['_OBJECTSTATE']._serialized_start=1500 - _globals['_OBJECTSTATE']._serialized_end=1627 - _globals['_STATE']._serialized_start=1630 - _globals['_STATE']._serialized_end=1829 - _globals['_STATECHANGE']._serialized_start=1831 - _globals['_STATECHANGE']._serialized_end=1935 - _globals['_ADDAGENTINPUT']._serialized_start=1937 - _globals['_ADDAGENTINPUT']._serialized_end=1999 - _globals['_ISSTARTEDSTATE']._serialized_start=2001 - _globals['_ISSTARTEDSTATE']._serialized_end=2037 - _globals['_SIMULATORSERVER']._serialized_start=2040 - _globals['_SIMULATORSERVER']._serialized_end=2611 + _globals['_SIMULATORSTATE']._serialized_end=733 + _globals['_ENTITIESSTATE']._serialized_start=736 + _globals['_ENTITIESSTATE']._serialized_end=1097 + _globals['_AGENTSTATE']._serialized_start=1100 + _globals['_AGENTSTATE']._serialized_end=1539 + _globals['_OBJECTSTATE']._serialized_start=1541 + _globals['_OBJECTSTATE']._serialized_end=1668 + _globals['_STATE']._serialized_start=1671 + _globals['_STATE']._serialized_end=1870 + _globals['_STATECHANGE']._serialized_start=1872 + _globals['_STATECHANGE']._serialized_end=1976 + _globals['_ADDAGENTINPUT']._serialized_start=1978 + _globals['_ADDAGENTINPUT']._serialized_end=2040 + _globals['_ISSTARTEDSTATE']._serialized_start=2042 + _globals['_ISSTARTEDSTATE']._serialized_end=2078 + _globals['_SCENE']._serialized_start=2080 + _globals['_SCENE']._serialized_end=2102 + _globals['_SIMULATORSERVER']._serialized_start=2105 + _globals['_SIMULATORSERVER']._serialized_end=2733 # @@protoc_insertion_point(module_scope) diff --git a/vivarium/simulator/grpc_server/simulator_pb2.pyi b/vivarium/simulator/grpc_server/simulator_pb2.pyi index 269cdd5..9d3631d 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2.pyi +++ b/vivarium/simulator/grpc_server/simulator_pb2.pyi @@ -27,7 +27,7 @@ class RigidBody(_message.Message): def __init__(self, center: _Optional[_Union[NDArray, _Mapping]] = ..., orientation: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class SimulatorState(_message.Message): - __slots__ = ("idx", "box_size", "max_agents", "max_objects", "num_steps_lax", "dt", "freq", "neighbor_radius", "to_jit", "use_fori_loop", "collision_eps", "collision_alpha") + __slots__ = ("idx", "box_size", "max_agents", "max_objects", "num_steps_lax", "dt", "freq", "neighbor_radius", "to_jit", "use_fori_loop", "collision_eps", "collision_alpha", "has_changed") IDX_FIELD_NUMBER: _ClassVar[int] BOX_SIZE_FIELD_NUMBER: _ClassVar[int] MAX_AGENTS_FIELD_NUMBER: _ClassVar[int] @@ -40,6 +40,7 @@ class SimulatorState(_message.Message): USE_FORI_LOOP_FIELD_NUMBER: _ClassVar[int] COLLISION_EPS_FIELD_NUMBER: _ClassVar[int] COLLISION_ALPHA_FIELD_NUMBER: _ClassVar[int] + HAS_CHANGED_FIELD_NUMBER: _ClassVar[int] idx: NDArray box_size: NDArray max_agents: NDArray @@ -52,7 +53,8 @@ class SimulatorState(_message.Message): use_fori_loop: NDArray collision_eps: NDArray collision_alpha: NDArray - def __init__(self, idx: _Optional[_Union[NDArray, _Mapping]] = ..., box_size: _Optional[_Union[NDArray, _Mapping]] = ..., max_agents: _Optional[_Union[NDArray, _Mapping]] = ..., max_objects: _Optional[_Union[NDArray, _Mapping]] = ..., num_steps_lax: _Optional[_Union[NDArray, _Mapping]] = ..., dt: _Optional[_Union[NDArray, _Mapping]] = ..., freq: _Optional[_Union[NDArray, _Mapping]] = ..., neighbor_radius: _Optional[_Union[NDArray, _Mapping]] = ..., to_jit: _Optional[_Union[NDArray, _Mapping]] = ..., use_fori_loop: _Optional[_Union[NDArray, _Mapping]] = ..., collision_eps: _Optional[_Union[NDArray, _Mapping]] = ..., collision_alpha: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... + has_changed: NDArray + def __init__(self, idx: _Optional[_Union[NDArray, _Mapping]] = ..., box_size: _Optional[_Union[NDArray, _Mapping]] = ..., max_agents: _Optional[_Union[NDArray, _Mapping]] = ..., max_objects: _Optional[_Union[NDArray, _Mapping]] = ..., num_steps_lax: _Optional[_Union[NDArray, _Mapping]] = ..., dt: _Optional[_Union[NDArray, _Mapping]] = ..., freq: _Optional[_Union[NDArray, _Mapping]] = ..., neighbor_radius: _Optional[_Union[NDArray, _Mapping]] = ..., to_jit: _Optional[_Union[NDArray, _Mapping]] = ..., use_fori_loop: _Optional[_Union[NDArray, _Mapping]] = ..., collision_eps: _Optional[_Union[NDArray, _Mapping]] = ..., collision_alpha: _Optional[_Union[NDArray, _Mapping]] = ..., has_changed: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ... class EntitiesState(_message.Message): __slots__ = ("position", "momentum", "force", "mass", "diameter", "entity_type", "entity_idx", "friction", "exists") @@ -149,3 +151,9 @@ class IsStartedState(_message.Message): IS_STARTED_FIELD_NUMBER: _ClassVar[int] is_started: bool def __init__(self, is_started: bool = ...) -> None: ... + +class Scene(_message.Message): + __slots__ = ("scene",) + SCENE_FIELD_NUMBER: _ClassVar[int] + scene: str + def __init__(self, scene: _Optional[str] = ...) -> None: ... diff --git a/vivarium/simulator/grpc_server/simulator_pb2_grpc.py b/vivarium/simulator/grpc_server/simulator_pb2_grpc.py index f11b97c..509fa15 100644 --- a/vivarium/simulator/grpc_server/simulator_pb2_grpc.py +++ b/vivarium/simulator/grpc_server/simulator_pb2_grpc.py @@ -61,6 +61,11 @@ def __init__(self, channel): request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, ) + self.LoadScene = channel.unary_unary( + '/simulator.SimulatorServer/LoadScene', + request_serializer=simulator__pb2.Scene.SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) class SimulatorServerServicer(object): @@ -121,6 +126,12 @@ def Stop(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def LoadScene(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_SimulatorServerServicer_to_server(servicer, server): rpc_method_handlers = { @@ -169,6 +180,11 @@ def add_SimulatorServerServicer_to_server(servicer, server): request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, ), + 'LoadScene': grpc.unary_unary_rpc_method_handler( + servicer.LoadScene, + request_deserializer=simulator__pb2.Scene.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'simulator.SimulatorServer', rpc_method_handlers) @@ -332,3 +348,20 @@ def Stop(request, google_dot_protobuf_dot_empty__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def LoadScene(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/simulator.SimulatorServer/LoadScene', + simulator__pb2.Scene.SerializeToString, + google_dot_protobuf_dot_empty__pb2.Empty.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/vivarium/simulator/grpc_server/simulator_server.py b/vivarium/simulator/grpc_server/simulator_server.py index aa919a1..07103ff 100644 --- a/vivarium/simulator/grpc_server/simulator_server.py +++ b/vivarium/simulator/grpc_server/simulator_server.py @@ -41,7 +41,8 @@ def __init__(self, simulator): self._lock = Lock() def GetState(self, request, context): - state = self.simulator.state + with self._lock: + state = self.simulator.state return state_to_proto(state) def GetNVEState(self, request, context): @@ -62,6 +63,11 @@ def Start(self, request, context): def IsStarted(self, request, context): return simulator_pb2.IsStartedState(is_started=self.simulator.is_started()) + + def LoadScene(self, request, context): + with self._lock: + self.simulator.load_scene(request.scene) + return Empty() def Stop(self, request, context): self.simulator.stop() diff --git a/vivarium/simulator/simulator.py b/vivarium/simulator/simulator.py index 964d794..bb08757 100644 --- a/vivarium/simulator/simulator.py +++ b/vivarium/simulator/simulator.py @@ -13,8 +13,11 @@ from jax import lax from jax_md import space, partition, dataclasses +from hydra import compose, initialize +from omegaconf import OmegaConf + from vivarium.controllers import converters -from vivarium.simulator.states import EntityType, SimulatorState +from vivarium.simulator.states import EntityType, SimulatorState, init_state_from_dict lg = logging.getLogger(__name__) @@ -26,7 +29,7 @@ def __init__(self, state, behavior_bank, dynamics_fn): self.behavior_bank = behavior_bank self.dynamics_fn = dynamics_fn - # TODO: explicitely copy the attributes of simulator_state (prevents linting errors and easier to understand which element is an attriute of the class) + # TODO: explicitly copy the attributes of simulator_state (prevents linting errors and easier to understand which element is an attribute of the class) all_attrs = [f.name for f in dataclasses.fields(SimulatorState)] for attr in all_attrs: self.update_attr(attr, SimulatorState.get_type(attr)) @@ -188,6 +191,15 @@ def set_state(self, nested_field, nve_idx, column_idx, value): if nested_field in (('simulator_state', 'box_size'), ('simulator_state', 'dt'), ('simulator_state', 'to_jit')): self.update_function_update() + def load_scene(self, scene): + with initialize(version_base=None, config_path="../../conf"): + args = compose(config_name="config", overrides=[f"scene={scene}"]) + + args = OmegaConf.merge(args.default, args.scene) + state = init_state_from_dict(args) + self. __init__(state, self.behavior_bank, self.dynamics_fn) + self.set_state(("simulator_state", "has_changed"), [0], None, jnp.array([True])) + # Functions to start, stop, pause @@ -243,6 +255,11 @@ def init_state(self, state): lg.info('init_state') self.state = self.init_fn(state, self.key) + def load_state(self, state): + lg.info('load_state') + # the pause may be unnecessary + with self.pause(): + self.__init__(state, self.behavior_bank, self.dynamics_fn) # Neighbor functions diff --git a/vivarium/simulator/states.py b/vivarium/simulator/states.py index 40c931f..76b1ad6 100644 --- a/vivarium/simulator/states.py +++ b/vivarium/simulator/states.py @@ -1,6 +1,5 @@ from enum import Enum from typing import Optional, List, Union -from collections import OrderedDict import inspect import yaml @@ -82,6 +81,7 @@ class SimulatorState: use_fori_loop: util.Array collision_alpha: util.Array collision_eps: util.Array + has_changed: util.Array @staticmethod def get_type(attr): @@ -89,7 +89,7 @@ def get_type(attr): return int elif attr in ['box_size', 'dt', 'freq', 'neighbor_radius', 'collision_alpha', 'collision_eps']: return float - elif attr in ['to_jit', 'use_fori_loop']: + elif attr in ['to_jit', 'use_fori_loop', 'has_changed']: return bool else: raise ValueError(f"Unknown attribute {attr}") @@ -154,7 +154,8 @@ def init_simulator_state( to_jit: bool = True, use_fori_loop: bool = False, collision_alpha: float = 0.5, - collision_eps: float = 0.1 + collision_eps: float = 0.1, + has_changed: bool = False ) -> SimulatorState: """ Initialize simulator state with given parameters @@ -172,7 +173,8 @@ def init_simulator_state( to_jit= jnp.array([1*to_jit]), use_fori_loop=jnp.array([1*use_fori_loop]), collision_alpha=jnp.array([collision_alpha]), - collision_eps=jnp.array([collision_eps])) + collision_eps=jnp.array([collision_eps]), + has_changed=jnp.array([1*has_changed])) def _init_positions(key_pos, positions, n_entities, box_size, n_dims=2): @@ -311,6 +313,21 @@ def init_state( ) +def init_state_from_dict(dictionary: dict): + simulator_state = init_simulator_state(**dictionary.simulator) + + agents_state = init_agent_state(simulator_state=simulator_state, **dictionary.agents) + + objects_state = init_object_state(simulator_state=simulator_state, **dictionary.objects) + + entities_state = init_entities_state(simulator_state=simulator_state, **dictionary.entities) + + return init_state(simulator_state = simulator_state, + agents_state = agents_state, + objects_state = objects_state, + entities_state = entities_state) + + def generate_default_config_files(): """ Generate a default yaml file with all the default arguments in the init_params_fns (see dict below)