Skip to content

Commit

Permalink
Enable multiple jitted steps in selective sensing and use back jax_md…
Browse files Browse the repository at this point in the history
… dataclasses
  • Loading branch information
corentinlger committed Aug 27, 2024
1 parent 771d689 commit 692c85c
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 42 deletions.
130 changes: 98 additions & 32 deletions vivarium/experimental/environments/braitenberg/selective_sensing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from flax import struct
from jax_md.rigid_body import RigidBody
from jax_md import simulate
from jax_md import space, partition
from jax_md.dataclasses import dataclass as md_dataclass
from jax_md import space, partition, simulate

from vivarium.experimental.environments.utils import distance
from vivarium.experimental.environments.base_env import BaseState, BaseEnv
Expand All @@ -32,8 +32,11 @@ class EntityType(Enum):
AGENT = 0
OBJECT = 1

# Set the class of all states to jax_md dataclass instead of struct dataclass
# What could be done in the future is to set it back to struct and simplify client server connection

# Already incorporates position, momentum, force, mass and velocity
@struct.dataclass
@md_dataclass
class EntityState(simulate.NVEState):
entity_type: jnp.array
ent_subtype: jnp.array
Expand All @@ -42,12 +45,12 @@ class EntityState(simulate.NVEState):
friction: jnp.array
exists: jnp.array

@struct.dataclass
@md_dataclass
class ParticleState:
ent_idx: jnp.array
color: jnp.array

@struct.dataclass
@md_dataclass
class AgentState(ParticleState):
prox: jnp.array
motor: jnp.array
Expand All @@ -63,11 +66,11 @@ class AgentState(ParticleState):
proxs_dist_max: jnp.array
proxs_cos_min: jnp.array

@struct.dataclass
@md_dataclass
class ObjectState(ParticleState):
pass

@struct.dataclass
@md_dataclass
class State(BaseState):
max_agents: jnp.int32
max_objects: jnp.int32
Expand All @@ -80,7 +83,8 @@ class State(BaseState):
agents: AgentState
objects: ObjectState

@struct.dataclass
# Not part of the state but part of the environment
@md_dataclass
class Neighbors:
neighbors: jnp.array
agents_neighs_idx: jnp.array
Expand Down Expand Up @@ -344,7 +348,7 @@ def __init__(self, state, occlusion=True, seed=42):
self.init_fn, self.apply_physics = dynamics_fn(self.displacement, self.shift, braintenberg_force_fn)
# Do a warning at the moment if neighbor radius is < box_size
if state.neighbor_radius < state.box_size:
lg.warn("Neighbor radius < Box size, there might be problems for neighbors arrays computations")
lg.warn("Neighbor radius < Box size, this might cause problems for neighbors arrays and proximity maps updates")
self.neighbor_fn = partition.neighbor_list(
self.displacement,
state.box_size,
Expand Down Expand Up @@ -377,30 +381,36 @@ def choose_agent_prox_motor_function(self):
return prox_motor_function

@partial(jit, static_argnums=(0,))
def _step(self, state: State, neighbors_storage: Neighbors) -> Tuple[State, jnp.array]:
"""Do 1 jitted step in the environment and return the updated state
def _step_env(self, state: State, neighbors_storage: Neighbors) -> Tuple[State, Neighbors]:
"""Do one jitted step in the environment and return the updated state, as well as updated neighbors array
:param state: current state
:param neighbors_storage: class storing all neighbors information
:return: new sttae
:return: new state, neighbors storage wih updated neighbors
"""

# Retrieve different neighbors format
neighbors = neighbors_storage.neighbors
agents_neighs_idx = neighbors_storage.agents_neighs_idx
ag_idx_dense = neighbors_storage.agents_idx_dense
# Differences : compute raw proxs for all agents first
dist, relative_theta, proximity_dist_map, proximity_dist_theta = get_relative_displacement(state, agents_neighs_idx, displacement_fn=self.displacement)
senders, receivers = agents_neighs_idx
ag_idx_dense_senders, ag_idx_dense_receivers = ag_idx_dense

# Compute raw proxs for all agents first
dist, relative_theta, proximity_dist_map, proximity_dist_theta = get_relative_displacement(
state,
agents_neighs_idx,
displacement_fn=self.displacement
)

dist_max = state.agents.proxs_dist_max[senders]
cos_min = state.agents.proxs_cos_min[senders]
# TODO : shouldn't the agents_neighs_idx[1, :] be receivers ?
target_exist_mask = state.entities.exists[agents_neighs_idx[1, :]]
# Compute agents raw proximeters (proximeters for all neighbors)
raw_proxs = sensor_fn(dist, relative_theta, dist_max, cos_min, target_exist_mask)

# Could even just pass ag_idx_dense in the fn and do this inside
ag_idx_dense_senders, ag_idx_dense_receivers = ag_idx_dense

# Compute real agents proximeters and motors
agent_proxs, mean_agent_motors = self.compute_all_agents_proxs_motors(
state,
state.agents.ent_idx,
Expand All @@ -413,44 +423,91 @@ def _step(self, state: State, neighbors_storage: Neighbors) -> Tuple[State, jnp.
ag_idx_dense_receivers,
)

agents = state.agents.replace(
# Update agents state
agents = state.agents.set(
prox=agent_proxs,
proximity_map_dist=proximity_dist_map,
proximity_map_theta=proximity_dist_theta,
motor=mean_agent_motors
)

# Last block unchanged
# Update the entities and the state
state = state.replace(agents=agents)
entities = self.apply_physics(state, neighbors)
state = state.replace(time=state.time+1, entities=entities)

# Update the neighbors storage
neighbors = neighbors.update(state.entities.position.center)
neighbors_storage = Neighbors(
neighbors=neighbors,
agents_neighs_idx=agents_neighs_idx,
agents_idx_dense=ag_idx_dense
)

return state, neighbors
return state, neighbors_storage

def step(self, state: State) -> State:
"""Do 1 step in the environment and return the updated state. This function also handles the neighbors mechanism and hence isn't jitted

@partial(jax.jit, static_argnums=(0, 3))
def _steps(self, state, neighbor_storage, num_updates):
print('COMPILE')

"""Update the current state by doing a _step_env update num_updates times (this results in faster simulations)
:param state: _description_
:param neighbor_storage: _description_
:param num_updates: _description_
"""

def step_fn(carry, _):
"""Apply a step function to return new state and neighbors storage in a jax.lax.scan update
:param carry: tuple of (state, neighbors storage)
:param _: dummy xs for jax.lax.scan
:return: tuple of (carry, carry) with carry=(new_state, new_neighbors _sotrage)
"""
state, neighbors_storage = carry
new_state, new_neighbors_storage = self._step_env(state, neighbors_storage)
carry = (new_state, new_neighbors_storage)
return carry, carry

(state, neighbor_storage), _ = jax.lax.scan(
step_fn,
(state, neighbor_storage),
xs=None,
length=num_updates
)

return state, neighbor_storage


def step(self, state: State, num_updates: int = 4) -> State:
"""Do num_updates jitted steps in the environment and return the updated state. This function also handles the neighbors mechanism and hence isn't jitted
:param state: current state
:param num_updates: number of jitted_steps
:return: next state
"""
# Because momentum is initialized to None, need to initialize it with init_fn from jax_md
if state.entities.momentum is None:
state = self.init_fn(state, self.init_key)
# Compute next state

# Save the first state
current_state = state
state, neighbors = self._step(current_state, self.neighbors_storage)

state, neighbors_storage = self._steps(current_state, self.neighbors_storage, num_updates)

# Check if neighbors buffer overflowed
if neighbors.did_buffer_overflow:
# reallocate neighbors and run the simulation from current_state
if neighbors_storage.neighbors.did_buffer_overflow:
# reallocate neighbors and run the simulation from current_state if it is the case
lg.warning(f'NEIGHBORS BUFFER OVERFLOW at step {state.time}: rebuilding neighbors')
self.neighbors_storage = self.allocate_neighbors(state)
assert not neighbors.did_buffer_overflow
# Because there was an error, we need to re-run this simulation loop from the copy of the current_state we created (and check wether it worked or not after)
state, neighbors_storage = self._steps(current_state, self.neighbors_storage, num_updates)
assert not neighbors_storage.neighbors.did_buffer_overflow

return state


def allocate_neighbors(self, state, position=None):
"""Allocate the neighbors according to the state
Expand Down Expand Up @@ -681,6 +738,7 @@ def init_entities(

def init_agents(
max_agents,
max_objects,
params,
sensed,
behaviors,
Expand All @@ -707,8 +765,9 @@ def init_agents(
theta_mul=jnp.full((max_agents), theta_mul),
proxs_dist_max=jnp.full((max_agents), prox_dist_max),
proxs_cos_min=jnp.full((max_agents), prox_cos_min),
proximity_map_dist=jnp.zeros((max_agents, 1)),
proximity_map_theta=jnp.zeros((max_agents, 1)),
# Change shape of these maps so they stay constant (jax.lax.scan problem otherwise)
proximity_map_dist=jnp.zeros((max_agents, max_agents + max_objects)),
proximity_map_theta=jnp.zeros((max_agents, max_agents + max_objects)),
color=agents_color
)

Expand Down Expand Up @@ -758,7 +817,7 @@ def init_complete_state(


def init_state(
entities_data,
entities_data=entities_data,
box_size=box_size,
dt=dt,
neighbor_radius=neighbor_radius,
Expand Down Expand Up @@ -879,6 +938,7 @@ def init_state(

agents = init_agents(
max_agents=max_agents,
max_objects=max_objects,
params=params,
sensed=sensed,
behaviors=behaviors,
Expand Down Expand Up @@ -913,4 +973,10 @@ def init_state(

return state

state = init_state(entities_data=entities_data)

if __name__ == "__main__":
state = init_state()
env = SelectiveSensorsEnv(state)

env.step(state, num_updates=5)
env.step(state, num_updates=6)
18 changes: 10 additions & 8 deletions vivarium/experimental/environments/braitenberg/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from jax_md.rigid_body import RigidBody
from jax_md import simulate
from jax_md import space, rigid_body, partition, quantity
from jax_md.dataclasses import dataclass as md_dataclass


from vivarium.experimental.environments.utils import normal, distance, relative_position
from vivarium.experimental.environments.base_env import BaseState, BaseEnv
Expand All @@ -28,20 +30,20 @@ class EntityType(Enum):
OBJECT = 1

# Already incorporates position, momentum, force, mass and velocity
@struct.dataclass
@md_dataclass
class EntityState(simulate.NVEState):
entity_type: jnp.array
entity_idx: jnp.array
diameter: jnp.array
friction: jnp.array
exists: jnp.array

@struct.dataclass
@md_dataclass
class ParticleState:
ent_idx: jnp.array
color: jnp.array

@struct.dataclass
@md_dataclass
class AgentState(ParticleState):
prox: jnp.array
motor: jnp.array
Expand All @@ -56,11 +58,11 @@ class AgentState(ParticleState):
proxs_dist_max: jnp.array
proxs_cos_min: jnp.array

@struct.dataclass
@md_dataclass
class ObjectState(ParticleState):
pass

@struct.dataclass
@md_dataclass
class State(BaseState):
max_agents: jnp.int32
max_objects: jnp.int32
Expand Down Expand Up @@ -369,19 +371,19 @@ def _step(self, state: State, neighbors: jnp.array, agents_neighs_idx: jnp.array

# 2 : Compute motor activations according to new proximeter values
motor = compute_motor(prox, state.agents.params, state.agents.behavior, state.agents.motor)
agents = state.agents.replace(
agents = state.agents.set(
prox=prox,
proximity_map_dist=proximity_dist_map,
proximity_map_theta=proximity_dist_theta,
motor=motor
)

# 3 : Update the state with new agents proximeter and motor values
state = state.replace(agents=agents)
state = state.set(agents=agents)

# 4 : Move the entities by applying forces on them (collision, friction and motor forces for agents)
entities = self.apply_physics(state, neighbors)
state = state.replace(time=state.time+1, entities=entities)
state = state.set(time=state.time+1, entities=entities)

# 5 : Update neighbors
neighbors = neighbors.update(state.entities.position.center)
Expand Down
4 changes: 2 additions & 2 deletions vivarium/experimental/environments/physics_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def init_fn(state, key, kT=0.):
assert state.entities.momentum is None
assert not jnp.any(state.entities.force.center) and not jnp.any(state.entities.force.orientation)

state = state.replace(entities=simulate.initialize_momenta(state.entities, key, kT))
state = state.set(entities=simulate.initialize_momenta(state.entities, key, kT))
return state

def mask_momentum(entity_state, exists_mask):
Expand All @@ -142,7 +142,7 @@ def step_fn(state, neighbor):
entity_state = simulate.momentum_step(state.entities, dt_2)
# TODO : why do we used dt and not dt/2 in the line below ?
entity_state = simulate.position_step(entity_state, shift, dt_2, neighbor=neighbor)
entity_state = entity_state.replace(force=force)
entity_state = entity_state.set(force=force)
entity_state = simulate.momentum_step(entity_state, dt_2)
entity_state = mask_momentum(entity_state, exists_mask)
return entity_state
Expand Down

0 comments on commit 692c85c

Please sign in to comment.