-
Notifications
You must be signed in to change notification settings - Fork 249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change callback for AdversarialTrainer
#626
base: master
Are you sure you want to change the base?
Conversation
@@ -421,7 +422,7 @@ def train_gen( | |||
def train( | |||
self, | |||
total_timesteps: int, | |||
callback: Optional[Callable[[int], None]] = None, | |||
callback: Optional[List[BaseCallback]] = None | |||
) -> None: | |||
"""Alternates between training the generator and discriminator. | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last part of description and finally a call to callback(round)
is probably misleading now.
if self.gen_callback is None: | ||
self.gen_callback = callback | ||
else: | ||
self.gen_callback = callback + [self.gen_callback] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can someone abuse the API by calling train()
multiple times? If so, the value of self.gen_callback
would contain nested list, which is not correct. Generally, the value of gen_callback is currently Optional[BaseCallback]
and we shouldn't change the type to a list at runtime.
Perhaps it would be better to add an optional callback argument to train_gen()
, merge callbacks there, and avoid the stateful change here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, can the learn_kwargs
argument from train_gen()
be removed, as discussed in the original issue #607 ?
@@ -421,7 +422,7 @@ def train_gen( | |||
def train( | |||
self, | |||
total_timesteps: int, | |||
callback: Optional[Callable[[int], None]] = None, | |||
callback: Optional[List[BaseCallback]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to change the semantics of the argument here, or should we rather deprecate the feature (and introduce a different parameter for additional gen_callback)?
I think the suggestion in the original issue was to add a new gen_callback
argument. (Btw, stable-baselines supports both CallbackList and list of callbacks if we wanted to be fancy)
@@ -421,7 +422,7 @@ def train_gen( | |||
def train( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more thing - if you change the arguments, update of training_adversarial.py will also be needed
Changing the callback mechanism of
AdversarialTrainer
such that we can insertsb3.EvalCallback
. See #607.