Skip to content

Commit

Permalink
pyserde 0.14 + lint
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Jun 5, 2024
1 parent 352f68c commit d68da7f
Show file tree
Hide file tree
Showing 15 changed files with 101 additions and 76 deletions.
2 changes: 1 addition & 1 deletion experiments/cf_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def replace_net(
) -> tuple[NormalPPONet, optax.OptState]:
initialized = initialize_net(key)
pponet = eqx_where(flag, initialized, pponet)
opt_state = jax.tree_map(
opt_state = jax.tree_util.tree_map(
lambda a, b: jnp.where(
jnp.expand_dims(flag, tuple(range(1, a.ndim))),
b,
Expand Down
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ dependencies = [
"moderngl >= 5.6",
"moderngl-window >= 2.4",
"jax >= 0.4",
"pyarrow >= 8.0",
"pyserde[toml]", # TODO: update
"pyarrow >= 9.0",
"pyserde[toml] >= 0.14",
"optax >= 0.1",
]
dynamic = ["version"]
Expand Down Expand Up @@ -71,8 +71,6 @@ select = ["E", "F", "B", "UP"]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
"src/emevo/reward_fn.py" = ["B023"]
# For pyserde
"src/emevo/exp_utils.py" = ["UP006", "UP007", "UP035"]
# For typer
"experiments/**/*.py" = ["B008", "UP006", "UP007"]
"smoke-tests/*.py" = ["B008", "UP006", "UP007"]
2 changes: 1 addition & 1 deletion smoke-tests/circle_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

def weight_summary(network):
params, _ = eqx.partition(network, eqx.is_inexact_array)
params_mean = jax.tree_map(jnp.mean, params)
params_mean = jax.tree_util.tree_map(jnp.mean, params)
for k, v in jax.tree_util.tree_leaves_with_path(params_mean):
print(k, v)

Expand Down
1 change: 1 addition & 0 deletions src/emevo/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Implementation of registry and built-in emevo environments.
"""

from emevo.environments.registry import register

register(
Expand Down
8 changes: 4 additions & 4 deletions src/emevo/environments/circle_foraging.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
)
from emevo.environments.env_utils import (
CircleCoordinate,
FoodNum,
FoodNumFn,
FoodNumState,
Locating,
LocatingFn,
LocatingState,
FoodNum,
FoodNumFn,
SquareCoordinate,
loc_gaussian,
first_to_nth_true,
loc_gaussian,
place,
place_multi,
)
Expand Down Expand Up @@ -233,7 +233,7 @@ def cr(shape: Circle, state: State) -> Raycast:
rc = segment_raycast(1.0, p1, p2, shaped.segment, stated.segment)
to_seg = jnp.where(rc.hit, 1.0 - rc.fraction, -1.0)
obs = jnp.concatenate(
jax.tree_map(
jax.tree_util.tree_map(
lambda arr: jnp.max(arr, keepdims=True),
(to_c, to_sc, to_seg),
),
Expand Down
4 changes: 2 additions & 2 deletions src/emevo/environments/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _update(self, internal: jax.Array) -> Self:
)

def get_slice(self, index: int) -> Self:
return jax.tree_map(lambda x: x[index], self)
return jax.tree_util.tree_map(lambda x: x[index], self)


class FoodNumFn(Protocol):
Expand Down Expand Up @@ -256,7 +256,7 @@ def increment(self, n: jax.Array | int = 1) -> Self:
return LocatingState(n_produced=self.n_produced + n)

def get_slice(self, index: int) -> Self:
return jax.tree_map(lambda x: x[index], self)
return jax.tree_util.tree_map(lambda x: x[index], self)


LocatingFn = Callable[[chex.PRNGKey, int, LocatingState], jax.Array]
Expand Down
55 changes: 31 additions & 24 deletions src/emevo/environments/phyjax2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,49 +43,49 @@ def empty(cls: type[T]) -> Callable[[], T]:
class PyTreeOps:
def __add__(self, o: Any) -> Self:
if o.__class__ is self.__class__:
return jax.tree_map(lambda x, y: x + y, self, o)
return jax.tree_util.tree_map(lambda x, y: x + y, self, o)
else:
return jax.tree_map(lambda x: x + o, self)
return jax.tree_util.tree_map(lambda x: x + o, self)

def __sub__(self, o: Any) -> Self:
if o.__class__ is self.__class__:
return jax.tree_map(lambda x, y: x - y, self, o)
return jax.tree_util.tree_map(lambda x, y: x - y, self, o)
else:
return jax.tree_map(lambda x: x - o, self)
return jax.tree_util.tree_map(lambda x: x - o, self)

def __mul__(self, o: float | jax.Array) -> Self:
return jax.tree_map(lambda x: x * o, self)
return jax.tree_util.tree_map(lambda x: x * o, self)

def __neg__(self) -> Self:
return jax.tree_map(lambda x: -x, self)
return jax.tree_util.tree_map(lambda x: -x, self)

def __truediv__(self, o: float | jax.Array) -> Self:
return jax.tree_map(lambda x: x / o, self)
return jax.tree_util.tree_map(lambda x: x / o, self)

@jax.jit
def get_slice(
self,
index: int | Sequence[int] | Sequence[bool] | jax.Array,
) -> Self:
return jax.tree_map(lambda x: x[index], self)
return jax.tree_util.tree_map(lambda x: x[index], self)

def reshape(self, shape: Sequence[int]) -> Self:
return jax.tree_map(lambda x: x.reshape(shape), self)
return jax.tree_util.tree_map(lambda x: x.reshape(shape), self)

def sum(self, axis: int | None = None) -> Self:
return jax.tree_map(lambda x: jnp.sum(x, axis=axis), self)
return jax.tree_util.tree_map(lambda x: jnp.sum(x, axis=axis), self)

def tolist(self) -> list[Self]:
leaves, treedef = jax.tree_util.tree_flatten(self)
return [treedef.unflatten(leaf) for leaf in zip(*leaves)]

def zeros_like(self) -> Any:
return jax.tree_map(lambda x: jnp.zeros_like(x), self)
return jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), self)

@property
def shape(self) -> Any:
"""For debugging"""
return jax.tree_map(lambda x: x.shape, self)
return jax.tree_util.tree_map(lambda x: x.shape, self)


TWO_PI = jnp.pi * 2
Expand Down Expand Up @@ -478,7 +478,9 @@ class StateDict:

def concat(self) -> Self:
states = [s for s in self.values() if s.batch_size() > 0] # type: ignore
return jax.tree_map(lambda *args: jnp.concatenate(args, axis=0), *states)
return jax.tree_util.tree_map(
lambda *args: jnp.concatenate(args, axis=0), *states
)

def _get(self, name: str, statec: State) -> State:
state = self[name] # type: ignore
Expand Down Expand Up @@ -527,7 +529,9 @@ def concat(self) -> Shape:
shapes = [
s.to_shape() for s in self.values() if s.batch_size() > 0 # type: ignore
]
return jax.tree_map(lambda *args: jnp.concatenate(args, axis=0), *shapes)
return jax.tree_util.tree_map(
lambda *args: jnp.concatenate(args, axis=0), *shapes
)

def n_shapes(self) -> int:
return sum([s.batch_size() for s in self.values()]) # type: ignore
Expand Down Expand Up @@ -602,8 +606,8 @@ def pair_inner(x: jax.Array, reps: int) -> jax.Array:


def _circle_to_circle(ci: ContactIndices[Circle, Circle], stated: StateDict) -> Contact:
pos1 = jax.tree_map(lambda arr: arr[ci.index1], stated.circle.p)
pos2 = jax.tree_map(lambda arr: arr[ci.index2], stated.circle.p)
pos1 = jax.tree_util.tree_map(lambda arr: arr[ci.index1], stated.circle.p)
pos2 = jax.tree_util.tree_map(lambda arr: arr[ci.index2], stated.circle.p)
is_active1 = stated.circle.is_active[ci.index1]
is_active2 = stated.circle.is_active[ci.index2]
return _circle_to_circle_impl(
Expand All @@ -619,8 +623,8 @@ def _circle_to_static_circle(
ci: ContactIndices[Circle, Circle],
stated: StateDict,
) -> Contact:
pos1 = jax.tree_map(lambda arr: arr[ci.index1], stated.circle.p)
pos2 = jax.tree_map(lambda arr: arr[ci.index2], stated.static_circle.p)
pos1 = jax.tree_util.tree_map(lambda arr: arr[ci.index1], stated.circle.p)
pos2 = jax.tree_util.tree_map(lambda arr: arr[ci.index2], stated.static_circle.p)
is_active1 = stated.circle.is_active[ci.index1]
is_active2 = stated.static_circle.is_active[ci.index2]
return _circle_to_circle_impl(
Expand All @@ -636,8 +640,8 @@ def _capsule_to_circle(
ci: ContactIndices[Capsule, Circle],
stated: StateDict,
) -> Contact:
pos1 = jax.tree_map(lambda arr: arr[ci.index1], stated.capsule.p)
pos2 = jax.tree_map(lambda arr: arr[ci.index2], stated.circle.p)
pos1 = jax.tree_util.tree_map(lambda arr: arr[ci.index1], stated.capsule.p)
pos2 = jax.tree_util.tree_map(lambda arr: arr[ci.index2], stated.circle.p)
is_active1 = stated.capsule.is_active[ci.index1]
is_active2 = stated.circle.is_active[ci.index2]
return _capsule_to_circle_impl(
Expand All @@ -653,8 +657,8 @@ def _segment_to_circle(
ci: ContactIndices[Segment, Circle],
stated: StateDict,
) -> Contact:
pos1 = jax.tree_map(lambda arr: arr[ci.index1], stated.segment.p)
pos2 = jax.tree_map(lambda arr: arr[ci.index2], stated.circle.p)
pos1 = jax.tree_util.tree_map(lambda arr: arr[ci.index1], stated.segment.p)
pos2 = jax.tree_util.tree_map(lambda arr: arr[ci.index2], stated.circle.p)
is_active1 = stated.segment.is_active[ci.index1]
is_active2 = stated.circle.is_active[ci.index2]
return _segment_to_circle_impl(
Expand Down Expand Up @@ -732,7 +736,7 @@ def __post_init__(self) -> None:
index2=ci.index2 + offset2,
)
ci_slided_list.append(ci_slided)
self._ci_total = jax.tree_map(
self._ci_total = jax.tree_util.tree_map(
lambda *args: jnp.concatenate(args, axis=0),
*ci_slided_list,
)
Expand All @@ -744,7 +748,10 @@ def check_contacts(self, stated: StateDict) -> Contact:
if ci is not None:
contact = fn(ci, stated)
contacts.append(contact)
return jax.tree_map(lambda *args: jnp.concatenate(args, axis=0), *contacts)
return jax.tree_util.tree_map(
lambda *args: jnp.concatenate(args, axis=0),
*contacts,
)

def n_possible_contacts(self) -> int:
n = 0
Expand Down
2 changes: 1 addition & 1 deletion src/emevo/environments/phyjax2d_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _capsule_mass(

def _concat_or(sl: list[S], default_fn: Callable[[], S]) -> S:
if len(sl) > 0:
return jax.tree_map(lambda *args: jnp.concatenate(args, axis=0), *sl)
return jax.tree_util.tree_map(lambda *args: jnp.concatenate(args, axis=0), *sl)
else:
return default_fn()

Expand Down
4 changes: 2 additions & 2 deletions src/emevo/eqx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@

def get_slice(module: M, slice_idx: int | jax.Array) -> M:
dynamic, static = eqx.partition(module, eqx.is_array)
sliced_dyn = jax.tree_map(lambda item: item[slice_idx], dynamic)
sliced_dyn = jax.tree_util.tree_map(lambda item: item[slice_idx], dynamic)
return eqx.combine(sliced_dyn, static)


@eqx.filter_jit
def where(flag: jax.Array, mod_a: M, mod_b: M) -> M:
dyn_a, static = eqx.partition(mod_a, eqx.is_array)
dyn_b, _ = eqx.partition(mod_b, eqx.is_array)
dyn = jax.tree_map(
dyn = jax.tree_util.tree_map(
lambda a, b: jnp.where(jnp.expand_dims(flag, tuple(range(1, a.ndim))), a, b),
dyn_a,
dyn_b,
Expand Down
Loading

0 comments on commit d68da7f

Please sign in to comment.