-
Notifications
You must be signed in to change notification settings - Fork 249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change callback for AdversarialTrainer
#626
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,13 +2,14 @@ | |
import abc | ||
import dataclasses | ||
import logging | ||
from typing import Callable, Iterable, Iterator, Mapping, Optional, Type, overload | ||
from typing import Iterable, Iterator, List, Mapping, Optional, Type, overload | ||
|
||
import numpy as np | ||
import torch as th | ||
import torch.utils.tensorboard as thboard | ||
import tqdm | ||
from stable_baselines3.common import base_class, on_policy_algorithm, policies, vec_env | ||
from stable_baselines3.common.callbacks import BaseCallback | ||
from stable_baselines3.sac import policies as sac_policies | ||
from torch.nn import functional as F | ||
|
||
|
@@ -421,7 +422,7 @@ def train_gen( | |
def train( | ||
self, | ||
total_timesteps: int, | ||
callback: Optional[Callable[[int], None]] = None, | ||
callback: Optional[List[BaseCallback]] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to change the semantics of the argument here, or should we rather deprecate the feature (and introduce a different parameter for additional gen_callback)? I think the suggestion in the original issue was to add a new |
||
) -> None: | ||
"""Alternates between training the generator and discriminator. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The last part of description |
||
|
@@ -434,10 +435,15 @@ def train( | |
Args: | ||
total_timesteps: An upper bound on the number of transitions to sample | ||
from the environment during training. | ||
callback: A function called at the end of every round which takes in a | ||
single argument, the round number. Round numbers are in | ||
`range(total_timesteps // self.gen_train_timesteps)`. | ||
callback: List of stable_baslines3 callback to be passed to the policy | ||
learning function. | ||
""" | ||
if callback is not None: | ||
if self.gen_callback is None: | ||
self.gen_callback = callback | ||
else: | ||
self.gen_callback = callback + [self.gen_callback] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can someone abuse the API by calling Perhaps it would be better to add an optional callback argument to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, can the |
||
|
||
n_rounds = total_timesteps // self.gen_train_timesteps | ||
assert n_rounds >= 1, ( | ||
"No updates (need at least " | ||
|
@@ -450,8 +456,6 @@ def train( | |
with networks.training(self.reward_train): | ||
# switch to training mode (affects dropout, normalization) | ||
self.train_disc() | ||
if callback: | ||
callback(r) | ||
self.logger.dump(self._global_step) | ||
|
||
@overload | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more thing - if you change the arguments, update of training_adversarial.py will also be needed