Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adding a population-training feature for n-1-sp methods #48

Draft
wants to merge 62 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
eb028ea
update generate_hdim_and_seed for population methods
Ttopiac Nov 4, 2024
eb1d84e
Replace aamas25 in population.py by tag.CheckedPoints.FINAL_TRAINED_M…
Ttopiac Nov 4, 2024
2a0febc
Replace the first loop in get_best_SP_agent by a line of code.
Ttopiac Nov 5, 2024
b592e6c
Add two CheckedPoints tags, including CheckedModelPrefix and REWARD_S…
Ttopiac Nov 5, 2024
9bc2e37
Add a function for us to list agent's checked tags and also a test fu…
Ttopiac Nov 6, 2024
4bf7932
rewrite a comment
Ttopiac Nov 6, 2024
a9a874e
update generate_hdim_and_seed for population methods
Ttopiac Nov 4, 2024
e418ebf
Replace aamas25 in population.py by tag.CheckedPoints.FINAL_TRAINED_M…
Ttopiac Nov 4, 2024
b751163
Replace the first loop in get_best_SP_agent by a line of code.
Ttopiac Nov 5, 2024
0d4f99d
Fix merge conflicts
Ttopiac Nov 7, 2024
c159690
Add a function for us to list agent's checked tags and also a test fu…
Ttopiac Nov 6, 2024
1763e96
rewrite a comment
Ttopiac Nov 6, 2024
51c44db
fixed conflicts
Ttopiac Nov 7, 2024
b4cb9e8
Replace the first loop in get_best_SP_agent by a line of code.
Ttopiac Nov 5, 2024
ed9153d
Add two CheckedPoints tags, including CheckedModelPrefix and REWARD_S…
Ttopiac Nov 5, 2024
18d734a
Merge branch 'pop-n-1-sp' of github.com:HIRO-group/multiHRI into pop-…
Ttopiac Nov 7, 2024
1f27549
Replace CheckedPoints by KeyCheckpoints
Ttopiac Nov 7, 2024
706a84c
Fix uncommited merged conflicts in rl.py
Ttopiac Nov 7, 2024
a758251
Fix conflicts in base_agent.py
Ttopiac Nov 7, 2024
a5a1fcb
Replace get_KeyCheckpoints_agents by get_checkedpoints_agents
Ttopiac Nov 7, 2024
1858053
?
Ttopiac Nov 7, 2024
821a46e
Fix train_helper.py
Ttopiac Nov 7, 2024
610eb96
little things
Ttopiac Nov 7, 2024
4d5780b
Add SPN_XSPCKP_HP_TYPE
Ttopiac Nov 7, 2024
9d95ed4
Renmae get_poulation by get_categorized_population
Ttopiac Nov 11, 2024
0b49c5c
Rename save_population by save_categorized_population
Ttopiac Nov 11, 2024
59b9824
Ensure the save the last model with last tag.
Ttopiac Nov 11, 2024
a59aa60
Remove tag from RLAgentTrainer.train_agents
Ttopiac Nov 11, 2024
9dfdea0
use checked_model_name_handler.generate_checked_model
Ttopiac Nov 11, 2024
eee6bf1
Clean checked_model_name_handler.py
Ttopiac Nov 11, 2024
f8797f6
Add new layout and new agents info to evaluate_agents.py and eval_con…
Ttopiac Nov 14, 2024
1391d24
update generate_hdim_and_seed for population methods
Ttopiac Nov 4, 2024
cff8151
Replace aamas25 in population.py by tag.CheckedPoints.FINAL_TRAINED_M…
Ttopiac Nov 4, 2024
d1ebd7f
Replace the first loop in get_best_SP_agent by a line of code.
Ttopiac Nov 5, 2024
a4b1fdf
Fix merge conflicts
Ttopiac Nov 7, 2024
4c0f0d1
Add a function for us to list agent's checked tags and also a test fu…
Ttopiac Nov 6, 2024
48f7afc
rewrite a comment
Ttopiac Nov 6, 2024
25fa6c3
Replace the first loop in get_best_SP_agent by a line of code.
Ttopiac Nov 5, 2024
d5e7263
Add two CheckedPoints tags, including CheckedModelPrefix and REWARD_S…
Ttopiac Nov 5, 2024
40e96ea
update generate_hdim_and_seed for population methods
Ttopiac Nov 4, 2024
060db24
Replace the first loop in get_best_SP_agent by a line of code.
Ttopiac Nov 5, 2024
b66f06b
Add two CheckedPoints tags, including CheckedModelPrefix and REWARD_S…
Ttopiac Nov 5, 2024
f90ccde
Replace CheckedPoints by KeyCheckpoints
Ttopiac Nov 7, 2024
2af78f8
Replace get_KeyCheckpoints_agents by get_checkedpoints_agents
Ttopiac Nov 7, 2024
06c8631
little things
Ttopiac Nov 7, 2024
a8793a7
Add SPN_XSPCKP_HP_TYPE
Ttopiac Nov 7, 2024
4338577
Renmae get_poulation by get_categorized_population
Ttopiac Nov 11, 2024
7d55c9f
Rename save_population by save_categorized_population
Ttopiac Nov 11, 2024
93ed176
Ensure the save the last model with last tag.
Ttopiac Nov 11, 2024
9af6a21
use checked_model_name_handler.generate_checked_model
Ttopiac Nov 11, 2024
64756b9
Clean checked_model_name_handler.py
Ttopiac Nov 11, 2024
c914ddf
Add new layout and new agents info to evaluate_agents.py and eval_con…
Ttopiac Nov 14, 2024
31b1c3c
Merge branch 'pop-n-1-sp' of github.com:HIRO-group/multiHRI into pop-…
Ttopiac Nov 14, 2024
ea87768
Add tag_for_returning_agent input to a call on train_agents()
Ttopiac Nov 14, 2024
b4e535e
Add get_model_path and move get_most_recent_ckeckpints to base_agent.…
Ttopiac Nov 14, 2024
ac0f68a
Add feature to allow user to use generate_hdim_and_seed for not only …
Ttopiac Nov 15, 2024
c824b64
Rename methods for generating SP agent populations to distinguish the…
Ttopiac Nov 15, 2024
6d291e6
Update resume: train agent if no checkpoint exists
Ttopiac Nov 15, 2024
54750c0
Remove a bug in ensure_enough_SP_agents
Ttopiac Nov 16, 2024
bd4a6fa
Increase readability for functions with long arugments.
Ttopiac Nov 18, 2024
92a55bf
Add a MultiSetupTrainer
Ttopiac Nov 20, 2024
8a0ac4f
Modify multi_setup_trainer methods and make them to be the same as wh…
Ttopiac Dec 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 74 additions & 11 deletions oai_agents/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from oai_agents.common.state_encodings import ENCODING_SCHEMES
from oai_agents.common.subtasks import calculate_completed_subtask, get_doable_subtasks, Subtasks
from oai_agents.common.tags import AgentPerformance, TeamType, KeyCheckpoints
from oai_agents.common.checked_model_name_handler import CheckedModelNameHandler
from oai_agents.gym_environments.base_overcooked_env import USEABLE_COUNTERS

from overcooked_ai_py.mdp.overcooked_mdp import Action
Expand All @@ -16,14 +17,15 @@
import numpy as np
import torch as th
import torch.nn as nn
from typing import List, Tuple, Union
from typing import List, Tuple, Union, Optional
import stable_baselines3.common.distributions as sb3_distributions
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
import wandb
import os
import random
import pickle as pkl
import re

class OAIAgent(nn.Module, ABC):
"""
Expand Down Expand Up @@ -464,11 +466,9 @@ def get_agents(self) -> List[OAIAgent]:

def save_agents(self, path: Union[Path, None] = None, tag: Union[str, None] = None):
''' Saves each agent that the trainer is training '''
if not path:
if self.args.exp_dir:
path = self.args.base_dir / 'agent_models' / self.args.exp_dir / self.name
else:
path = self.args.base_dir / 'agent_models'/ self.name
path = path or OAITrainer.get_model_path(base_dir=self.args.base_dir,
exp_folder=self.args.exp_dir,
model_name=self.name)

tag = tag or self.args.exp_name
save_path = path / tag / 'trainer_file'
Expand All @@ -495,11 +495,9 @@ def save_agents(self, path: Union[Path, None] = None, tag: Union[str, None] = No
@staticmethod
def load_agents(args, tag, name: str=None, path: Union[Path, None] = None):
''' Loads each agent that the trainer is training '''
if not path:
if args.exp_dir:
path = args.base_dir / 'agent_models' / args.exp_dir / name
else:
path = args.base_dir / 'agent_models'/ name
path = path or OAITrainer.get_model_path(base_dir=args.base_dir,
exp_folder=args.exp_dir,
model_name=name)

tag = tag or args.exp_name
load_path = path / tag / 'trainer_file'
Expand All @@ -519,3 +517,68 @@ def load_agents(args, tag, name: str=None, path: Union[Path, None] = None):
env_info = pkl.load(f)

return agents, env_info, saved_variables

@staticmethod
def list_agent_checked_tags(args, name: str=None, path: Union[Path, None] = None) -> List[str]:
'''
Lists only tags that start with KeyCheckpoints.CHECKED_MODEL_PREFIX, followed by an integer.
If the integer is greater than 0, it must be followed by KeyCheckpoints.REWARD_SUBSTR and a floating-point number.

Parameters:
- args: Experiment arguments containing base directory info and experiment directory info.
- name: The name of the agent, for which tags should be listed.
- path: Optional. If provided, it overrides the default path to the agents directory.

Returns:
- A list of tags (directories) that match the specified pattern.
'''
path = path or OAITrainer.get_model_path(base_dir=args.base_dir,
exp_folder=args.exp_dir,
model_name=name)

handler = CheckedModelNameHandler()
return handler.get_all_checked_tags(path=path)

@staticmethod
def get_most_recent_checkpoint(args, name: str) -> str:
path = OAITrainer.get_model_path(
base_dir=args.base_dir,
exp_folder=args.exp_dir,
model_name=name
)
if not path.exists():
print(f"Warning: The directory {path} does not exist.")
return None
ckpts = [name for name in os.listdir(path) if name.startswith(KeyCheckpoints.CHECKED_MODEL_PREFIX)]
if not ckpts:
print(f"Warning: No checkpoints found in {path} with prefix '{KeyCheckpoints.CHECKED_MODEL_PREFIX}'.")
return None
ckpts_nums = [int(c.split('_')[1]) for c in ckpts]
last_ckpt_num = max(ckpts_nums)
return [c for c in ckpts if c.startswith(f"{KeyCheckpoints.CHECKED_MODEL_PREFIX}{last_ckpt_num}")][0]

@staticmethod
def get_model_path(base_dir: Union[str, Path], exp_folder: Optional[str], model_name: str) -> Path:
"""
Constructs a path for saving or loading an agent model.

Parameters:
base_dir (str or Path): The base directory where models are stored.
exp_folder (str or None): The experiment folder name, or None if not applicable.
model_name (str): The name of the model.

Returns:
Path: A Path object representing the constructed path.
"""
# Ensure base_dir is a Path object
base_dir = Path(base_dir) if isinstance(base_dir, str) else base_dir

experiment_name = OAITrainer.get_experiment_name(exp_folder=exp_folder, model_name=model_name)

path = base_dir / 'agent_models' /experiment_name

return path

@staticmethod
def get_experiment_name(exp_folder: Optional[str], model_name: str):
return f"{exp_folder}/{model_name}" if exp_folder else model_name
4 changes: 2 additions & 2 deletions oai_agents/agents/il.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,11 @@ def run_epoch(self, agent_idx):
self.agents[agent_idx].eval()
return np.mean(losses)

def train_agents(self, epochs=100, exp_name=None):
def train_agents(self, epochs=100):
""" Training routine """
if self.datasets is None:
self.setup_datasets()
exp_name = exp_name or self.args.exp_name
exp_name = self.args.exp_name
run = wandb.init(project="overcooked_ai", entity=self.args.wandb_ent,
dir=str(self.args.base_dir / 'wandb'),
reinit=True, name=exp_name + '_' + self.name, mode=self.args.wandb_mode)
Expand Down
104 changes: 61 additions & 43 deletions oai_agents/agents/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from oai_agents.common.state_encodings import ENCODING_SCHEMES
from oai_agents.common.tags import AgentPerformance, TeamType, TeammatesCollection, KeyCheckpoints
from oai_agents.gym_environments.base_overcooked_env import OvercookedGymEnv
from oai_agents.common.checked_model_name_handler import CheckedModelNameHandler

import numpy as np
import random
Expand All @@ -13,24 +14,28 @@
from sb3_contrib import RecurrentPPO, MaskablePPO
import wandb
import os
from typing import Optional

VEC_ENV_CLS = DummyVecEnv #

class RLAgentTrainer(OAITrainer):
''' Train an RL agent to play with a teammates_collection of agents.'''
def __init__(self, teammates_collection, args,
agent, epoch_timesteps, n_envs,
seed, learner_type,
train_types=[], eval_types=[],
curriculum=None, num_layers=2, hidden_dim=256,
checkpoint_rate=None, name=None, env=None, eval_envs=None,
use_cnn=False, use_lstm=False, use_frame_stack=False,
taper_layers=False, use_policy_clone=False, deterministic=False, start_step: int=0, start_timestep: int=0):
def __init__(
self, teammates_collection, args,
agent, epoch_timesteps, n_envs,
seed, learner_type,
train_types=[], eval_types=[],
curriculum=None, num_layers=2, hidden_dim=256,
checkpoint_rate=None, name=None, env=None, eval_envs=None,
use_cnn=False, use_lstm=False, use_frame_stack=False,
taper_layers=False, use_policy_clone=False, deterministic=False, start_step: int=0, start_timestep: int=0
):


name = name or 'rl_agent'
super(RLAgentTrainer, self).__init__(name, args, seed=seed)


self.args = args
self.device = args.device
self.teammates_len = self.args.teammates_len
Expand Down Expand Up @@ -62,20 +67,23 @@ def __init__(self, teammates_collection, args,
self.start_timestep = start_timestep

self.learning_agent, self.agents = self.get_learning_agent(agent)
self.teammates_collection, self.eval_teammates_collection = self.get_teammates_collection(_tms_clctn = teammates_collection,
learning_agent = self.learning_agent,
train_types = train_types,
eval_types = eval_types)
self.teammates_collection, self.eval_teammates_collection = self.get_teammates_collection(
_tms_clctn = teammates_collection,
learning_agent = self.learning_agent,
train_types = train_types,
eval_types = eval_types
)
self.best_score, self.best_training_rew = -1, float('-inf')

@classmethod
def generate_randomly_initialized_agent(cls,
args,
learner_type:str,
name:str,
seed:int,
hidden_dim:int,
) -> OAIAgent:
def generate_randomly_initialized_agent(
cls,
args,
learner_type:str,
name:str,
seed:int,
hidden_dim:int,
) -> OAIAgent:
'''
Generate a randomly initialized learning agent using the RLAgentTrainer class
This function does not perform any learning
Expand All @@ -84,16 +92,17 @@ def generate_randomly_initialized_agent(cls,
:param seed: Random seed
:returns: An untrained, randomly inititalized RL agent
'''
trainer = cls(name=name,
args=args,
agent=None,
teammates_collection={},
epoch_timesteps=args.epoch_timesteps,
n_envs=args.n_envs,
seed=seed,
hidden_dim=hidden_dim,
learner_type=learner_type,
)
trainer = cls(
name=name,
args=args,
agent=None,
teammates_collection={},
epoch_timesteps=args.epoch_timesteps,
n_envs=args.n_envs,
seed=seed,
hidden_dim=hidden_dim,
learner_type=learner_type,
)

learning_agent, _ = trainer.get_learning_agent(None)
return learning_agent
Expand Down Expand Up @@ -261,19 +270,16 @@ def wrap_agent(self, sb3_agent, name):
return SB3LSTMWrapper(sb3_agent, name, self.args)
return SB3Wrapper(sb3_agent, name, self.args)

def get_experiment_name(self, exp_name):
return exp_name or str(self.args.exp_dir) + '/' + self.name


def should_evaluate(self, steps):
mean_training_rew = np.mean([ep_info["r"] for ep_info in self.learning_agent.agent.ep_info_buffer])
self.best_training_rew *= 0.98
self.best_training_rew *= 1.00

steps_divisable_by_15 = (steps + 1) % 15 == 0
steps_divisible_by_x = (steps + 1) % 15 == 0
mean_rew_greater_than_best = mean_training_rew > self.best_training_rew and self.learning_agent.num_timesteps >= 5e6
checkpoint_rate_reached = self.checkpoint_rate and self.learning_agent.num_timesteps // self.checkpoint_rate > (len(self.ck_list) - 1)

return steps_divisable_by_15 or mean_rew_greater_than_best or checkpoint_rate_reached
return steps_divisible_by_x or mean_rew_greater_than_best or checkpoint_rate_reached

def log_details(self, experiment_name, total_train_timesteps):
print("Training agent: " + self.name + ", for experiment: " + experiment_name)
Expand All @@ -292,8 +298,8 @@ def log_details(self, experiment_name, total_train_timesteps):
print("Final sparse reward ratio: ", self.args.final_sparse_r_ratio)


def train_agents(self, total_train_timesteps, tag_for_returning_agent, exp_name=None, resume_ck_list=None):
experiment_name = self.get_experiment_name(exp_name)
def train_agents(self, total_train_timesteps, tag_for_returning_agent, resume_ck_list=None):
experiment_name = RLAgentTrainer.get_experiment_name(exp_folder=self.args.exp_dir, model_name=self.name)
run = wandb.init(project="overcooked_ai", entity=self.args.wandb_ent, dir=str(self.args.base_dir / 'wandb'),
reinit=True, name=experiment_name, mode=self.args.wandb_mode,
resume="allow")
Expand All @@ -302,23 +308,35 @@ def train_agents(self, total_train_timesteps, tag_for_returning_agent, exp_name=

if self.checkpoint_rate is not None:
if self.args.resume:
path = self.args.base_dir / 'agent_models' / experiment_name

ckpts = [name for name in os.listdir(path) if name.startswith("ck")]
path = RLAgentTrainer.get_model_path(
base_dir=self.args.base_dir,
exp_folder=self.args.exp_dir,
model_name=self.name
)
if not path.exists():
print(f"Warning: The directory {path} does not exist.")
return None
ckpts = [name for name in os.listdir(path) if name.startswith(KeyCheckpoints.CHECKED_MODEL_PREFIX)]
if not ckpts:
print(f"Warning: No checkpoints found in {path} with prefix '{KeyCheckpoints.CHECKED_MODEL_PREFIX}'.")
return None
ckpts_nums = [int(c.split('_')[1]) for c in ckpts]
sorted_idxs = np.argsort(ckpts_nums)
ckpts = [ckpts[i] for i in sorted_idxs]
self.ck_list = [(c[0], path, c[2]) for c in resume_ck_list] if resume_ck_list else [({k: 0 for k in self.args.layout_names}, path, ck) for ck in ckpts]
self.ck_list = [(c[0], path, c[2]) for c in resume_ck_list] if resume_ck_list else [
({k: 0 for k in self.args.layout_names}, path, ck) for ck in ckpts]
else:
self.ck_list = []
path, tag = self.save_agents(tag=f'ck_{len(self.ck_list)}')
path, tag = self.save_agents(tag=f'{KeyCheckpoints.CHECKED_MODEL_PREFIX}{len(self.ck_list)}')
self.ck_list.append(({k: 0 for k in self.args.layout_names}, path, tag))


best_path, best_tag = None, None

self.steps = self.start_step
curr_timesteps = self.start_timestep
prev_timesteps = self.learning_agent.num_timesteps
ck_name_handler = CheckedModelNameHandler()

while curr_timesteps < total_train_timesteps:
self.curriculum.update(current_step=self.steps)
Expand Down Expand Up @@ -347,7 +365,7 @@ def train_agents(self, total_train_timesteps, tag_for_returning_agent, exp_name=

if self.checkpoint_rate:
if self.learning_agent.num_timesteps // self.checkpoint_rate > (len(self.ck_list) - 1):
path, tag = self.save_agents(tag=f'ck_{len(self.ck_list)}_rew_{mean_reward}')
path, tag = self.save_agents(tag=ck_name_handler.generate_tag(id=len(self.ck_list), mean_reward=mean_reward))
self.ck_list.append((rew_per_layout, path, tag))

if mean_reward >= self.best_score:
Expand Down
3 changes: 3 additions & 0 deletions oai_agents/common/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def get_arguments(additional_args=[]):
parser.add_argument("--num-of-ckpoints", type=int, default=10)
parser.add_argument("--resume", action="store_true", default=False, help="Restart from last checkpoint for population training only")

parser.add_argument("--for-evaluation", action="store_true", default=False, help="The trained agents are used for evaluating other agents. Please note that seeds and h_dim are different when agents are trained for evaluating others.)")
parser.add_argument("--num-of-training-variants", type=int, default=4)

for parser_arg, parser_kwargs in additional_args:
parser.add_argument(parser_arg, **parser_kwargs)

Expand Down
Loading