Skip to content

Commit

Permalink
feat: enhance plotting functions with customizable colors and improve…
Browse files Browse the repository at this point in the history
…d layout
  • Loading branch information
WiemKhlifi committed Jan 13, 2025
1 parent 5675ad6 commit 8361055
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 137 deletions.
203 changes: 78 additions & 125 deletions examples/simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,26 @@
create_matrices_for_rliable,
data_process_pipeline,
)

##############################
# Read in and process data
##############################
METRICS_TO_NORMALIZE = ["return"]
LEGEND_MAP = {
"algo_1": "Algorithm 1",
"algo_2": "Algorithm 2",
"algo_3": "Algorithm 3",
"algo_4": "Algorithm 4",
"algo_5": "Algorithm 5",
from marl_eval.json_tools.json_utils import concatenate_json_files, pull_neptune_data
import matplotlib.pyplot as plt
import colorcet as cc
import seaborn as sns
import numpy as np

METRICS_TO_NORMALIZE = []
metric_name="episode_return"
legend_map = {
"rec_mappo": "MAPPO",
"ff_ippo": "IPPO",
"rec_qmix": "QMIX",
"rec_iql": "IQL"
}

with open("examples/example_results.json") as f:
algorithms = list(legend_map.values())
colors = dict(zip(algorithms, sns.color_palette(cc.glasbey_category10)))

env_name = "LevelBasedForaging"
with open("concatenated_json_files/metrics.json") as f:
raw_data = json.load(f)

processed_data = data_process_pipeline(
Expand All @@ -49,7 +55,7 @@

environment_comparison_matrix, sample_effeciency_matrix = create_matrices_for_rliable(
data_dictionary=processed_data,
environment_name="env_1",
environment_name=env_name,
metrics_to_normalize=METRICS_TO_NORMALIZE,
)

Expand All @@ -60,127 +66,74 @@
##############################
# Plot success rate data
##############################

# Aggregate data over a single task.

task = "task_1"
fig = plot_single_task(
processed_data=processed_data,
environment_name="env_1",
task_name=task,
metric_name="success_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)

fig.figure.savefig(
f"examples/plots/env_1_{task}_agg_success_rate.png", bbox_inches="tight"
)

# Aggregate data over all environment tasks.

fig = performance_profiles(
environment_comparison_matrix,
metric_name="success_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/success_rate_performance_profile.png", bbox_inches="tight"
)

fig, _, _ = aggregate_scores( # type: ignore
environment_comparison_matrix,
metric_name="success_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
save_tabular_as_latex=True,
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/success_rate_aggregate_scores.png", bbox_inches="tight"
)

fig = probability_of_improvement(
environment_comparison_matrix,
metric_name="success_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
algorithms_to_compare=[
["algo_1", "algo_2"],
["algo_1", "algo_3"],
["algo_2", "algo_4"],
],
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/success_rate_prob_of_improvement.png", bbox_inches="tight"
)
# # Aggregate data over all environment tasks.
# fig = performance_profiles(
# environment_comparison_matrix,
# metric_name=metric_name,
# metrics_to_normalize=METRICS_TO_NORMALIZE,
# legend_map=legend_map,
# colors=colors,
# )
# # plt.legend(loc='lower center', bbox_to_anchor=(0.5, 0.9), prop={'size': 17}, ncol=5, bbox_transform=plt.gcf().transFigure, borderaxespad=0.2, frameon=True)
# fig.figure.savefig("examples/plots/return_performance_profile.pdf", bbox_inches="tight")

# fig, _, _ = aggregate_scores( # type: ignore
# environment_comparison_matrix,
# metric_name=metric_name,
# metrics_to_normalize=METRICS_TO_NORMALIZE,
# save_tabular_as_latex=True,
# legend_map=legend_map,
# )
# fig.figure.savefig( "examples/plots/return_aggregate_scores.pdf", bbox_inches="tight")

# fig = probability_of_improvement(
# environment_comparison_matrix,
# metric_name=metric_name,
# metrics_to_normalize=METRICS_TO_NORMALIZE,
# algorithms_to_compare=[
# ["ff_mappo", "ff_ippo"],
# ["rec_mappo", "rec_ippo"],
# ["ff_mappo", "rec_mappo"],
# ["ff_ippo", "rec_ippo"],
# ],
# legend_map=legend_map,
# )
# fig.figure.savefig("examples/plots/return_prob_of_improvement.pdf", bbox_inches="tight")

fig, _, _ = sample_efficiency_curves( # type: ignore
sample_effeciency_matrix,
metric_name="success_rate",
metric_name=metric_name,
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/success_rate_sample_effeciency_curve.png", bbox_inches="tight"
legend_map=legend_map,
colors=colors,
)
# legend = plt.legend(loc='lower center', bbox_to_anchor=(0.5, 0.9), prop={'size': 15}, ncol=7, bbox_transform=plt.gcf().transFigure, borderaxespad=0.2, frameon=True)
# fig = legend.figure
# fig.canvas.draw()
# bbox = legend.get_window_extent()
# bbox = bbox.from_extents(*(bbox.extents + np.array([-4,-4,4,4])))
# bbox = bbox.transformed(fig.dpi_scale_trans.inverted())
# fig.savefig('legend.png', dpi=1200, bbox_inches=bbox)
# plt.hlines(y=2, xmin=0, xmax=4e7, colors='gray', linestyles='--', label='Threshold')
# plt.legend()
fig.figure.savefig(f"examples/plots/{env_name}_sample_effeciency_curve.pdf", bbox_inches="tight")


##############################
# Plot episode return data
##############################

# Aggregate data over a single task

task = "task_1"
fig = plot_single_task(
processed_data=processed_data,
environment_name="env_1",
task_name=task,
metric_name="return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)

fig.figure.savefig(f"examples/plots/env_1_{task}_agg_return.png", bbox_inches="tight")

# Aggregate data over all environment tasks.

fig = performance_profiles(
environment_comparison_matrix,
metric_name="return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)
fig.figure.savefig("examples/plots/return_performance_profile.png", bbox_inches="tight")

fig, _, _ = aggregate_scores( # type: ignore
environment_comparison_matrix,
metric_name="return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
save_tabular_as_latex=True,
legend_map=LEGEND_MAP,
)
fig.figure.savefig("examples/plots/return_aggregate_scores.png", bbox_inches="tight")

fig = probability_of_improvement(
environment_comparison_matrix,
metric_name="return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
algorithms_to_compare=[
["algo_1", "algo_2"],
["algo_1", "algo_3"],
["algo_2", "algo_4"],
],
legend_map=LEGEND_MAP,
)
fig.figure.savefig("examples/plots/return_prob_of_improvement.png", bbox_inches="tight")

fig, _, _ = sample_efficiency_curves( # type: ignore
sample_effeciency_matrix,
metric_name="return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=LEGEND_MAP,
)
fig.figure.savefig(
"examples/plots/return_sample_effeciency_curve.png", bbox_inches="tight"
)
for task in processed_data[env_name.lower()].keys():
fig = plot_single_task(
processed_data=processed_data,
environment_name=env_name,
task_name=task,
metric_name=metric_name,
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=legend_map,
colors=colors,
)

fig.figure.savefig(f"examples/plots/{env_name}_{task}_agg_return.pdf", bbox_inches="tight")
12 changes: 8 additions & 4 deletions marl_eval/plotting_tools/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def plot_single_task_curve(
x_axis_len = len(aggregated_data[algorithm]["mean"])

# Set x-axis values to match evaluation interval steps.
x_axis_values = np.arange(x_axis_len) * extra_info["evaluation_interval"]
# x_axis_values = np.arange(x_axis_len) * extra_info["evaluation_interval"]
# Note: This is hardcoded for now as well.
x_axis_values = np.linspace(0,20, x_axis_len)

if run_times is not None:
x_axis_values = np.linspace(0, run_times[algorithm] / 60, x_axis_len)
Expand All @@ -102,15 +104,17 @@ def plot_single_task_curve(
ax.plot(
x_axis_values,
metric_values,
color=colors[algorithm],
color=colors[algorithm_name],
marker=marker,
linewidth=linewidth,
label=algorithm_name,
)
ax.fill_between(
x_axis_values, y1=lower, y2=upper, color=colors[algorithm], alpha=0.2
x_axis_values, y1=lower, y2=upper, color=colors[algorithm_name], alpha=0.2
)

# plt.hlines(y=0.15, xmin=0, xmax=2e7, colors=colors["MAPPO"], linestyles='--', label='JaxMARL MAPPO')
# plt.hlines(y=0.1, xmin=0, xmax=2e7, colors=colors["IPPO"], linestyles='--', label='JaxMARL IPPO')
# plt.legend()
return _annotate_and_decorate_axis(
ax,
xlabel=xlabel,
Expand Down
Loading

0 comments on commit 8361055

Please sign in to comment.