Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make adding new Policy Models flexible #327

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
98e4d7e
initial commit: split the base policy and the architectures
engmubarak48 Jun 18, 2024
1388756
ignore git logs
engmubarak48 Jun 21, 2024
7c8a517
refactor base policy class and move models into differrent file
engmubarak48 Jun 21, 2024
4ecd865
handle when config none gracefully
engmubarak48 Jun 21, 2024
d50aeea
black formatting
engmubarak48 Jun 21, 2024
1508e4b
further formatting with isort
engmubarak48 Jun 21, 2024
9cdf18e
formatting black + isort
engmubarak48 Jun 28, 2024
8af7f82
added flatten flag and device movement handling
engmubarak48 Jun 28, 2024
9474150
bug fix: use .cpu() before .numpy()
engmubarak48 Jun 28, 2024
83a4904
added cnn policy, and flatten flag should be set to false when using …
engmubarak48 Jun 28, 2024
e254cf2
black formatting
engmubarak48 Jun 28, 2024
d30c3f2
smaller cnn config like kernel size etc
engmubarak48 Jul 9, 2024
2ac9af5
minor refactor on parse_config
engmubarak48 Jul 9, 2024
7e02200
move self.is_model to instantiate and add super().parse_config(config…
engmubarak48 Jul 9, 2024
e6a14ef
Add docstring and typing to __init__ of policy base.
alexhernandezgarcia Jul 10, 2024
57b0b13
Use kwargs instead of listing parameters explicitly
alexhernandezgarcia Jul 10, 2024
178a08e
Policy MLP: docstring and typing.
alexhernandezgarcia Jul 10, 2024
ace8a28
Get rid of parse_config and include its content in __init__
alexhernandezgarcia Jul 10, 2024
8e6f03d
Combine instantiate and make_* into a single method make_model()
alexhernandezgarcia Jul 10, 2024
774c411
Missing import
alexhernandezgarcia Jul 10, 2024
9520315
Fix config issue by implementing _get_config()
alexhernandezgarcia Jul 10, 2024
910a948
Docstring for base argument
alexhernandezgarcia Jul 10, 2024
c9ec03f
Merge pull request #335 from alexhernandezgarcia/ahg/293-flexible-pol…
josephdviviano Sep 18, 2024
9fa9381
remove env from the cnn policy
engmubarak48 Sep 23, 2024
2bf438a
init the cnn env's height and width in the policy
engmubarak48 Sep 23, 2024
9395e61
add mlp to device
engmubarak48 Sep 23, 2024
595db80
formatting
engmubarak48 Sep 23, 2024
6797bcc
debug: add to print environment information when running GitHub Actio…
engmubarak48 Sep 23, 2024
95258a0
Add the specific versions of pymatgen and spglib
engmubarak48 Sep 25, 2024
61f1e1d
Merge branch 'main' into 293-flexible-policy-definition
engmubarak48 Sep 25, 2024
b7f33d8
revert version downgrade of pymatgen
engmubarak48 Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ playground/
!docs/requirements-docs.txt
.DS_Store
docs/_build/
logs
2 changes: 2 additions & 0 deletions config/env/tetris.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ height: 20
pieces: ["I", "J", "L", "O", "S", "T", "Z"]
# Allowed roations
rotations: [0, 90, 180, 270]
# Don't flatten if using CNN
flatten: True
engmubarak48 marked this conversation as resolved.
Show resolved Hide resolved
# Other config
allow_redundant_rotations: False
allow_eos_before_full: False
Expand Down
16 changes: 16 additions & 0 deletions config/policy/cnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_target_: gflownet.policy.cnn.CNNPolicy

shared: null

forward:
n_layers: 1
channels: [16]
kernel_sizes: [[3, 3], [2, 2], [1, 1]] # Each tuple represents (height, width)
strides: [[2, 2], [2, 1], [1, 1]] # Each tuple represents (vertical_stride, horizontal_stride)
checkpoint: null
reload_ckpt: False

backward:
shared_weights: True
checkpoint: null
reload_ckpt: False
3 changes: 1 addition & 2 deletions config/policy/mlp.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
_target_: gflownet.policy.base.Policy
_target_: gflownet.policy.mlp.MLPPolicy

shared: null

forward:
type: mlp
n_hid: 128
n_layers: 2
checkpoint: null
Expand Down
8 changes: 6 additions & 2 deletions gflownet/envs/tetris.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
height: int = 20,
pieces: List = ["I", "J", "L", "O", "S", "T", "Z"],
rotations: List = [0, 90, 180, 270],
flatten: bool = True,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we move the flattening from the environment to the policy, then we don't need this.

allow_redundant_rotations: bool = False,
allow_eos_before_full: bool = False,
**kwargs,
Expand All @@ -87,6 +88,7 @@ def __init__(
self.height = height
self.pieces = pieces
self.rotations = rotations
self.flatten = flatten
self.allow_redundant_rotations = allow_redundant_rotations
self.allow_eos_before_full = allow_eos_before_full
self.max_pieces_per_type = 100
Expand Down Expand Up @@ -307,7 +309,9 @@ def states2policy(
A tensor containing all the states in the batch.
"""
states = tint(states, device=self.device, int_type=self.int)
return self.states2proxy(states).flatten(start_dim=1).to(self.float)
if self.flatten:
return self.states2proxy(states).flatten(start_dim=1).to(self.float)
return self.states2proxy(states).to(self.float)
Comment on lines +312 to +314
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexhernandezgarcia This is a temporary solution to make the CNN policy work on Tetris env. But normally the flattening should happen inside the model but not in the environment (see my other comments)

if you are okay with that, then I can update.


def state2readable(self, state: Optional[TensorType["height", "width"]] = None):
"""
Expand Down Expand Up @@ -581,7 +585,7 @@ def _plot_board(board, ax: Axes, cellsize: int = 20, linewidth: int = 2):
linewidth : int
The width of the separation between cells, in pixels.
"""
board = board.clone().numpy()
board = board.clone().cpu().numpy()
height = board.shape[0] * cellsize
width = board.shape[1] * cellsize
board_img = 128 * np.ones(
Expand Down
2 changes: 1 addition & 1 deletion gflownet/evaluator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def compute_log_prob_metrics(self, x_tt, metrics=None):

if "corr_prob_traj_rewards" in metrics:
lp_metrics["corr_prob_traj_rewards"] = np.corrcoef(
np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt
np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt.cpu().numpy()
)[0, 1]

if "var_logrewards_logp" in metrics:
Expand Down
77 changes: 6 additions & 71 deletions gflownet/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import torch
from omegaconf import OmegaConf
from torch import nn

from gflownet.utils.common import set_device, set_float_precision


class ModelBase(ABC):
class Policy(ABC):
engmubarak48 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, config, env, device, float_precision, base=None):
# Device and float precision
self.device = set_device(device)
Expand All @@ -21,82 +20,18 @@ def __init__(self, config, env, device, float_precision, base=None):
self.base = base

self.parse_config(config)
self.instantiate()

def parse_config(self, config):
engmubarak48 marked this conversation as resolved.
Show resolved Hide resolved
# If config is null, default to uniform
if config is None:
config = OmegaConf.create()
config.type = "uniform"
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.n_hid = config.get("n_hid", None)
self.n_layers = config.get("n_layers", None)
self.tail = config.get("tail", [])
if "type" in config:
self.type = config.type
elif self.shared_weights:
self.type = self.base.type
else:
raise "Policy type must be defined if shared_weights is False"

@abstractmethod
def instantiate(self):
pass

def __call__(self, states):
return self.model(states)

def make_mlp(self, activation):
"""
Defines an MLP with no top layer activation
If share_weight == True,
baseModel (the model with which weights are to be shared) must be provided
Args
----
layers_dim : list
Dimensionality of each layer
activation : Activation
Activation function
"""
if self.shared_weights == True and self.base is not None:
mlp = nn.Sequential(
self.base.model[:-1],
nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
),
)
return mlp
elif self.shared_weights == False:
layers_dim = (
[self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)]
)
mlp = nn.Sequential(
*(
sum(
[
[nn.Linear(idim, odim)]
+ ([activation] if n < len(layers_dim) - 2 else [])
for n, (idim, odim) in enumerate(
zip(layers_dim, layers_dim[1:])
)
],
[],
)
+ self.tail
)
)
return mlp
else:
raise ValueError(
"Base Model must be provided when shared_weights is set to True"
)


class Policy(ModelBase):
def __init__(self, config, env, device, float_precision, base=None):
super().__init__(config, env, device, float_precision, base)

self.instantiate()
self.type = "uniform"

def instantiate(self):
if self.type == "fixed":
Expand All @@ -105,12 +40,12 @@ def instantiate(self):
elif self.type == "uniform":
self.model = self.uniform_distribution
self.is_model = False
elif self.type == "mlp":
self.model = self.make_mlp(nn.LeakyReLU()).to(self.device)
self.is_model = True
else:
raise "Policy model type not defined"

def __call__(self, states):
return self.model(states)

def fixed_distribution(self, states):
"""
Returns the fixed distribution specified by the environment.
Expand Down
91 changes: 91 additions & 0 deletions gflownet/policy/cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch
from omegaconf import OmegaConf
from torch import nn

from gflownet.policy.base import Policy


class CNNPolicy(Policy):
def __init__(self, config, env, device, float_precision, base=None):
self.env = env
super().__init__(
config=config,
env=env,
device=device,
float_precision=float_precision,
base=base,
)
self.is_model = True
engmubarak48 marked this conversation as resolved.
Show resolved Hide resolved

def make_cnn(self):
"""
Defines an CNN with no top layer activation
"""
if self.shared_weights and self.base is not None:
layers = list(self.base.model.children())[:-1]
last_layer = nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
)

model = nn.Sequential(*layers, last_layer).to(self.device)
return model

current_channels = 1
conv_module = nn.Sequential()

if len(self.kernel_sizes) != self.n_layers:
raise ValueError(
f"Inconsistent dimensions kernel_sizes != n_layers, {len(self.kernel_sizes)} != {self.n_layers}"
)

for i in range(self.n_layers):
conv_module.add_module(
f"conv_{i}",
nn.Conv2d(
in_channels=current_channels,
out_channels=self.channels[i],
kernel_size=tuple(self.kernel_sizes[i]),
stride=tuple(self.strides[i]),
padding=0,
padding_mode="zeros", # Constant zero padding
),
)
conv_module.add_module(f"relu_{i}", nn.ReLU())
current_channels = self.channels[i]

dummy_input = torch.ones(
(1, 1, self.env.height, self.env.width)
) # (batch_size, channels, height, width)
try:
in_channels = conv_module(dummy_input).numel()
if in_channels >= 500_000: # TODO: this could better be handled
raise RuntimeWarning(
"Input channels for the dense layer are too big, this will increase number of parameters"
)
except RuntimeError as e:
raise RuntimeError(
"Failed during convolution operation. Ensure that the kernel sizes and strides are appropriate for the input dimensions."
) from e

model = nn.Sequential(
conv_module, nn.Flatten(), nn.Linear(in_channels, self.output_dim)
)
return model.to(self.device)

def parse_config(self, config):
engmubarak48 marked this conversation as resolved.
Show resolved Hide resolved
if config is None:
config = OmegaConf.create()
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.reload_ckpt = config.get("reload_ckpt", False)
self.n_layers = config.get("n_layers", 3)
self.channels = config.get("channels", [16] * self.n_layers)
self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers)
self.strides = config.get("strides", [(1, 1)] * self.n_layers)

def instantiate(self):
self.model = self.make_cnn()

def __call__(self, states):
states = states.unsqueeze(1) # (batch_size, channels, height, width)
return self.model(states)
Empty file added gflownet/policy/gnn.py
Empty file.
78 changes: 78 additions & 0 deletions gflownet/policy/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from omegaconf import OmegaConf
from torch import nn

from gflownet.policy.base import Policy


class MLPPolicy(Policy):
def __init__(self, config, env, device, float_precision, base=None):
super().__init__(
config=config,
env=env,
device=device,
float_precision=float_precision,
base=base,
)
self.is_model = True

def make_mlp(self, activation):
"""
Defines an MLP with no top layer activation
If share_weight == True,
baseModel (the model with which weights are to be shared) must be provided
Args
----
layers_dim : list
Dimensionality of each layer
activation : Activation
Activation function
"""
if self.shared_weights == True and self.base is not None:
mlp = nn.Sequential(
self.base.model[:-1],
nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
),
)
return mlp
elif self.shared_weights == False:
layers_dim = (
[self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)]
)
mlp = nn.Sequential(
*(
sum(
[
[nn.Linear(idim, odim)]
+ ([activation] if n < len(layers_dim) - 2 else [])
for n, (idim, odim) in enumerate(
zip(layers_dim, layers_dim[1:])
)
],
[],
)
+ self.tail
)
)
return mlp
else:
raise ValueError(
"Base Model must be provided when shared_weights is set to True"
)

def parse_config(self, config):
if config is None:
config = OmegaConf.create()
config.type = "mlp"
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.n_hid = config.get("n_hid", 128)
self.n_layers = config.get("n_layers", 2)
self.tail = config.get("tail", [])
self.reload_ckpt = config.get("reload_ckpt", False)

def instantiate(self):
self.model = self.make_mlp(nn.LeakyReLU()).to(self.device)

def __call__(self, states):
engmubarak48 marked this conversation as resolved.
Show resolved Hide resolved
return self.model(states)
1 change: 0 additions & 1 deletion playground/botorch/mes_exact_deepKernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from math import floor

import gpytorch

# import tqdm
import torch
from botorch.test_functions import Hartmann
Expand Down
1 change: 0 additions & 1 deletion playground/botorch/mes_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np
import torch

# from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.test_functions import Branin, Hartmann
Expand Down
Loading
Loading