Skip to content

Commit

Permalink
Refactor behaviors imports
Browse files Browse the repository at this point in the history
  • Loading branch information
corentinlger committed Nov 18, 2024
1 parent 1d48c1e commit fa78108
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 58 deletions.
27 changes: 15 additions & 12 deletions notebooks/server_side/3_prey_predator_braitenberg.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -34,8 +34,11 @@
"from jax import vmap, jit\n",
"from jax_md.dataclasses import dataclass as md_dataclass\n",
"\n",
"from vivarium.environments.braitenberg.simple import BraitenbergEnv, AgentState, State, EntityType, init_complete_state, init_entities, init_objects\n",
"from vivarium.environments.braitenberg.simple import compute_motor, compute_prox, behavior_to_params, Behaviors"
"from vivarium.environments.braitenberg.simple.simple_env import BraitenbergEnv\n",
"from vivarium.environments.braitenberg.simple.classes import AgentState, State, EntityType\n",
"from vivarium.environments.braitenberg.simple.init import init_complete_state, init_entities, init_objects\n",
"from vivarium.environments.braitenberg.simple.simple_env import compute_motor, compute_prox\n",
"from vivarium.environments.braitenberg.behaviors import behavior_to_params, Behaviors"
]
},
{
Expand All @@ -49,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -73,7 +76,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -100,7 +103,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -160,7 +163,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -265,7 +268,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -274,7 +277,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -286,7 +289,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -313,7 +316,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -327,7 +330,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand Down
41 changes: 41 additions & 0 deletions vivarium/environments/braitenberg/behaviors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from enum import Enum

import jax.numpy as jnp


class Behaviors(Enum):
FEAR = 0
AGGRESSION = 1
LOVE = 2
SHY = 3
NOOP = 4
MANUAL = 5

behavior_params = {
Behaviors.FEAR.value: jnp.array(
[[1., 0., 0.],
[0., 1., 0.]]),
Behaviors.AGGRESSION.value: jnp.array(
[[0., 1., 0.],
[1., 0., 0.]]),
Behaviors.LOVE.value: jnp.array(
[[-1., 0., 1.],
[0., -1., 1.]]),
Behaviors.SHY.value: jnp.array(
[[0., -1., 1.],
[-1., 0., 1.]]),
Behaviors.NOOP.value: jnp.array(
[[0., 0., 0.],
[0., 0., 0.]]),
Behaviors.MANUAL.value: jnp.array(
[[0., 0., 0.],
[0., 0., 0.]])
}

def behavior_to_params(behavior):
"""Return the params associated to a behavior.
:param behavior: behavior id (int)
:return: params
"""
return behavior_params[behavior]
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from jax import random
from jax_md.rigid_body import RigidBody

from vivarium.environments.braitenberg.simple.simple_env import Behaviors, behavior_to_params
from vivarium.environments.braitenberg.behaviors import Behaviors, behavior_to_params
from vivarium.utils.scene_configs import load_default_config
from vivarium.environments.braitenberg.selective_sensing.classes import State, AgentState, ObjectState, EntityState, EntityType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from vivarium.environments.utils import distance
from vivarium.environments.base_env import BaseEnv
from vivarium.environments.physics_engine import dynamics_fn
from vivarium.environments.braitenberg.behaviors import Behaviors
from vivarium.environments.braitenberg.selective_sensing.classes import State, Neighbors, EntityType
from vivarium.environments.braitenberg.selective_sensing.init import init_state

from vivarium.environments.braitenberg.simple.simple_env import (
proximity_map,
sensor_fn,
Behaviors,
linear_behavior,
braintenberg_force_fn
)
Expand Down
39 changes: 1 addition & 38 deletions vivarium/environments/braitenberg/simple/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jax_md.dataclasses import dataclass as md_dataclass
from jax_md import simulate

from vivarium.environments.base_env import BaseState, BaseEnv
from vivarium.environments.base_env import BaseState


class EntityType(Enum):
Expand Down Expand Up @@ -56,40 +56,3 @@ class State(BaseState):
entities: EntityState
agents: AgentState
objects: ObjectState

class Behaviors(Enum):
FEAR = 0
AGGRESSION = 1
LOVE = 2
SHY = 3
NOOP = 4
MANUAL = 5

behavior_params = {
Behaviors.FEAR.value: jnp.array(
[[1., 0., 0.],
[0., 1., 0.]]),
Behaviors.AGGRESSION.value: jnp.array(
[[0., 1., 0.],
[1., 0., 0.]]),
Behaviors.LOVE.value: jnp.array(
[[-1., 0., 1.],
[0., -1., 1.]]),
Behaviors.SHY.value: jnp.array(
[[0., -1., 1.],
[-1., 0., 1.]]),
Behaviors.NOOP.value: jnp.array(
[[0., 0., 0.],
[0., 0., 0.]]),
Behaviors.MANUAL.value: jnp.array(
[[0., 0., 0.],
[0., 0., 0.]])
}

def behavior_to_params(behavior):
"""Return the params associated to a behavior.
:param behavior: behavior id (int)
:return: params
"""
return behavior_params[behavior]
5 changes: 3 additions & 2 deletions vivarium/environments/braitenberg/simple/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
from jax import random
from jax_md.rigid_body import RigidBody

from vivarium.environments.braitenberg.behaviors import Behaviors, behavior_to_params
from vivarium.environments.braitenberg.simple.classes import (
EntityType,
State,
AgentState,
ObjectState,
EntityState,
Behaviors,
behavior_to_params
)

# Constants
SEED = 0
MAX_AGENTS = 10
MAX_OBJECTS = 2
Expand All @@ -37,6 +37,7 @@
OBJECTS_COLOR = jnp.array([1.0, 0.0, 0.0])
BEHAVIOR = Behaviors.AGGRESSION.value


def init_state(
box_size=BOX_SIZE,
dt=DT,
Expand Down
7 changes: 3 additions & 4 deletions vivarium/environments/braitenberg/simple/simple_env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import logging as lg

from enum import Enum
from functools import partial
from typing import Tuple

import numpy as np
import jax.numpy as jnp

from jax import vmap, jit
Expand All @@ -13,10 +11,11 @@
from jax_md import space, rigid_body, partition, quantity

from vivarium.environments.base_env import BaseEnv
from vivarium.environments.braitenberg.simple.init import init_state
from vivarium.environments.utils import normal, distance, relative_position
from vivarium.environments.physics_engine import total_collision_energy, friction_force, dynamics_fn
from vivarium.environments.braitenberg.simple.classes import EntityType, State, Behaviors, behavior_to_params
from vivarium.environments.braitenberg.simple.init import init_state
from vivarium.environments.braitenberg.behaviors import Behaviors
from vivarium.environments.braitenberg.simple.classes import EntityType, State


### Define the constants and the classes of the environment to store its state ###
Expand Down

0 comments on commit fa78108

Please sign in to comment.