Skip to content

Commit

Permalink
Fix index problem in make_web_data
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Jul 10, 2024
1 parent 6597a81 commit 8788f36
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions scripts/make_web_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _agg_df(
sx_list, sy_list, slab_list = [], [], []
uniqueid_list, c_nsteps_list, s_nsteps_list = [], [], []
for i in range(length):
active_slots = np.nonzero(cact[i])
(active_slots,) = np.nonzero(cact[i])
caxy_i = caxy[i][active_slots]
saxy_i = saxy[i][sact[i]]

Expand All @@ -64,12 +64,18 @@ def _agg_df(
if len(df) != len(caxy_i):
warnings.warn(
"Number of active agents doesn't match"
+ f"State: {len(saxy_i)} Log: {len(df)}"
+ f"State: {len(caxy_i)} Log: {len(df)}"
+ f"at step {ldf_offset + start + i}",
stacklevel=1,
)
df = df.unique(subset="unique_id", keep="first")
df = df.filter(((pl.col("unique_id") == 0) & (pl.col("slots") != 0)).not_())
if len(df) != len(caxy_i):
ldf_slots = set(df["slots"])
for slot in active_slots:
if int(slot) not in ldf_slots:
print(f"Active slot {slot} does not appear in log")
exit(1)
uniqueid_list.append(df["unique_id"])
# Num. steps
c_nsteps_list.append(df["step"])
Expand Down Expand Up @@ -100,6 +106,7 @@ def main(
starting_points: List[int],
write_dir: Optional[Path] = None,
length: int = 100,
deincr_log_idx: bool = False, # For backward compatibility
) -> None:
if write_dir is None:
write_dir = Path("saved-web-data")
Expand All @@ -109,24 +116,25 @@ def main(

log_path = profile_and_rewards_path.parent.expanduser()

for point in starting_points:
for i, point in enumerate(starting_points):
index = point // 1024000
ld_start = point - 1 * int(deincr_log_idx)
ldfi = ldf.filter(
(pl.col("step") >= point) & (pl.col("step") < point + length)
(pl.col("step") >= ld_start) & (pl.col("step") <= ld_start + length)
).collect() # Offloading here for speedup
cxy_df, sxy_df = _agg_df(
log_path / f"state-{index + 1}.npz",
point - index * 1024000,
length,
ldfi,
index * 1024000,
index * 1024000 - 1 * int(deincr_log_idx),
)
cxy_df.write_parquet(
write_dir / f"saved_cpos-{point}.parqut",
write_dir / f"saved_cpos-{i}.parquet",
compression="snappy",
)
sxy_df.write_parquet(
write_dir / f"saved_spos-{point}.parqut",
write_dir / f"saved_spos-{i}.parquet",
compression="snappy",
)

Expand Down

0 comments on commit 8788f36

Please sign in to comment.