diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index e308938..95827c9 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -124,6 +124,7 @@ class Losses: @dataclass class CleanPuffeRL: + exp_name: str config: argparse.Namespace vecenv: pufferlib.vector.Serial | pufferlib.vector.Multiprocessing policy: nn.Module diff --git a/pokemonred_puffer/train.py b/pokemonred_puffer/train.py index 5a77a14..164d006 100644 --- a/pokemonred_puffer/train.py +++ b/pokemonred_puffer/train.py @@ -200,6 +200,7 @@ def train( args.train.env = "Pokemon Red" with CleanPuffeRL( + exp_name=args.exp_name, config=args.train, vecenv=vecenv, policy=policy,