Skip to content

Commit

Permalink
Merge pull request #7 from oist/webdata
Browse files Browse the repository at this point in the history
Script for aggregating log data for web-based visualizer
  • Loading branch information
kngwyu authored Jul 2, 2024
2 parents fbbbb95 + 1b727fa commit 8a1a37c
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 11 deletions.
6 changes: 1 addition & 5 deletions experiments/cf_nolearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@
SavedProfile,
is_cuda_ready,
)
from emevo.rl.ppo_normal import (
NormalPPONet,
vmap_apply,
vmap_net,
)
from emevo.rl.ppo_normal import NormalPPONet, vmap_apply, vmap_net
from emevo.spaces import BoxSpace

PROJECT_ROOT = Path(__file__).parent.parent
Expand Down
5 changes: 3 additions & 2 deletions experiments/cf_simple.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Asexual reward evolution with Circle Foraging"""

import dataclasses
import json
from pathlib import Path
Expand Down Expand Up @@ -56,11 +57,11 @@ class RewardExtractor:
_max_norm: jax.Array = dataclasses.field(init=False)

def __post_init__(self) -> None:
self._max_norm = jnp.sqrt(jnp.sum(self.act_space.high ** 2, axis=-1))
self._max_norm = jnp.sqrt(jnp.sum(self.act_space.high**2, axis=-1))

def normalize_action(self, action: jax.Array) -> jax.Array:
scaled = self.act_space.sigmoid_scale(action)
norm = jnp.sqrt(jnp.sum(scaled ** 2, axis=-1, keepdims=True))
norm = jnp.sqrt(jnp.sum(scaled**2, axis=-1, keepdims=True))
return norm / self._max_norm

def extract(
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import nox

SOURCES = ["src/emevo", "tests", "smoke-tests", "experiments"]
SOURCES = ["src/emevo", "tests", "smoke-tests", "experiments", "scripts"]


def _sync(session: nox.Session, requirements: str) -> None:
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,6 @@ select = ["E", "F", "B", "UP"]
"__init__.py" = ["F401"]
"src/emevo/reward_fn.py" = ["B023"]
# For typer
"experiments/**/*.py" = ["B008", "UP006", "UP007"]
"smoke-tests/*.py" = ["B008", "UP006", "UP007"]
"experiments/*.py" = ["B008", "UP006", "UP007"]
"smoke-tests/*.py" = ["B008", "UP006", "UP007"]
"scripts/*.py" = ["UP006", "UP035"]
141 changes: 141 additions & 0 deletions scripts/make_web_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Asexual reward evolution with Circle Foraging"""

import warnings
from pathlib import Path
from typing import List, Optional

import numpy as np
import polars as pl
import typer

from emevo.analysis.log_plotting import load_log

PROJECT_ROOT = Path(__file__).parent.parent


def _make_stats_df(profile_and_rewards_path: Path) -> tuple[pl.DataFrame, pl.DataFrame]:
rdf = pl.read_parquet(profile_and_rewards_path)
ldf = load_log(profile_and_rewards_path.parent).cast({"unique_id": pl.Int64})
nc_df = rdf.group_by("parent").agg(n_children=pl.col("unique_id").len())
age_df = (
ldf.group_by("unique_id").agg(lifetime=pl.col("unique_id").count()).collect()
)
food_df = ldf.group_by("unique_id").agg(eaten=pl.col("n_got_food").sum()).collect()
df = (
rdf.join(
nc_df, left_on="unique_id", right_on="parent", how="left", coalesce=True
)
.with_columns(pl.col("n_children").replace(None, 0))
.join(age_df, left_on="unique_id", right_on="unique_id")
.join(food_df, left_on="unique_id", right_on="unique_id")
)
return df, ldf


def _agg_df(
path: Path,
start: int,
length: int,
ldf: pl.DataFrame,
ldf_offset: int,
) -> tuple[pl.DataFrame, pl.DataFrame]:
npzfile = np.load(path)
caxy = npzfile["circle_axy"][start : start + length] # (length, 200, 3)
cact = npzfile["circle_is_active"][start : start + length] # (length, 200)
saxy = npzfile["static_circle_axy"][start : start + length]
sact = npzfile["static_circle_is_active"][start : start + length]
cx_list, cy_list, ca_list = [], [], []
sx_list, sy_list = [], []
uniqueid_list, c_nsteps_list, s_nsteps_list = [], [], []
for i in range(length):
active_slots = np.nonzero(cact[i])
caxy_i = caxy[i][active_slots]
saxy_i = saxy[i][sact[i]]

sx_list.append(saxy_i[:, 1])
sy_list.append(saxy_i[:, 2])

ca_list.append(caxy_i[:, 0])
cx_list.append(caxy_i[:, 1])
cy_list.append(caxy_i[:, 2])
df = ldf.filter(pl.col("step") == ldf_offset + start + i).sort("slots")
if len(df) != len(caxy_i):
warnings.warn(
"Number of active agents doesn't match"
+ f"State: {len(saxy_i)} Log: {len(df)}"
+ f"at step {ldf_offset + start + i}",
stacklevel=1,
)
df = df.unique(subset="unique_id", keep="first")
df = df.filter(((pl.col("unique_id") == 0) & (pl.col("slots") != 0)).not_())
uniqueid_list.append(df["unique_id"])
# Num. steps
c_nsteps_list.append(df["step"])
s_nsteps_list.append([ldf_offset + start + i] * len(saxy_i))

cx = np.concatenate(cx_list)
cy = np.concatenate(cy_list)
ca = np.concatenate(ca_list)
unique_id = pl.concat(uniqueid_list)
c_nsteps = pl.concat(c_nsteps_list)

sx = np.concatenate(sx_list)
sy = np.concatenate(sy_list)
s_nsteps = np.concatenate(s_nsteps_list)
cxy_df = pl.DataFrame(
{
"angle": ca,
"x": cx,
"y": cy,
"unique_id": unique_id,
"nsteps": c_nsteps,
}
)
sxy_df = pl.DataFrame(
{
"x": sx,
"y": sy,
"nsteps": s_nsteps,
}
)
return cxy_df, sxy_df


def main(
profile_and_rewards_path: Path,
starting_points: List[int],
write_dir: Optional[Path] = None,
length: int = 100,
) -> None:
if write_dir is None:
write_dir = Path("saved-web-data")

stats_df, ldf = _make_stats_df(profile_and_rewards_path)
stats_df.write_parquet(write_dir / "stats.parqut", compression="snappy")

log_path = profile_and_rewards_path.parent.expanduser()

for point in starting_points:
index = point // 1024000
ldfi = ldf.filter(
(pl.col("step") >= point) & (pl.col("step") < point + length)
).collect() # Offloading here for speedup
cxy_df, sxy_df = _agg_df(
log_path / f"state-{index + 1}.npz",
point - index * 1024000,
length,
ldfi,
index * 1024000,
)
cxy_df.write_parquet(
write_dir / f"saved_cpos-{point}.parqut",
compression="snappy",
)
sxy_df.write_parquet(
write_dir / f"saved_spos-{point}.parqut",
compression="snappy",
)


if __name__ == "__main__":
typer.run(main)
2 changes: 1 addition & 1 deletion src/emevo/exp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class GopsConfig:
path: str
init_std: float
init_mean: float
params: dict[str, float | dict[str, float]]
params: dict[str, float | dict[str, Any]]
init_kwargs: dict[str, float] = dataclasses.field(default_factory=dict)

def load_model(self) -> gops.Mutation | gops.Crossover:
Expand Down

0 comments on commit 8a1a37c

Please sign in to comment.