-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
164 changed files
with
6,005 additions
and
14 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# %% | ||
import sys | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
sys.path.append("../") | ||
import proc_elo | ||
import proc_task_eval | ||
|
||
|
||
# Figure style-related | ||
TITLE_FONT = 24 | ||
LABEL_FONT = 18 | ||
TICK_FONT = 11 | ||
ELO_SAMPLE_TICKS = list(range(0, 150, 25)) | ||
ELO_SAMPLE_LABELS = ["0", "25M", "50M", "75M", "100M", "125M"] | ||
ELO_YTIKCS = list(range(850, 1101, 50)) | ||
|
||
MT_SAMPLE_TICKS = list(range(0, 200, 50)) | ||
MT_SAMPLE_LABELS = ["0", "50M", "100M", "150M"] | ||
MT_YTICKS = [y / 100 for y in range(0, 21, 4)] | ||
|
||
# blue-ish: #4c72b0, orange: #dd8452 | ||
SPECIALIST_MARKER = { | ||
"marker": "o", | ||
"s": 100, | ||
"c": "white", | ||
"edgecolors": "#4c72b0", | ||
"label": "Specialist", | ||
"linewidths": 3, | ||
} | ||
GENERALIST_MARKER = {"marker": "o", "s": 120, "c": "#dd8452", "label": "Generalist"} | ||
|
||
# Manually extract sampling ratio for each minigame, after running average smoothing with 100 | ||
# https://wandb.ai/kywch/meta-mmo/runs/lf95vvxr, see stats/Sampling/{minigame}_agent_steps | ||
generalist_sample_ratio = { | ||
"svonly": [0.288, 0.290, 0.293, 0.291], # Survival | ||
"tbonly": [0.305, 0.304, 0.299, 0.296], # Team Battle | ||
"mtonly": [0.407, 0.406, 0.408, 0.412], # Multi-task Traiing | ||
} | ||
|
||
|
||
def process_eval_data( | ||
data, specialist_prefix, sample_ratio, value_key="elo", generalist_prefix="general" | ||
): | ||
results = {"generalist": {"steps": [], "values": []}, "specialist": {"steps": [], "values": []}} | ||
for pol_data in data: | ||
pol_info = pol_data["policy"].split("_") | ||
assert pol_info[0] in [specialist_prefix, generalist_prefix], "" | ||
pol_type = "generalist" if pol_info[0] == generalist_prefix else "specialist" | ||
results[pol_type]["values"].append(pol_data[value_key]) | ||
results[pol_type]["steps"].append(int(pol_info[1][:-1])) | ||
|
||
# Apply sampling ratio to correct the steps | ||
gen_idx = np.argsort(results["generalist"]["steps"]) | ||
for ii, idx in enumerate(gen_idx): | ||
results["generalist"]["steps"][ii] *= sample_ratio[idx] | ||
|
||
return results | ||
|
||
|
||
if __name__ == "__main__": | ||
# Data for each minigame | ||
survive_elo = process_eval_data( | ||
proc_elo.process_eval_files("full_sv", "survive"), | ||
"svonly", | ||
generalist_sample_ratio["svonly"], | ||
) | ||
|
||
battle_elo = process_eval_data( | ||
proc_elo.process_eval_files("full_tb", "battle"), | ||
"tbonly", | ||
generalist_sample_ratio["tbonly"], | ||
) | ||
|
||
task_data = proc_task_eval.process_eval_files("full_mt", "curriculum").to_dicts() | ||
task_progress = process_eval_data( | ||
task_data, "mtonly", generalist_sample_ratio["mtonly"], value_key="task_progress" | ||
) | ||
|
||
# Create subplots with specified figure size | ||
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4)) | ||
|
||
# Remove top and right edges for each subplot | ||
for ax in [ax1, ax2, ax3]: | ||
ax.spines["top"].set_visible(False) | ||
ax.spines["right"].set_visible(False) | ||
ax.tick_params(axis="both", which="major", length=6, labelsize=TICK_FONT) | ||
|
||
# Marker styles for generalist and specialist | ||
|
||
# Survival subplot | ||
ax1.set_title("Survival", fontsize=TITLE_FONT) | ||
ax1.set_ylabel("Elo", fontsize=TITLE_FONT) | ||
ax1.set_xticks(ELO_SAMPLE_TICKS) | ||
ax1.set_xticklabels(ELO_SAMPLE_LABELS, fontsize=TICK_FONT) | ||
ax1.set_xlim(0, 125) | ||
ax1.set_yticks(ELO_YTIKCS) | ||
ax1.set_ylim(850, 1100) | ||
ax1.plot([0, 150], [1000, 1000], "k:") # Anchor line for ELO | ||
ax1.scatter( | ||
survive_elo["generalist"]["steps"], survive_elo["generalist"]["values"], **GENERALIST_MARKER | ||
) | ||
ax1.scatter( | ||
survive_elo["specialist"]["steps"], survive_elo["specialist"]["values"], **SPECIALIST_MARKER | ||
) | ||
ax1.legend(loc="lower right", fontsize=13) | ||
|
||
# Team Battle subplot | ||
ax2.set_title("Team Battle", fontsize=TITLE_FONT) | ||
ax2.set_xlabel("Training samples", fontsize=TITLE_FONT, labelpad=13) | ||
ax2.set_xticks(ELO_SAMPLE_TICKS) | ||
ax2.set_xticklabels(ELO_SAMPLE_LABELS, fontsize=TICK_FONT) | ||
ax2.set_xlim(0, 125) | ||
ax2.set_yticks(ELO_YTIKCS) | ||
ax2.set_ylim(850, 1100) | ||
ax2.plot([0, 150], [1000, 1000], "k:") # Anchor line for ELO | ||
ax2.scatter( | ||
battle_elo["specialist"]["steps"], battle_elo["specialist"]["values"], **SPECIALIST_MARKER | ||
) | ||
ax2.scatter( | ||
battle_elo["generalist"]["steps"], battle_elo["generalist"]["values"], **GENERALIST_MARKER | ||
) | ||
|
||
# Multi-task Eval subplot | ||
ax3.set_title("Multi-task Eval", fontsize=TITLE_FONT) | ||
ax3.set_ylabel("Task progress", fontsize=TITLE_FONT) | ||
ax3.set_xticks(MT_SAMPLE_TICKS) | ||
ax3.set_xticklabels(MT_SAMPLE_LABELS, fontsize=TICK_FONT) | ||
ax3.set_xlim(0, 175) | ||
ax3.set_yticks(MT_YTICKS) | ||
ax3.set_ylim(0, 0.2) | ||
ax3.scatter( | ||
task_progress["specialist"]["steps"], | ||
task_progress["specialist"]["values"], | ||
**SPECIALIST_MARKER, | ||
) | ||
ax3.scatter( | ||
task_progress["generalist"]["steps"], | ||
task_progress["generalist"]["values"], | ||
**GENERALIST_MARKER, | ||
) | ||
|
||
# Adjust spacing between subplots | ||
plt.tight_layout() | ||
plt.subplots_adjust(wspace=0.5) | ||
|
||
# Save the figure as a PNG file with specified size | ||
plt.savefig("fig_4.png", dpi=300, bbox_inches="tight") | ||
|
||
# Display the figure | ||
plt.show() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# %% | ||
import sys | ||
import matplotlib.pyplot as plt | ||
|
||
sys.path.append("../") | ||
import proc_elo | ||
from fig_4 import process_eval_data | ||
|
||
|
||
# Figure style-related | ||
TITLE_FONT = 26 | ||
TICK_FONT = 12 | ||
ELO_SAMPLE_TICKS = list(range(0, 150, 25)) | ||
ELO_SAMPLE_LABEL = ["0", "25M", "50M", "75M", "100M", "125M"] | ||
ELO_YTIKCS = list(range(850, 1101, 50)) | ||
|
||
# blue-ish: #4c72b0, orange: #dd8452 | ||
SPECIALIST_MARKER = { | ||
"marker": "o", | ||
"s": 100, | ||
"c": "white", | ||
"edgecolors": "#4c72b0", | ||
"label": "Specialist", | ||
"linewidths": 3, | ||
} | ||
GENERALIST_MARKER = {"marker": "o", "s": 120, "c": "#dd8452", "label": "Generalist"} | ||
|
||
# Manually extract sampling ratio for each minigame, after running average smoothing with 100 | ||
# https://wandb.ai/kywch/meta-mmo/runs/53n3yvnj, see stats/Sampling/{minigame}_agent_steps | ||
generalist_sample_ratio = { | ||
"battle": [0.218, 0.261, 0.274, 0.280], # Team Battle | ||
"ptk": [0.153, 0.188, 0.201, 0.217], # Protect the King | ||
"race": [0.164, 0.120, 0.100, 0.097], # Race to the Center | ||
"koh": [0.228, 0.231, 0.223, 0.210], # King of the Hill | ||
"sandwich": [0.237, 0.200, 0.202, 0.196], # Sandwich | ||
} | ||
|
||
minigame_info = { | ||
"battle": ("Team Battle", "mini_tb", "tbonly"), # name, directory, prefix | ||
"ptk": ("Protect the King", "mini_pk", "pkonly"), | ||
"race": ("Race to the Center", "mini_rc", "rconly"), | ||
"koh": ("King of the Hill", "mini_kh", "khonly"), | ||
"sandwich": ("Sandwich", "mini_sw", "swonly"), | ||
} | ||
|
||
if __name__ == "__main__": | ||
# Data for each minigame | ||
elo_data = { | ||
game: process_eval_data( | ||
proc_elo.process_eval_files(dir, game), prefix, generalist_sample_ratio[game] | ||
) | ||
for game, (_, dir, prefix) in minigame_info.items() | ||
} | ||
|
||
# Create subplots with specified figure size | ||
fig, axes = plt.subplots(1, 5, figsize=(20, 4)) | ||
|
||
# Remove top and right edges for each subplot | ||
for ax, game in zip(axes, minigame_info.keys()): | ||
ax.spines["top"].set_visible(False) | ||
ax.spines["right"].set_visible(False) | ||
ax.tick_params(axis="both", which="major", length=6, labelsize=TICK_FONT) | ||
|
||
ax.set_title(minigame_info[game][0], fontsize=TITLE_FONT) | ||
ax.set_xticks(ELO_SAMPLE_TICKS) | ||
ax.set_xticklabels(ELO_SAMPLE_LABEL, fontsize=TICK_FONT) | ||
ax.set_xlim(0, 105) | ||
if max(elo_data[game]["generalist"]["steps"]) > 100: | ||
ax.set_xlim(0, 125) | ||
ax.set_yticks(ELO_YTIKCS) | ||
ax.set_ylim(880, 1100) | ||
ax.plot([0, 150], [1000, 1000], "k:") # Anchor line for ELO | ||
ax.scatter( | ||
elo_data[game]["generalist"]["steps"], | ||
elo_data[game]["generalist"]["values"], | ||
**GENERALIST_MARKER, | ||
) | ||
ax.scatter( | ||
elo_data[game]["specialist"]["steps"], | ||
elo_data[game]["specialist"]["values"], | ||
**SPECIALIST_MARKER, | ||
) | ||
|
||
axes[0].set_ylabel("Elo", fontsize=TITLE_FONT + 2) | ||
axes[0].legend(loc="lower right", fontsize=15) | ||
axes[2].set_xlabel("Training samples", fontsize=TITLE_FONT + 6, labelpad=15) | ||
|
||
# Adjust spacing between subplots | ||
plt.tight_layout() | ||
plt.subplots_adjust(wspace=0.3) | ||
|
||
# Save the figure as a PNG file with specified size | ||
plt.savefig("fig_6.png", dpi=300, bbox_inches="tight") | ||
|
||
# Display the figure | ||
plt.show() |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../policies/full_gen/model_003052.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../policies/full_gen/model_006104.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../policies/full_gen/model_009156.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../policies/full_gen/model_012207.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../policies/full_mt/model_003051.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../policies/full_mt/model_000763.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../policies/full_mt/model_001526.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../policies/full_mt/model_002289.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
policy mode seed count length task_progress weighted_score | ||
general_100m pvp 1 3328 485.417969 0.115763 10.896710 | ||
general_100m pvp 17590196 3328 458.769531 0.121940 11.712889 | ||
general_100m pvp 2 3328 472.893630 0.121988 10.945641 | ||
general_100m pvp 48692979 3328 492.171575 0.115905 11.119063 | ||
general_100m pvp 59315038 3328 482.833834 0.116905 10.826042 | ||
general_100m pvp 70970695 3328 473.437200 0.107732 11.079143 | ||
general_100m pvp 74435116 3328 477.912260 0.119933 10.914331 | ||
general_100m pvp 78309306 3328 477.221755 0.114937 10.389002 | ||
general_100m pvp 87951085 3328 476.138522 0.110778 10.684735 | ||
general_100m pvp 97056894 3328 465.522837 0.110429 9.705141 | ||
general_200m pvp 1 3328 539.848257 0.139947 12.805038 | ||
general_200m pvp 17590196 3328 545.559796 0.143204 13.091318 | ||
general_200m pvp 2 3328 536.153546 0.140734 13.049674 | ||
general_200m pvp 48692979 3328 526.338942 0.130081 11.990479 | ||
general_200m pvp 59315038 3328 533.338341 0.131956 12.664571 | ||
general_200m pvp 70970695 3328 537.543870 0.135842 12.342994 | ||
general_200m pvp 74435116 3328 546.435397 0.141714 12.721710 | ||
general_200m pvp 78309306 3328 540.789363 0.147472 13.563530 | ||
general_200m pvp 87951085 3328 542.996995 0.134235 12.576228 | ||
general_200m pvp 97056894 3328 538.321514 0.140068 12.713509 | ||
general_300m pvp 1 3328 586.922476 0.138772 13.681132 | ||
general_300m pvp 17590196 3328 591.190805 0.149476 13.643711 | ||
general_300m pvp 2 3328 588.495793 0.140983 13.714509 | ||
general_300m pvp 48692979 3328 586.776743 0.141291 13.870448 | ||
general_300m pvp 59315038 3328 583.858774 0.130881 12.397638 | ||
general_300m pvp 70970695 3328 587.692909 0.143695 13.458589 | ||
general_300m pvp 74435116 3328 582.948618 0.142102 12.950940 | ||
general_300m pvp 78309306 3328 595.240084 0.144205 13.517261 | ||
general_300m pvp 87951085 3328 584.113281 0.137986 12.871263 | ||
general_300m pvp 97056894 3328 586.717248 0.142228 12.274507 | ||
general_400m pvp 1 3328 641.062800 0.158408 14.048670 | ||
general_400m pvp 17590196 3328 631.736779 0.159788 14.052190 | ||
general_400m pvp 2 3328 642.314303 0.154533 14.025575 | ||
general_400m pvp 48692979 3328 634.712139 0.155592 14.582579 | ||
general_400m pvp 59315038 3328 638.046575 0.156290 14.099436 | ||
general_400m pvp 70970695 3328 630.991887 0.145236 13.386643 | ||
general_400m pvp 74435116 3328 633.663161 0.149206 14.425600 | ||
general_400m pvp 78309306 3328 649.272536 0.147923 14.380862 | ||
general_400m pvp 87951085 3328 634.288762 0.150661 13.728941 | ||
general_400m pvp 97056894 3328 647.101863 0.154756 13.684646 | ||
mtonly_100m pvp 1 3328 570.388822 0.149829 13.863780 | ||
mtonly_100m pvp 17590196 3328 562.064303 0.149719 13.478329 | ||
mtonly_100m pvp 2 3328 569.804688 0.155215 14.439638 | ||
mtonly_100m pvp 48692979 3328 574.843450 0.147685 13.431992 | ||
mtonly_100m pvp 59315038 3328 560.079928 0.140461 12.893354 | ||
mtonly_100m pvp 70970695 3328 559.964844 0.145116 13.588049 | ||
mtonly_100m pvp 74435116 3328 566.014123 0.146146 13.687470 | ||
mtonly_100m pvp 78309306 3328 569.761418 0.149076 13.819368 | ||
mtonly_100m pvp 87951085 3328 565.923077 0.139602 12.447987 | ||
mtonly_100m pvp 97056894 3328 565.586238 0.151044 14.106522 | ||
mtonly_25m pvp 1 3328 286.557392 0.081511 6.757035 | ||
mtonly_25m pvp 17590196 3328 294.131611 0.088139 7.095756 | ||
mtonly_25m pvp 2 3328 296.353365 0.092261 8.262692 | ||
mtonly_25m pvp 48692979 3328 295.110877 0.089124 7.094151 | ||
mtonly_25m pvp 59315038 3328 295.194712 0.084302 7.750465 | ||
mtonly_25m pvp 70970695 3328 282.779447 0.082892 7.430993 | ||
mtonly_25m pvp 74435116 3328 296.260517 0.088712 7.122447 | ||
mtonly_25m pvp 78309306 3328 282.268630 0.080598 7.305146 | ||
mtonly_25m pvp 87951085 3328 281.790565 0.084767 7.557862 | ||
mtonly_25m pvp 97056894 3328 287.589543 0.086848 7.185161 | ||
mtonly_50m pvp 1 3328 444.121394 0.110977 10.322101 | ||
mtonly_50m pvp 17590196 3328 460.631911 0.115711 10.504861 | ||
mtonly_50m pvp 2 3328 450.344952 0.124788 11.644909 | ||
mtonly_50m pvp 48692979 3328 449.207933 0.115114 10.659924 | ||
mtonly_50m pvp 59315038 3328 453.880409 0.115535 10.688714 | ||
mtonly_50m pvp 70970695 3328 456.128305 0.111703 10.935598 | ||
mtonly_50m pvp 74435116 3328 446.447716 0.117590 10.424055 | ||
mtonly_50m pvp 78309306 3328 443.463041 0.118123 11.175272 | ||
mtonly_50m pvp 87951085 3328 442.167969 0.112586 9.656979 | ||
mtonly_50m pvp 97056894 3328 453.640024 0.115860 10.637423 | ||
mtonly_75m pvp 1 3328 491.282151 0.130105 11.729226 | ||
mtonly_75m pvp 17590196 3328 499.215445 0.131843 11.945576 | ||
mtonly_75m pvp 2 3328 491.967548 0.130017 11.482762 | ||
mtonly_75m pvp 48692979 3328 499.919471 0.137459 11.812486 | ||
mtonly_75m pvp 59315038 3328 497.466046 0.128371 11.230510 | ||
mtonly_75m pvp 70970695 3328 496.729267 0.122593 11.319125 | ||
mtonly_75m pvp 74435116 3328 494.525240 0.124172 11.820290 | ||
mtonly_75m pvp 78309306 3328 493.586839 0.127388 11.466869 | ||
mtonly_75m pvp 87951085 3328 496.398438 0.121738 11.003502 | ||
mtonly_75m pvp 97056894 3328 498.067909 0.121682 11.521502 |
Oops, something went wrong.