Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Load scenes from clients #79

Closed
wants to merge 11 commits into from
1 change: 1 addition & 0 deletions conf/scene/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ simulator:
num_steps_lax: 4
to_jit: true
use_fori_loop: false
has_changed: false
27 changes: 8 additions & 19 deletions scripts/run_server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import logging
import hydra

import hydra.core
import hydra.core.global_hydra
from omegaconf import DictConfig, OmegaConf

from vivarium.simulator import behaviors
from vivarium.simulator.states import init_simulator_state
from vivarium.simulator.states import init_agent_state
from vivarium.simulator.states import init_object_state
from vivarium.simulator.states import init_entities_state
from vivarium.simulator.states import init_state
from vivarium.simulator.states import init_state_from_dict
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.physics_engine import dynamics_rigid
from vivarium.simulator.grpc_server.simulator_server import serve
Expand All @@ -21,23 +19,14 @@ def main(cfg: DictConfig = None) -> None:

args = OmegaConf.merge(cfg.default, cfg.scene)

simulator_state = init_simulator_state(**args.simulator)

agents_state = init_agent_state(simulator_state=simulator_state, **args.agents)

objects_state = init_object_state(simulator_state=simulator_state, **args.objects)

entities_state = init_entities_state(simulator_state=simulator_state, **args.entities)

state = init_state(
simulator_state=simulator_state,
agents_state=agents_state,
objects_state=objects_state,
entities_state=entities_state
)
state = init_state_from_dict(args)

simulator = Simulator(state, behaviors.behavior_bank, dynamics_rigid)

# necessary to be able to load other scenes
glob = hydra.core.global_hydra.GlobalHydra()
glob.clear()

lg.info('Simulator server started')
serve(simulator)

Expand Down
2 changes: 2 additions & 0 deletions vivarium/controllers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class SimulatorConfig(Config):
use_fori_loop = param.Boolean(False)
collision_eps = param.Number(0.1)
collision_alpha = param.Number(0.5)
has_changed = param.Boolean(False)


def __init__(self, **params):
super().__init__(**params)
Expand Down
3 changes: 2 additions & 1 deletion vivarium/controllers/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def get_default_state(n_entities_dict):
neighbor_radius=jnp.array([1.]),
to_jit= jnp.array([1]), use_fori_loop=jnp.array([0]),
collision_alpha=jnp.array([0.]),
collision_eps=jnp.array([0.])),
collision_eps=jnp.array([0.]),
has_changed=jnp.array([0])),
entities_state=EntitiesState(position=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)),
momentum=None,
force=RigidBody(center=jnp.zeros((n_entities, 2)), orientation=jnp.zeros(n_entities)),
Expand Down
4 changes: 4 additions & 0 deletions vivarium/controllers/simulator_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def get_nve_state(self):
self.state = self.client.get_nve_state()
return self.state

def load_scene(self, scene):
self.client.load_scene(scene)
self.client.state = self.client.get_state()
self.__init__(client=self.client)

if __name__ == "__main__":

Expand Down
87 changes: 85 additions & 2 deletions vivarium/interface/panel_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,45 @@ def __init__(self, config, panel_configs, panel_simulator_config, selected, etyp
self.config[i].param.watch(self.hide_non_existing, "exists", onlychanged=True)

def drag_cb(self, attr, old, new):
"""Callback for the drag & drop of entities

:param attr: (unused)
:param old: (unused)
:param new: The event containing the new positions of the entities
"""
for i, c in enumerate(self.config):
c.x_position = new['x'][i]
c.y_position = new['y'][i]

@contextmanager
def no_drag_cb(self):
"""Prevent the CDS from updating the configs when the change comes from the
server
"""
self.cds.remove_on_change('data', self.drag_cb)
yield
self.cds.on_change('data', self.drag_cb)

def get_cds_data(self, state):
"""Update the ColumnDataSource with the new data

:param state: The state coming from the server
:return: Data dictionary for the ColumnDataSource
"""
raise NotImplementedError()

def update_cds(self, state):
"""Updates the ColumnDataSource with new data from server

:param state: The state coming from the server
"""
self.cds.data.update(self.get_cds_data(state))

def create_cds_view(self):
"""Creates a ColumnDataSource view for each visibility attribute

:return: A dictionary of ColumnDataSource views for each visibility attribute
"""
# For each attribute in the panel config, create a filter
# that is a logical AND of the visibility and the attribute
return {
Expand All @@ -67,32 +89,59 @@ def create_cds_view(self):
}

def update_cds_view(self, event):
"""Updates the view of the ColumnDataSource if the visibility of an entity
changes

:param event: The event containing the changed value
"""
n = event.name
for attr in [n] if n != "visible" else self.panel_configs[0].param_names():
f = [getattr(pc, attr) and pc.visible for pc in self.panel_configs]
self.cds_view[attr].filter = BooleanFilter(f)

def update_selected_plot(self, event):
"""Updates the selected entities in the plot

:param event: The event containing the new selected entities
"""
self.cds.selected.indices = event.new

def hide_all_non_existing(self, event):
"""Hides or shows all the entities that do not exist according to the global
visibility of non-existing entities

:param event: The event containing the new global "visibility of non-existing
entities" value
"""
for i, pc in enumerate(self.panel_configs):
if not self.config[i].exists:
pc.visible = not event.new

def hide_non_existing(self, event):
"""Hides or shows an entity that does not exist depending on the global
visibility of non-existing entities

:param event: The event containing the new existence value
"""
if not self.panel_simulator_config.hide_non_existing:
return
idx = self.config.index(event.obj)
self.panel_configs[idx].visible = event.new


def update_selected_simulator(self):
"""Updates the list of selected entities in the Selection list
"""
indices = self.cds.selected.indices
if len(indices) > 0 and indices != self.selected.selection:
self.selected.selection = indices

def plot(self, fig: figure):
"""Plot the objects on the bokeh figure

:param fig: A bokeh figure
:return: The figure with the objects plotted
"""
raise NotImplementedError()


Expand Down Expand Up @@ -205,6 +254,7 @@ class WindowManager(Parameterized):
align="center", value=config_types[1:])
update_switch = pn.widgets.Switch(name="Update plot", value=True, align="center")
update_timestep = pn.widgets.IntSlider(name="Timestep (ms)", value=1, start=1, end=1000)
scene_loader = pn.widgets.FileInput(accept=".yml", name="Load scene", align="center")
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.entity_manager_classes = {EntityType.AGENT: AgentManager,
Expand All @@ -224,6 +274,10 @@ def __init__(self, **kwargs):
self.set_callbacks()

def start_toggle_cb(self, event):
"""Callback for the start/stop button

:param event: The event for the new value of the button
"""
if event.new != self.controller.is_started():
if event.new:
self.controller.start()
Expand All @@ -237,12 +291,25 @@ def entity_toggle_cb(self, event):
cc.visible = cc.name in event.new

def update_timestep_cb(self, event):
"""Callback for the timestep of the plot update

:param event: The event for the new value of the timestep
"""
self.pcb_plot.period = event.new

def update_plot_cb(self):
"""Periodic callback for the plot update
"""
for em in self.entity_managers.values():
em.update_selected_simulator()
state = self.controller.update_state()
# TODO: change this part to use a function, that could be reused by the load button
if state.simulator_state.has_changed:
self = WindowManager()
self.controller.update_entity_list()
self.plot = self.create_plot()
self.app = self.create_app()
self.controller.simulator_config.has_changed = False
self.controller.pull_configs()
if self.controller.panel_simulator_config.config_update:
self.controller.pull_selected_configs()
Expand All @@ -251,12 +318,20 @@ def update_plot_cb(self):
em.update_cds(state)

def update_switch_cb(self, event):
"""Callback for the plot update switch

:param event: The event for the new value of the switch
"""
if event.new and not self.pcb_plot.running:
self.pcb_plot.start()
elif not event.new and self.pcb_plot.running:
self.pcb_plot.stop()

def create_plot(self):
"""Creates a bokeh plot for the simulator

:return: A bokeh plot
"""
p_tools = "crosshair,pan,wheel_zoom,box_zoom,reset,tap,box_select,lasso_select"
p = figure(tools=p_tools, active_drag="box_select")
p.axis.major_label_text_font_size = "24px"
Expand All @@ -270,6 +345,10 @@ def create_plot(self):
return p

def create_app(self):
"""Creates a panel app

:return: the panel app
"""
self.config_columns = pn.Row(*
[pn.Column(
pn.pane.Markdown("### SIMULATOR", align="center"),
Expand All @@ -287,7 +366,9 @@ def create_app(self):
for etype in self.entity_managers.keys()])

app = pn.Row(pn.Column(pn.Row(pn.pane.Markdown("### Start/Stop server", align="center"),
self.start_toggle),
self.start_toggle,
pn.pane.Markdown("### Load scene", align="center"),
self.scene_loader),
pn.Row(pn.pane.Markdown("### Start/Stop update", align="center"),
self.update_switch, self.update_timestep),
pn.panel(self.plot)),
Expand All @@ -296,7 +377,9 @@ def create_app(self):
return app

def set_callbacks(self):
# putting directly the slider value causes bugs on some OS
"""
Set the callbacks for all the widgets in the app
"""
self.pcb_plot = pn.state.add_periodic_callback(self.update_plot_cb,
self.update_timestep.value)
self.entity_toggle.param.watch(self.entity_toggle_cb, "value")
Expand Down
6 changes: 4 additions & 2 deletions vivarium/simulator/grpc_server/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def proto_to_simulator_state(simulator_state):
to_jit=proto_to_ndarray(simulator_state.to_jit).astype(int),
use_fori_loop=proto_to_ndarray(simulator_state.use_fori_loop).astype(int),
collision_eps=proto_to_ndarray(simulator_state.collision_eps).astype(float),
collision_alpha=proto_to_ndarray(simulator_state.collision_alpha).astype(float)
collision_alpha=proto_to_ndarray(simulator_state.collision_alpha).astype(float),
has_changed=proto_to_ndarray(simulator_state.has_changed).astype(int)
)


Expand Down Expand Up @@ -87,7 +88,8 @@ def simulator_state_to_proto(simulator_state):
to_jit=ndarray_to_proto(simulator_state.to_jit),
use_fori_loop=ndarray_to_proto(simulator_state.use_fori_loop),
collision_eps=ndarray_to_proto(simulator_state.collision_eps),
collision_alpha=ndarray_to_proto(simulator_state.collision_alpha)
collision_alpha=ndarray_to_proto(simulator_state.collision_alpha),
has_changed=ndarray_to_proto(simulator_state.has_changed)
)


Expand Down
6 changes: 6 additions & 0 deletions vivarium/simulator/grpc_server/protos/simulator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ service SimulatorServer {

rpc Stop(google.protobuf.Empty) returns (google.protobuf.Empty) {}

rpc LoadScene(Scene) returns (google.protobuf.Empty) {}
}

message AgentIdx {
Expand Down Expand Up @@ -68,6 +69,7 @@ message SimulatorState {
NDArray use_fori_loop = 10;
NDArray collision_eps = 11;
NDArray collision_alpha = 12;
NDArray has_changed = 13;
}

message EntitiesState {
Expand Down Expand Up @@ -124,3 +126,7 @@ message AddAgentInput {
message IsStartedState {
bool is_started = 1;
}

message Scene {
string scene = 1;
}
4 changes: 4 additions & 0 deletions vivarium/simulator/grpc_server/simulator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,7 @@ def step(self):

def is_started(self):
return self.stub.IsStarted(Empty()).is_started

def load_scene(self, scene):
message = simulator_pb2.Scene(scene=scene)
return self.stub.LoadScene(message)
Loading
Loading