Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP, Policy] Docstring and refactoring on top of PR 327 #335

Conversation

alexhernandezgarcia
Copy link
Owner

This PR incorporates docstring and refactoring of the code on top of PR #327

Summary of refactoring:

  • Got rid of the method parse_config() and included its content in __init__() to make the handling of all the attributes of a policy more explicit.
  • Got rid of the method instantiate() and combined with make_*() (e.g. make_mlp, make_cnn) into a single method common to all policies, make_model(), which instantiates and returns the corresponding model.

🚧 Work in progress 🚧

self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers)
self.strides = config.get("strides", [(1, 1)] * self.n_layers)
# Environment
# TODO: rethink whether storing the whole environment is needed
self.env = env
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might not need to store the env most of the time actually. But sometimes one might need to know more about the environment configuration to induce inductive bias to their policy model. In the CNN Policy, we use to access the grid dimension like self.ev.height, self.env.width but is not crucial.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, I think storing the whole env is likely to duplicate a lot of info unnecessarily, especially if you're trying to do multiprocessing stuff down the line.

Base Policy class for GFlowNet policy models.
"""

from typing import Union
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also need to import Tuple here

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed.

@engmubarak48
Copy link
Collaborator

@josephdviviano finalizing this PR will hopefully close both #327 and #335

Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers)
self.strides = config.get("strides", [(1, 1)] * self.n_layers)
# Environment
# TODO: rethink whether storing the whole environment is needed
self.env = env
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, I think storing the whole env is likely to duplicate a lot of info unnecessarily, especially if you're trying to do multiprocessing stuff down the line.

@josephdviviano josephdviviano merged commit c9ec03f into 293-flexible-policy-definition Sep 18, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants