Skip to content

Commit

Permalink
added exp dir
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed Jun 6, 2024
1 parent 43896ed commit f046b30
Show file tree
Hide file tree
Showing 164 changed files with 6,005 additions and 14 deletions.
Binary file added experiments/fig_4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
152 changes: 152 additions & 0 deletions experiments/fig_4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# %%
import sys
import numpy as np
import matplotlib.pyplot as plt

sys.path.append("../")
import proc_elo
import proc_task_eval


# Figure style-related
TITLE_FONT = 24
LABEL_FONT = 18
TICK_FONT = 11
ELO_SAMPLE_TICKS = list(range(0, 150, 25))
ELO_SAMPLE_LABELS = ["0", "25M", "50M", "75M", "100M", "125M"]
ELO_YTIKCS = list(range(850, 1101, 50))

MT_SAMPLE_TICKS = list(range(0, 200, 50))
MT_SAMPLE_LABELS = ["0", "50M", "100M", "150M"]
MT_YTICKS = [y / 100 for y in range(0, 21, 4)]

# blue-ish: #4c72b0, orange: #dd8452
SPECIALIST_MARKER = {
"marker": "o",
"s": 100,
"c": "white",
"edgecolors": "#4c72b0",
"label": "Specialist",
"linewidths": 3,
}
GENERALIST_MARKER = {"marker": "o", "s": 120, "c": "#dd8452", "label": "Generalist"}

# Manually extract sampling ratio for each minigame, after running average smoothing with 100
# https://wandb.ai/kywch/meta-mmo/runs/lf95vvxr, see stats/Sampling/{minigame}_agent_steps
generalist_sample_ratio = {
"svonly": [0.288, 0.290, 0.293, 0.291], # Survival
"tbonly": [0.305, 0.304, 0.299, 0.296], # Team Battle
"mtonly": [0.407, 0.406, 0.408, 0.412], # Multi-task Traiing
}


def process_eval_data(
data, specialist_prefix, sample_ratio, value_key="elo", generalist_prefix="general"
):
results = {"generalist": {"steps": [], "values": []}, "specialist": {"steps": [], "values": []}}
for pol_data in data:
pol_info = pol_data["policy"].split("_")
assert pol_info[0] in [specialist_prefix, generalist_prefix], ""
pol_type = "generalist" if pol_info[0] == generalist_prefix else "specialist"
results[pol_type]["values"].append(pol_data[value_key])
results[pol_type]["steps"].append(int(pol_info[1][:-1]))

# Apply sampling ratio to correct the steps
gen_idx = np.argsort(results["generalist"]["steps"])
for ii, idx in enumerate(gen_idx):
results["generalist"]["steps"][ii] *= sample_ratio[idx]

return results


if __name__ == "__main__":
# Data for each minigame
survive_elo = process_eval_data(
proc_elo.process_eval_files("full_sv", "survive"),
"svonly",
generalist_sample_ratio["svonly"],
)

battle_elo = process_eval_data(
proc_elo.process_eval_files("full_tb", "battle"),
"tbonly",
generalist_sample_ratio["tbonly"],
)

task_data = proc_task_eval.process_eval_files("full_mt", "curriculum").to_dicts()
task_progress = process_eval_data(
task_data, "mtonly", generalist_sample_ratio["mtonly"], value_key="task_progress"
)

# Create subplots with specified figure size
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))

# Remove top and right edges for each subplot
for ax in [ax1, ax2, ax3]:
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.tick_params(axis="both", which="major", length=6, labelsize=TICK_FONT)

# Marker styles for generalist and specialist

# Survival subplot
ax1.set_title("Survival", fontsize=TITLE_FONT)
ax1.set_ylabel("Elo", fontsize=TITLE_FONT)
ax1.set_xticks(ELO_SAMPLE_TICKS)
ax1.set_xticklabels(ELO_SAMPLE_LABELS, fontsize=TICK_FONT)
ax1.set_xlim(0, 125)
ax1.set_yticks(ELO_YTIKCS)
ax1.set_ylim(850, 1100)
ax1.plot([0, 150], [1000, 1000], "k:") # Anchor line for ELO
ax1.scatter(
survive_elo["generalist"]["steps"], survive_elo["generalist"]["values"], **GENERALIST_MARKER
)
ax1.scatter(
survive_elo["specialist"]["steps"], survive_elo["specialist"]["values"], **SPECIALIST_MARKER
)
ax1.legend(loc="lower right", fontsize=13)

# Team Battle subplot
ax2.set_title("Team Battle", fontsize=TITLE_FONT)
ax2.set_xlabel("Training samples", fontsize=TITLE_FONT, labelpad=13)
ax2.set_xticks(ELO_SAMPLE_TICKS)
ax2.set_xticklabels(ELO_SAMPLE_LABELS, fontsize=TICK_FONT)
ax2.set_xlim(0, 125)
ax2.set_yticks(ELO_YTIKCS)
ax2.set_ylim(850, 1100)
ax2.plot([0, 150], [1000, 1000], "k:") # Anchor line for ELO
ax2.scatter(
battle_elo["specialist"]["steps"], battle_elo["specialist"]["values"], **SPECIALIST_MARKER
)
ax2.scatter(
battle_elo["generalist"]["steps"], battle_elo["generalist"]["values"], **GENERALIST_MARKER
)

# Multi-task Eval subplot
ax3.set_title("Multi-task Eval", fontsize=TITLE_FONT)
ax3.set_ylabel("Task progress", fontsize=TITLE_FONT)
ax3.set_xticks(MT_SAMPLE_TICKS)
ax3.set_xticklabels(MT_SAMPLE_LABELS, fontsize=TICK_FONT)
ax3.set_xlim(0, 175)
ax3.set_yticks(MT_YTICKS)
ax3.set_ylim(0, 0.2)
ax3.scatter(
task_progress["specialist"]["steps"],
task_progress["specialist"]["values"],
**SPECIALIST_MARKER,
)
ax3.scatter(
task_progress["generalist"]["steps"],
task_progress["generalist"]["values"],
**GENERALIST_MARKER,
)

# Adjust spacing between subplots
plt.tight_layout()
plt.subplots_adjust(wspace=0.5)

# Save the figure as a PNG file with specified size
plt.savefig("fig_4.png", dpi=300, bbox_inches="tight")

# Display the figure
plt.show()
Binary file added experiments/fig_6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
96 changes: 96 additions & 0 deletions experiments/fig_6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# %%
import sys
import matplotlib.pyplot as plt

sys.path.append("../")
import proc_elo
from fig_4 import process_eval_data


# Figure style-related
TITLE_FONT = 26
TICK_FONT = 12
ELO_SAMPLE_TICKS = list(range(0, 150, 25))
ELO_SAMPLE_LABEL = ["0", "25M", "50M", "75M", "100M", "125M"]
ELO_YTIKCS = list(range(850, 1101, 50))

# blue-ish: #4c72b0, orange: #dd8452
SPECIALIST_MARKER = {
"marker": "o",
"s": 100,
"c": "white",
"edgecolors": "#4c72b0",
"label": "Specialist",
"linewidths": 3,
}
GENERALIST_MARKER = {"marker": "o", "s": 120, "c": "#dd8452", "label": "Generalist"}

# Manually extract sampling ratio for each minigame, after running average smoothing with 100
# https://wandb.ai/kywch/meta-mmo/runs/53n3yvnj, see stats/Sampling/{minigame}_agent_steps
generalist_sample_ratio = {
"battle": [0.218, 0.261, 0.274, 0.280], # Team Battle
"ptk": [0.153, 0.188, 0.201, 0.217], # Protect the King
"race": [0.164, 0.120, 0.100, 0.097], # Race to the Center
"koh": [0.228, 0.231, 0.223, 0.210], # King of the Hill
"sandwich": [0.237, 0.200, 0.202, 0.196], # Sandwich
}

minigame_info = {
"battle": ("Team Battle", "mini_tb", "tbonly"), # name, directory, prefix
"ptk": ("Protect the King", "mini_pk", "pkonly"),
"race": ("Race to the Center", "mini_rc", "rconly"),
"koh": ("King of the Hill", "mini_kh", "khonly"),
"sandwich": ("Sandwich", "mini_sw", "swonly"),
}

if __name__ == "__main__":
# Data for each minigame
elo_data = {
game: process_eval_data(
proc_elo.process_eval_files(dir, game), prefix, generalist_sample_ratio[game]
)
for game, (_, dir, prefix) in minigame_info.items()
}

# Create subplots with specified figure size
fig, axes = plt.subplots(1, 5, figsize=(20, 4))

# Remove top and right edges for each subplot
for ax, game in zip(axes, minigame_info.keys()):
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.tick_params(axis="both", which="major", length=6, labelsize=TICK_FONT)

ax.set_title(minigame_info[game][0], fontsize=TITLE_FONT)
ax.set_xticks(ELO_SAMPLE_TICKS)
ax.set_xticklabels(ELO_SAMPLE_LABEL, fontsize=TICK_FONT)
ax.set_xlim(0, 105)
if max(elo_data[game]["generalist"]["steps"]) > 100:
ax.set_xlim(0, 125)
ax.set_yticks(ELO_YTIKCS)
ax.set_ylim(880, 1100)
ax.plot([0, 150], [1000, 1000], "k:") # Anchor line for ELO
ax.scatter(
elo_data[game]["generalist"]["steps"],
elo_data[game]["generalist"]["values"],
**GENERALIST_MARKER,
)
ax.scatter(
elo_data[game]["specialist"]["steps"],
elo_data[game]["specialist"]["values"],
**SPECIALIST_MARKER,
)

axes[0].set_ylabel("Elo", fontsize=TITLE_FONT + 2)
axes[0].legend(loc="lower right", fontsize=15)
axes[2].set_xlabel("Training samples", fontsize=TITLE_FONT + 6, labelpad=15)

# Adjust spacing between subplots
plt.tight_layout()
plt.subplots_adjust(wspace=0.3)

# Save the figure as a PNG file with specified size
plt.savefig("fig_6.png", dpi=300, bbox_inches="tight")

# Display the figure
plt.show()
1 change: 1 addition & 0 deletions experiments/full_mt/curriculum_info_1.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions experiments/full_mt/curriculum_info_17590196.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions experiments/full_mt/curriculum_info_2.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions experiments/full_mt/curriculum_info_48692979.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions experiments/full_mt/curriculum_info_59315038.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions experiments/full_mt/curriculum_info_70970695.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions experiments/full_mt/curriculum_info_74435116.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions experiments/full_mt/curriculum_info_78309306.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions experiments/full_mt/curriculum_info_87951085.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions experiments/full_mt/curriculum_info_97056894.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions experiments/full_mt/general_100m.pt
1 change: 1 addition & 0 deletions experiments/full_mt/general_200m.pt
1 change: 1 addition & 0 deletions experiments/full_mt/general_300m.pt
1 change: 1 addition & 0 deletions experiments/full_mt/general_400m.pt
1 change: 1 addition & 0 deletions experiments/full_mt/mtonly_100m.pt
1 change: 1 addition & 0 deletions experiments/full_mt/mtonly_25m.pt
1 change: 1 addition & 0 deletions experiments/full_mt/mtonly_50m.pt
1 change: 1 addition & 0 deletions experiments/full_mt/mtonly_75m.pt
81 changes: 81 additions & 0 deletions experiments/full_mt/score_by_seed.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
policy mode seed count length task_progress weighted_score
general_100m pvp 1 3328 485.417969 0.115763 10.896710
general_100m pvp 17590196 3328 458.769531 0.121940 11.712889
general_100m pvp 2 3328 472.893630 0.121988 10.945641
general_100m pvp 48692979 3328 492.171575 0.115905 11.119063
general_100m pvp 59315038 3328 482.833834 0.116905 10.826042
general_100m pvp 70970695 3328 473.437200 0.107732 11.079143
general_100m pvp 74435116 3328 477.912260 0.119933 10.914331
general_100m pvp 78309306 3328 477.221755 0.114937 10.389002
general_100m pvp 87951085 3328 476.138522 0.110778 10.684735
general_100m pvp 97056894 3328 465.522837 0.110429 9.705141
general_200m pvp 1 3328 539.848257 0.139947 12.805038
general_200m pvp 17590196 3328 545.559796 0.143204 13.091318
general_200m pvp 2 3328 536.153546 0.140734 13.049674
general_200m pvp 48692979 3328 526.338942 0.130081 11.990479
general_200m pvp 59315038 3328 533.338341 0.131956 12.664571
general_200m pvp 70970695 3328 537.543870 0.135842 12.342994
general_200m pvp 74435116 3328 546.435397 0.141714 12.721710
general_200m pvp 78309306 3328 540.789363 0.147472 13.563530
general_200m pvp 87951085 3328 542.996995 0.134235 12.576228
general_200m pvp 97056894 3328 538.321514 0.140068 12.713509
general_300m pvp 1 3328 586.922476 0.138772 13.681132
general_300m pvp 17590196 3328 591.190805 0.149476 13.643711
general_300m pvp 2 3328 588.495793 0.140983 13.714509
general_300m pvp 48692979 3328 586.776743 0.141291 13.870448
general_300m pvp 59315038 3328 583.858774 0.130881 12.397638
general_300m pvp 70970695 3328 587.692909 0.143695 13.458589
general_300m pvp 74435116 3328 582.948618 0.142102 12.950940
general_300m pvp 78309306 3328 595.240084 0.144205 13.517261
general_300m pvp 87951085 3328 584.113281 0.137986 12.871263
general_300m pvp 97056894 3328 586.717248 0.142228 12.274507
general_400m pvp 1 3328 641.062800 0.158408 14.048670
general_400m pvp 17590196 3328 631.736779 0.159788 14.052190
general_400m pvp 2 3328 642.314303 0.154533 14.025575
general_400m pvp 48692979 3328 634.712139 0.155592 14.582579
general_400m pvp 59315038 3328 638.046575 0.156290 14.099436
general_400m pvp 70970695 3328 630.991887 0.145236 13.386643
general_400m pvp 74435116 3328 633.663161 0.149206 14.425600
general_400m pvp 78309306 3328 649.272536 0.147923 14.380862
general_400m pvp 87951085 3328 634.288762 0.150661 13.728941
general_400m pvp 97056894 3328 647.101863 0.154756 13.684646
mtonly_100m pvp 1 3328 570.388822 0.149829 13.863780
mtonly_100m pvp 17590196 3328 562.064303 0.149719 13.478329
mtonly_100m pvp 2 3328 569.804688 0.155215 14.439638
mtonly_100m pvp 48692979 3328 574.843450 0.147685 13.431992
mtonly_100m pvp 59315038 3328 560.079928 0.140461 12.893354
mtonly_100m pvp 70970695 3328 559.964844 0.145116 13.588049
mtonly_100m pvp 74435116 3328 566.014123 0.146146 13.687470
mtonly_100m pvp 78309306 3328 569.761418 0.149076 13.819368
mtonly_100m pvp 87951085 3328 565.923077 0.139602 12.447987
mtonly_100m pvp 97056894 3328 565.586238 0.151044 14.106522
mtonly_25m pvp 1 3328 286.557392 0.081511 6.757035
mtonly_25m pvp 17590196 3328 294.131611 0.088139 7.095756
mtonly_25m pvp 2 3328 296.353365 0.092261 8.262692
mtonly_25m pvp 48692979 3328 295.110877 0.089124 7.094151
mtonly_25m pvp 59315038 3328 295.194712 0.084302 7.750465
mtonly_25m pvp 70970695 3328 282.779447 0.082892 7.430993
mtonly_25m pvp 74435116 3328 296.260517 0.088712 7.122447
mtonly_25m pvp 78309306 3328 282.268630 0.080598 7.305146
mtonly_25m pvp 87951085 3328 281.790565 0.084767 7.557862
mtonly_25m pvp 97056894 3328 287.589543 0.086848 7.185161
mtonly_50m pvp 1 3328 444.121394 0.110977 10.322101
mtonly_50m pvp 17590196 3328 460.631911 0.115711 10.504861
mtonly_50m pvp 2 3328 450.344952 0.124788 11.644909
mtonly_50m pvp 48692979 3328 449.207933 0.115114 10.659924
mtonly_50m pvp 59315038 3328 453.880409 0.115535 10.688714
mtonly_50m pvp 70970695 3328 456.128305 0.111703 10.935598
mtonly_50m pvp 74435116 3328 446.447716 0.117590 10.424055
mtonly_50m pvp 78309306 3328 443.463041 0.118123 11.175272
mtonly_50m pvp 87951085 3328 442.167969 0.112586 9.656979
mtonly_50m pvp 97056894 3328 453.640024 0.115860 10.637423
mtonly_75m pvp 1 3328 491.282151 0.130105 11.729226
mtonly_75m pvp 17590196 3328 499.215445 0.131843 11.945576
mtonly_75m pvp 2 3328 491.967548 0.130017 11.482762
mtonly_75m pvp 48692979 3328 499.919471 0.137459 11.812486
mtonly_75m pvp 59315038 3328 497.466046 0.128371 11.230510
mtonly_75m pvp 70970695 3328 496.729267 0.122593 11.319125
mtonly_75m pvp 74435116 3328 494.525240 0.124172 11.820290
mtonly_75m pvp 78309306 3328 493.586839 0.127388 11.466869
mtonly_75m pvp 87951085 3328 496.398438 0.121738 11.003502
mtonly_75m pvp 97056894 3328 498.067909 0.121682 11.521502
Loading

0 comments on commit f046b30

Please sign in to comment.