Skip to content

Commit

Permalink
Added max_speed attribute to agents (#72)
Browse files Browse the repository at this point in the history
---------
  • Loading branch information
Marsolo1 authored Apr 11, 2024
1 parent a04fbb3 commit 613cd43
Show file tree
Hide file tree
Showing 12 changed files with 38 additions and 316 deletions.
1 change: 1 addition & 0 deletions conf/scene/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ agents:
behavior: 1
wheel_diameter: 2.
speed_mul: 1.
max_speed: 10.
theta_mul: 1.
prox_dist_max: 40.
prox_cos_min: 0.
Expand Down
4 changes: 3 additions & 1 deletion tests/test_simulator_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vivarium.simulator.states import init_entities_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.sim_computation import dynamics_rigid
from vivarium.simulator.physics_engine import dynamics_rigid


def test_init_simulator_no_args():
Expand Down Expand Up @@ -40,6 +40,7 @@ def test_init_simulator_args():
behavior = 1
wheel_diameter = 2.0
speed_mul = 1.0
max_speed = 10.0
theta_mul = 1.0
prox_dist_max = 20.0
prox_cos_min = 0.0
Expand All @@ -62,6 +63,7 @@ def test_init_simulator_args():
behavior=behavior,
wheel_diameter=wheel_diameter,
speed_mul=speed_mul,
max_speed=max_speed,
theta_mul=theta_mul,
prox_dist_max=prox_dist_max,
prox_cos_min=prox_cos_min)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_simulator_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vivarium.simulator.states import init_entities_state
from vivarium.simulator.states import init_state
from vivarium.simulator.simulator import Simulator
from vivarium.simulator.sim_computation import dynamics_rigid
from vivarium.simulator.physics_engine import dynamics_rigid

NUM_STEPS = 50

Expand Down
1 change: 1 addition & 0 deletions vivarium/controllers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class AgentConfig(Config):
wheel_diameter = param.Number(2.)
diameter = param.Number(5.)
speed_mul = param.Number(1.)
max_speed = param.Number(10.)
theta_mul = param.Number(1.)
proxs_dist_max = param.Number(100., bounds=(0, None))
proxs_cos_min = param.Number(0., bounds=(-1., 1.))
Expand Down
1 change: 1 addition & 0 deletions vivarium/controllers/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def get_default_state(n_entities_dict):
behavior=jnp.zeros(max_agents, dtype=int),
wheel_diameter=jnp.zeros(max_agents),
speed_mul=jnp.zeros(max_agents),
max_speed=jnp.zeros(max_agents),
theta_mul=jnp.zeros(max_agents),
proxs_dist_max=jnp.zeros(max_agents),
proxs_cos_min=jnp.zeros(max_agents),
Expand Down
2 changes: 2 additions & 0 deletions vivarium/simulator/grpc_server/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def proto_to_agent_state(agent_state):
behavior=proto_to_ndarray(agent_state.behavior).astype(int),
wheel_diameter=proto_to_ndarray(agent_state.wheel_diameter).astype(float),
speed_mul=proto_to_ndarray(agent_state.speed_mul).astype(float),
max_speed=proto_to_ndarray(agent_state.max_speed).astype(float),
theta_mul=proto_to_ndarray(agent_state.theta_mul).astype(float),
proxs_dist_max=proto_to_ndarray(agent_state.proxs_dist_max).astype(float),
proxs_cos_min=proto_to_ndarray(agent_state.proxs_cos_min).astype(float),
Expand Down Expand Up @@ -114,6 +115,7 @@ def agent_state_to_proto(agent_state):
behavior=ndarray_to_proto(agent_state.behavior),
wheel_diameter=ndarray_to_proto(agent_state.wheel_diameter),
speed_mul=ndarray_to_proto(agent_state.speed_mul),
max_speed=ndarray_to_proto(agent_state.max_speed),
theta_mul=ndarray_to_proto(agent_state.theta_mul),
proxs_dist_max=ndarray_to_proto(agent_state.proxs_dist_max),
proxs_cos_min=ndarray_to_proto(agent_state.proxs_cos_min),
Expand Down
9 changes: 5 additions & 4 deletions vivarium/simulator/grpc_server/protos/simulator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ message AgentState {
NDArray behavior = 4;
NDArray wheel_diameter = 5;
NDArray speed_mul = 6;
NDArray theta_mul = 7;
NDArray proxs_dist_max = 8;
NDArray proxs_cos_min = 9;
NDArray color = 10;
NDArray max_speed = 7;
NDArray theta_mul = 8;
NDArray proxs_dist_max = 9;
NDArray proxs_cos_min = 10;
NDArray color = 11;
}

message ObjectState {
Expand Down
28 changes: 14 additions & 14 deletions vivarium/simulator/grpc_server/simulator_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions vivarium/simulator/grpc_server/simulator_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,14 @@ class EntitiesState(_message.Message):
def __init__(self, position: _Optional[_Union[RigidBody, _Mapping]] = ..., momentum: _Optional[_Union[RigidBody, _Mapping]] = ..., force: _Optional[_Union[RigidBody, _Mapping]] = ..., mass: _Optional[_Union[RigidBody, _Mapping]] = ..., diameter: _Optional[_Union[NDArray, _Mapping]] = ..., entity_type: _Optional[_Union[NDArray, _Mapping]] = ..., entity_idx: _Optional[_Union[NDArray, _Mapping]] = ..., friction: _Optional[_Union[NDArray, _Mapping]] = ..., exists: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ...

class AgentState(_message.Message):
__slots__ = ("nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color")
__slots__ = ("nve_idx", "prox", "motor", "behavior", "wheel_diameter", "speed_mul", "max_speed", "theta_mul", "proxs_dist_max", "proxs_cos_min", "color")
NVE_IDX_FIELD_NUMBER: _ClassVar[int]
PROX_FIELD_NUMBER: _ClassVar[int]
MOTOR_FIELD_NUMBER: _ClassVar[int]
BEHAVIOR_FIELD_NUMBER: _ClassVar[int]
WHEEL_DIAMETER_FIELD_NUMBER: _ClassVar[int]
SPEED_MUL_FIELD_NUMBER: _ClassVar[int]
MAX_SPEED_FIELD_NUMBER: _ClassVar[int]
THETA_MUL_FIELD_NUMBER: _ClassVar[int]
PROXS_DIST_MAX_FIELD_NUMBER: _ClassVar[int]
PROXS_COS_MIN_FIELD_NUMBER: _ClassVar[int]
Expand All @@ -94,11 +95,12 @@ class AgentState(_message.Message):
behavior: NDArray
wheel_diameter: NDArray
speed_mul: NDArray
max_speed: NDArray
theta_mul: NDArray
proxs_dist_max: NDArray
proxs_cos_min: NDArray
color: NDArray
def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., prox: _Optional[_Union[NDArray, _Mapping]] = ..., motor: _Optional[_Union[NDArray, _Mapping]] = ..., behavior: _Optional[_Union[NDArray, _Mapping]] = ..., wheel_diameter: _Optional[_Union[NDArray, _Mapping]] = ..., speed_mul: _Optional[_Union[NDArray, _Mapping]] = ..., theta_mul: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_dist_max: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_cos_min: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ...
def __init__(self, nve_idx: _Optional[_Union[NDArray, _Mapping]] = ..., prox: _Optional[_Union[NDArray, _Mapping]] = ..., motor: _Optional[_Union[NDArray, _Mapping]] = ..., behavior: _Optional[_Union[NDArray, _Mapping]] = ..., wheel_diameter: _Optional[_Union[NDArray, _Mapping]] = ..., speed_mul: _Optional[_Union[NDArray, _Mapping]] = ..., max_speed: _Optional[_Union[NDArray, _Mapping]] = ..., theta_mul: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_dist_max: _Optional[_Union[NDArray, _Mapping]] = ..., proxs_cos_min: _Optional[_Union[NDArray, _Mapping]] = ..., color: _Optional[_Union[NDArray, _Mapping]] = ...) -> None: ...

class ObjectState(_message.Message):
__slots__ = ("nve_idx", "custom_field", "color")
Expand Down
2 changes: 2 additions & 0 deletions vivarium/simulator/physics_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def motor_force(state, exists_mask):
state.entities_state.diameter[agent_idx],
state.agent_state.wheel_diameter
)
# `a_max` arg is deprecated in recent versions of jax, replaced by `max`
fwd = jnp.clip(fwd, a_max=state.agent_state.max_speed)

cur_vel = state.entities_state.momentum.center[agent_idx] / state.entities_state.mass.center[agent_idx]
cur_fwd_vel = vmap(jnp.dot)(cur_vel, n)
Expand Down
Loading

0 comments on commit 613cd43

Please sign in to comment.