Skip to content

Commit

Permalink
feat: restrucuring folder
Browse files Browse the repository at this point in the history
  • Loading branch information
RuanJohn committed Aug 28, 2024
1 parent b85d822 commit 2a857cc
Show file tree
Hide file tree
Showing 22 changed files with 607 additions and 276 deletions.
166 changes: 166 additions & 0 deletions data_plotting_scripts/ablation_plotting/timestep_ablation_rware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# python3
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import matplotlib.pyplot as plt

from marl_eval.plotting_tools.plotting import (
aggregate_scores,
performance_profiles,
plot_single_task,
sample_efficiency_curves,
)
from marl_eval.utils.data_processing_utils import (
create_matrices_for_rliable,
data_process_pipeline,
)

base_folder_name = "retmat-timestep-chunkwise-ablation"

ENV_NAME = "RobotWarehouse"
SAVE_PDF = False

data_dir = f"data/{base_folder_name}/merged_data/metrics.json"
png_plot_dir = f"plots/{base_folder_name}/rware_timestep_ablation/png/"
pdf_plot_dir = f"plots/{base_folder_name}/rware_timestep_ablation/pdf/"

legend_map = {
"retmat_memory-4": "RetMAT-4",
"retmat_memory-8": "RetMAT-8",
"retmat_memory-16": "RetMAT-16",
"retmat_memory-32": "RetMAT-32",
"retmat_memory-64": "RetMAT-64",
"retmat_memory-128": "RetMAT-128",
}

##############################
# Read in and process data
##############################
METRICS_TO_NORMALIZE = ["mean_episode_return"]

with open(data_dir) as f:
raw_data = json.load(f)

processed_data = data_process_pipeline(
raw_data=raw_data, metrics_to_normalize=METRICS_TO_NORMALIZE
)

environment_comparison_matrix, sample_effeciency_matrix = create_matrices_for_rliable(
data_dictionary=processed_data,
environment_name=ENV_NAME,
metrics_to_normalize=METRICS_TO_NORMALIZE,
)

# Create folder for storing plots
if not os.path.exists(png_plot_dir):
os.makedirs(png_plot_dir)
if not os.path.exists(pdf_plot_dir):
os.makedirs(pdf_plot_dir)


##############################
# Probability of improvement
# Aggregate scores
# Performance profiles
##############################

# aggregate scores
fig, _, _ = aggregate_scores( # type: ignore
environment_comparison_matrix,
metric_name="mean_episode_return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
save_tabular_as_latex=True,
legend_map=legend_map,
tabular_results_file_path=f"{png_plot_dir[:-4]}aggregated_score",
)
fig.figure.savefig(f"{png_plot_dir}aggregate_scores.png", bbox_inches="tight")
if SAVE_PDF:
fig.figure.savefig(f"{pdf_plot_dir}aggregate_scores.pdf", bbox_inches="tight")

# performance profiles
fig = performance_profiles(
environment_comparison_matrix,
metric_name="mean_episode_return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=legend_map,
)
fig.figure.savefig(f"{png_plot_dir}performance_profile.png", bbox_inches="tight")
if SAVE_PDF:
fig.figure.savefig(f"{pdf_plot_dir}performance_profile.pdf", bbox_inches="tight")


##############################
# Plot episode return data
##############################

# This should not be done with normalized data

METRICS_TO_NORMALIZE = []

with open(data_dir) as f:
raw_data = json.load(f)

processed_data = data_process_pipeline(
raw_data=raw_data, metrics_to_normalize=METRICS_TO_NORMALIZE
)

environment_comparison_matrix, sample_effeciency_matrix = create_matrices_for_rliable(
data_dictionary=processed_data,
environment_name=ENV_NAME,
metrics_to_normalize=METRICS_TO_NORMALIZE,
)

# Get all tasks
tasks = list(processed_data[ENV_NAME.lower()].keys())

# Aggregate data over a single tasks
for task in tasks:
fig = plot_single_task(
processed_data=processed_data,
environment_name=ENV_NAME,
task_name=task,
metric_name="mean_episode_return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=legend_map,
)

fig.figure.savefig(
f"{png_plot_dir}rware_{task}_agg_return.png", bbox_inches="tight"
)
if SAVE_PDF:
fig.figure.savefig(
f"{pdf_plot_dir}rware_{task}_agg_return.pdf", bbox_inches="tight"
)

# Close the figure object
plt.close(fig.figure)

# Aggregate data over all environment tasks.

fig, _, _ = sample_efficiency_curves( # type: ignore
sample_effeciency_matrix,
metric_name="mean_episode_return",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=legend_map,
)
fig.figure.savefig(
f"{png_plot_dir}return_sample_effeciency_curve.png", bbox_inches="tight"
)
if SAVE_PDF:
fig.figure.savefig(
f"{pdf_plot_dir}return_sample_effeciency_curve.pdf", bbox_inches="tight"
)
164 changes: 164 additions & 0 deletions data_plotting_scripts/ablation_plotting/timestep_ablation_smax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# python3
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import matplotlib.pyplot as plt

from marl_eval.plotting_tools.plotting import (
aggregate_scores,
performance_profiles,
plot_single_task,
sample_efficiency_curves,
)
from marl_eval.utils.data_processing_utils import (
create_matrices_for_rliable,
data_process_pipeline,
)

base_folder_name = "retmat-timestep-chunkwise-ablation"

ENV_NAME = "Smax"
SAVE_PDF = False

data_dir = f"data/{base_folder_name}/merged_data/metrics.json"
png_plot_dir = f"plots/{base_folder_name}/smax_timestep_ablation/png/"
pdf_plot_dir = f"plots/{base_folder_name}/smax_timestep_ablation/pdf/"

legend_map = {
"retmat_memory-4": "RetMAT-4",
"retmat_memory-8": "RetMAT-8",
"retmat_memory-16": "RetMAT-16",
"retmat_memory-32": "RetMAT-32",
"retmat_memory-64": "RetMAT-64",
"retmat_memory-128": "RetMAT-128",
}

##############################
# Read in and process data
##############################
METRICS_TO_NORMALIZE = ["mean_episode_return"]

with open(data_dir) as f:
raw_data = json.load(f)

processed_data = data_process_pipeline(
raw_data=raw_data, metrics_to_normalize=METRICS_TO_NORMALIZE
)

environment_comparison_matrix, sample_effeciency_matrix = create_matrices_for_rliable(
data_dictionary=processed_data,
environment_name=ENV_NAME,
metrics_to_normalize=METRICS_TO_NORMALIZE,
)

# Create folder for storing plots
if not os.path.exists(png_plot_dir):
os.makedirs(png_plot_dir)
if not os.path.exists(pdf_plot_dir):
os.makedirs(pdf_plot_dir)


##############################
# Probability of improvement
# Aggregate scores
# Performance profiles
##############################

# aggregate scores
fig, _, _ = aggregate_scores( # type: ignore
environment_comparison_matrix,
metric_name="win_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
save_tabular_as_latex=True,
legend_map=legend_map,
tabular_results_file_path=f"{png_plot_dir[:-4]}aggregated_score",
)
fig.figure.savefig(f"{png_plot_dir}aggregate_scores.png", bbox_inches="tight")
if SAVE_PDF:
fig.figure.savefig(f"{pdf_plot_dir}aggregate_scores.pdf", bbox_inches="tight")

# performance profiles
fig = performance_profiles(
environment_comparison_matrix,
metric_name="win_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=legend_map,
)
fig.figure.savefig(f"{png_plot_dir}performance_profile.png", bbox_inches="tight")
if SAVE_PDF:
fig.figure.savefig(f"{pdf_plot_dir}performance_profile.pdf", bbox_inches="tight")


##############################
# Plot episode return data
##############################

# This should not be done with normalized data

METRICS_TO_NORMALIZE = []

with open(data_dir) as f:
raw_data = json.load(f)

processed_data = data_process_pipeline(
raw_data=raw_data, metrics_to_normalize=METRICS_TO_NORMALIZE
)

environment_comparison_matrix, sample_effeciency_matrix = create_matrices_for_rliable(
data_dictionary=processed_data,
environment_name=ENV_NAME,
metrics_to_normalize=METRICS_TO_NORMALIZE,
)

# Get all tasks
tasks = list(processed_data[ENV_NAME.lower()].keys())

# Aggregate data over a single tasks
for task in tasks:
fig = plot_single_task(
processed_data=processed_data,
environment_name=ENV_NAME,
task_name=task,
metric_name="win_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=legend_map,
)

fig.figure.savefig(f"{png_plot_dir}_{task}_agg_win_rate.png", bbox_inches="tight")
if SAVE_PDF:
fig.figure.savefig(
f"{pdf_plot_dir}_{task}_agg_win_rate.pdf", bbox_inches="tight"
)

# Close the figure object
plt.close(fig.figure)

# Aggregate data over all environment tasks.

fig, _, _ = sample_efficiency_curves( # type: ignore
sample_effeciency_matrix,
metric_name="win_rate",
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=legend_map,
)
fig.figure.savefig(
f"{png_plot_dir}win_rate_sample_effeciency_curve.png", bbox_inches="tight"
)
if SAVE_PDF:
fig.figure.savefig(
f"{pdf_plot_dir}win_rate_sample_effeciency_curve.pdf", bbox_inches="tight"
)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
ENV_NAME = "RobotWarehouse"
SAVE_PDF = False

data_dir = f"data/{base_folder_name}/merged_data/metrics_seed_processed.json"
png_plot_dir = f"plots/{base_folder_name}/rware_no_retmat/png/"
pdf_plot_dir = f"plots/{base_folder_name}/rware_no_retmat/pdf/"
data_dir = f"data/{base_folder_name}/merged_data/metrics.json"
png_plot_dir = f"plots/{base_folder_name}/rware_timestep_ablation/png/"
pdf_plot_dir = f"plots/{base_folder_name}/rware_timestep_ablation/pdf/"

legend_map = {
"rec_mappo": "Rec MAPPO",
Expand All @@ -49,8 +49,6 @@
"retmat_memory": "RetMAT Memory",
"ff_happo": "FF HAPPO",
"rec_happo": "Rec HAPPO",
# "retmat_main_memory": "RetMAT Main Memory",
# "retmat_yarn_memory": "RetMAT Yarn Memory",
}

##############################
Expand Down
File renamed without changes.
14 changes: 4 additions & 10 deletions data_processing_scripts/keep_certain_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,12 @@ def filter_json(data, tasks_to_keep):
return filtered_data


base_folder_name = "biggest-benchmark-sac"
base_folder_name = "retmat-timestep-chunkwise-ablation"

# Example usage:
input_file = f"./data/{base_folder_name}/merged_data/metrics_winrate_processed.json"
output_file = f"./data/{base_folder_name}/merged_data/metrics_seed_processed.json"
tasks_to_keep = [
"hopper_3x1",
"halfcheetah_6x1",
"walker2d_2x3",
"ant_4x2",
# "humanoid_9|8",
] # Replace with your list of tasks to keep
input_file = f"./data/{base_folder_name}/merged_data/metrics.json"
output_file = f"./data/{base_folder_name}/merged_data/metrics.json"
tasks_to_keep = ["tiny-4ag", "medium-4ag"] # Replace with your list of tasks to keep

# Read the input JSON file
with open(input_file, "r") as f:
Expand Down
4 changes: 2 additions & 2 deletions data_processing_scripts/process_env_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def main(json_filename, new_json_filename):
save_json(new_json_filename, data)


base_folder_name = "biggest-benchmark-sac"
base_folder_name = "retmat-timestep-chunkwise-ablation"

# Replace 'your_file.json' with your actual JSON file name
json_filename = f"./data/{base_folder_name}/merged_data/metrics.json"
new_json_filename = f"./data/{base_folder_name}/merged_data/metrics_name_processed.json"
new_json_filename = f"./data/{base_folder_name}/merged_data/metrics.json"
main(json_filename, new_json_filename)
Loading

0 comments on commit 2a857cc

Please sign in to comment.