Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP, Policy] Docstring and refactoring on top of PR 327 #335

66 changes: 51 additions & 15 deletions gflownet/policy/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,44 @@
from abc import ABC, abstractmethod
"""
Base Policy class for GFlowNet policy models.
"""

from typing import Union

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also need to import Tuple here

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed.

import torch
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig

from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.common import set_device, set_float_precision


class Policy:
def __init__(self, config, env, device, float_precision, base=None):
def __init__(
self,
config: Union[dict, DictConfig],
env: GFlowNetEnv,
device: Union[str, torch.device],
float_precision: [int, torch.dtype],
base=None,
):
"""
Base Policy class for a :class:`GFlowNetAgent`.

Parameters
----------
config : dict or DictConfig
The configuration dictionary to set up the policy model.
env : GFlowNetEnv
The environment used to train the :class:`GFlowNetAgent`, used to extract
needed properties.
device : str or torch.device
The device to be passed to torch tensors.
float_precision : int or torch.dtype
The floating point precision to be passed to torch tensors.
"""
# If config is None, instantiate an empty config (defaults will be used)
if config is None:
config = OmegaConf.create()
# Device and float precision
self.device = set_device(device)
self.float = set_float_precision(float_precision)
Expand All @@ -18,24 +49,29 @@ def __init__(self, config, env, device, float_precision, base=None):
self.output_dim = len(self.fixed_output)
# Optional base model
self.base = base

self.parse_config(config)
self.instantiate()

def parse_config(self, config):
# If config is null, default to uniform
if config is None:
config = OmegaConf.create()
# Policy type, defaults to uniform
self.type = config.get("type", "uniform")
# Checkpoint, defaults to None
self.checkpoint = config.get("checkpoint", None)
# Instantiate the model
self.model, self.is_model = self.make_model()

def instantiate(self):
def make_model(self) -> Tuple[Union[torch.Tensor, torch.nn.Module], bool]:
"""
Instantiates the model of the policy.

Returns
-------
model : torch.tensor or torch.nn.Module
A tensor representing the output of the policy or a torch model.
is_model : bool
True if the policy is a model (for example, a neural network) and False if
it is a fixed tensor (for example to make a uniform distribution).
"""
if self.type == "fixed":
self.model = self.fixed_distribution
self.is_model = False
return self.fixed_distribution, False
elif self.type == "uniform":
self.model = self.uniform_distribution
self.is_model = False
return self.uniform_distribution, False
else:
raise "Policy model type not defined"

Expand Down
62 changes: 31 additions & 31 deletions gflownet/policy/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,32 @@


class CNNPolicy(Policy):
def __init__(self, config, env, device, float_precision, base=None):
def __init__(self, **kwargs):
# Shared weights, defaults to False
self.shared_weights = config.get("shared_weights", False)
# Reload checkpoint, defaults to False
self.reload_ckpt = config.get("reload_ckpt", False)
# CNN features: number of layers, number of channels, kernel sizes, strides
self.n_layers = config.get("n_layers", 3)
self.channels = config.get("channels", [16] * self.n_layers)
self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers)
self.strides = config.get("strides", [(1, 1)] * self.n_layers)
# Environment
# TODO: rethink whether storing the whole environment is needed
self.env = env
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might not need to store the env most of the time actually. But sometimes one might need to know more about the environment configuration to induce inductive bias to their policy model. In the CNN Policy, we use to access the grid dimension like self.ev.height, self.env.width but is not crucial.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, I think storing the whole env is likely to duplicate a lot of info unnecessarily, especially if you're trying to do multiprocessing stuff down the line.

super().__init__(
config=config,
env=env,
device=device,
float_precision=float_precision,
base=base,
)
# Base init
super().__init__(**kwargs)

def make_cnn(self):
def make_model(self):
"""
Defines an CNN with no top layer activation
Instantiates a CNN with no top layer activation.

Returns
-------
model : torch.nn.Module
A torch model containing the CNN.
is_model : bool
True because a CNN is a model.
"""
if self.shared_weights and self.base is not None:
layers = list(self.base.model.children())[:-1]
Expand All @@ -27,14 +40,15 @@ def make_cnn(self):
)

model = nn.Sequential(*layers, last_layer).to(self.device)
return model
return model, True

current_channels = 1
conv_module = nn.Sequential()

if len(self.kernel_sizes) != self.n_layers:
raise ValueError(
f"Inconsistent dimensions kernel_sizes != n_layers, {len(self.kernel_sizes)} != {self.n_layers}"
f"Inconsistent dimensions kernel_sizes != n_layers, "
"{len(self.kernel_sizes)} != {self.n_layers}"
)

for i in range(self.n_layers):
Expand All @@ -59,33 +73,19 @@ def make_cnn(self):
in_channels = conv_module(dummy_input).numel()
if in_channels >= 500_000: # TODO: this could better be handled
raise RuntimeWarning(
"Input channels for the dense layer are too big, this will increase number of parameters"
"Input channels for the dense layer are too big, this will "
"increase number of parameters"
)
except RuntimeError as e:
raise RuntimeError(
"Failed during convolution operation. Ensure that the kernel sizes and strides are appropriate for the input dimensions."
"Failed during convolution operation. Ensure that the kernel sizes "
"and strides are appropriate for the input dimensions."
) from e

model = nn.Sequential(
conv_module, nn.Flatten(), nn.Linear(in_channels, self.output_dim)
)
return model.to(self.device)

def parse_config(self, config):
super().parse_config(config)
if config is None:
config = OmegaConf.create()
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.reload_ckpt = config.get("reload_ckpt", False)
self.n_layers = config.get("n_layers", 3)
self.channels = config.get("channels", [16] * self.n_layers)
self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers)
self.strides = config.get("strides", [(1, 1)] * self.n_layers)

def instantiate(self):
self.model = self.make_cnn()
self.is_model = True
return model.to(self.device), True

def __call__(self, states):
states = states.unsqueeze(1) # (batch_size, channels, height, width)
Expand Down
67 changes: 32 additions & 35 deletions gflownet/policy/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,47 @@


class MLPPolicy(Policy):
def __init__(self, config, env, device, float_precision, base=None):
super().__init__(
config=config,
env=env,
device=device,
float_precision=float_precision,
base=base,
)
def __init__(self, **kwargs):
# Shared weights, defaults to False
self.shared_weights = config.get("shared_weights", False)
# Reload checkpoint, defaults to False
self.reload_ckpt = config.get("reload_ckpt", False)
# MLP features: number of layers, number of hidden units, tail, etc.
self.n_layers = config.get("n_layers", 2)
self.n_hid = config.get("n_hid", 128)
self.tail = config.get("tail", [])
# Base init
super().__init__(**kwargs)

def make_mlp(self, activation):
def make_model(self, activation: nn.Module = nn.LeakyReLU()):
"""
Defines an MLP with no top layer activation
If share_weight == True,
baseModel (the model with which weights are to be shared) must be provided
Args
----
layers_dim : list
Dimensionality of each layer
activation : Activation
Activation function
Instantiates an MLP with no top layer activation as the policy model.

If self.shared_weights is True, the base model with which weights are to be
shared must be provided.

Parameters
----------
activation : nn.Module
Activation function of the MLP layers

Returns
-------
model : torch.tensor or torch.nn.Module
A torch model containing the MLP.
is_model : bool
True because an MLP is a model.
"""
activation.to(self.device)

if self.shared_weights == True and self.base is not None:
mlp = nn.Sequential(
self.base.model[:-1],
nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
),
)
return mlp
return mlp, True
elif self.shared_weights == False:
layers_dim = (
[self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)]
Expand All @@ -53,26 +65,11 @@ def make_mlp(self, activation):
+ self.tail
)
)
return mlp
return mlp, True
else:
raise ValueError(
"Base Model must be provided when shared_weights is set to True"
)

def parse_config(self, config):
super().parse_config(config)
if config is None:
config = OmegaConf.create()
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.n_hid = config.get("n_hid", 128)
self.n_layers = config.get("n_layers", 2)
self.tail = config.get("tail", [])
self.reload_ckpt = config.get("reload_ckpt", False)

def instantiate(self):
self.model = self.make_mlp(nn.LeakyReLU()).to(self.device)
self.is_model = True

def __call__(self, states):
return self.model(states)
Loading