From 98e4d7ea8b5c3fe15d300fef0f38968aeedbff80 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Tue, 18 Jun 2024 14:55:47 -0400 Subject: [PATCH 01/29] initial commit: split the base policy and the architectures --- gflownet/policy/base.py | 60 +++++++++++------------------------------ gflownet/policy/cnn.py | 0 gflownet/policy/gnn.py | 0 gflownet/policy/mlp.py | 60 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 45 deletions(-) create mode 100644 gflownet/policy/cnn.py create mode 100644 gflownet/policy/gnn.py create mode 100644 gflownet/policy/mlp.py diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index f84b335fc..38630fd88 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -46,51 +46,6 @@ def instantiate(self): def __call__(self, states): return self.model(states) - def make_mlp(self, activation): - """ - Defines an MLP with no top layer activation - If share_weight == True, - baseModel (the model with which weights are to be shared) must be provided - Args - ---- - layers_dim : list - Dimensionality of each layer - activation : Activation - Activation function - """ - if self.shared_weights == True and self.base is not None: - mlp = nn.Sequential( - self.base.model[:-1], - nn.Linear( - self.base.model[-1].in_features, self.base.model[-1].out_features - ), - ) - return mlp - elif self.shared_weights == False: - layers_dim = ( - [self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)] - ) - mlp = nn.Sequential( - *( - sum( - [ - [nn.Linear(idim, odim)] - + ([activation] if n < len(layers_dim) - 2 else []) - for n, (idim, odim) in enumerate( - zip(layers_dim, layers_dim[1:]) - ) - ], - [], - ) - + self.tail - ) - ) - return mlp - else: - raise ValueError( - "Base Model must be provided when shared_weights is set to True" - ) - class Policy(ModelBase): def __init__(self, config, env, device, float_precision, base=None): @@ -111,6 +66,21 @@ def instantiate(self): else: raise "Policy model type not defined" + def instantiate(self): + if self.type == "fixed": + self.model = self.fixed_distribution + self.is_model = False + elif self.type == "uniform": + self.model = self.uniform_distribution + self.is_model = False + elif self.type == "mlp": + from policy.mlp import MLPPolicy + mlp_policy = MLPPolicy(self.config, self.env, self.device, self.float_precision, self.base) + self.model = mlp_policy.model + self.is_model = mlp_policy.is_model + else: + raise "Policy model type not defined" + def fixed_distribution(self, states): """ Returns the fixed distribution specified by the environment. diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py new file mode 100644 index 000000000..e69de29bb diff --git a/gflownet/policy/gnn.py b/gflownet/policy/gnn.py new file mode 100644 index 000000000..e69de29bb diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py new file mode 100644 index 000000000..f5bd35b43 --- /dev/null +++ b/gflownet/policy/mlp.py @@ -0,0 +1,60 @@ +# policy_models/mlp_policy.py + +import torch +from torch import nn +from gflownet.policy.base import ModelBase + + +class MLPPolicy(ModelBase): + def __init__(self, config, env, device, float_precision, base=None): + super().__init__(config, env, device, float_precision, base) + self.instantiate() + + def instantiate(self): + self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) + self.is_model = True + + def make_mlp(self, activation): + """ + Defines an MLP with no top layer activation + If share_weight == True, + baseModel (the model with which weights are to be shared) must be provided + Args + ---- + layers_dim : list + Dimensionality of each layer + activation : Activation + Activation function + """ + if self.shared_weights == True and self.base is not None: + mlp = nn.Sequential( + self.base.model[:-1], + nn.Linear( + self.base.model[-1].in_features, self.base.model[-1].out_features + ), + ) + return mlp + elif self.shared_weights == False: + layers_dim = ( + [self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)] + ) + mlp = nn.Sequential( + *( + sum( + [ + [nn.Linear(idim, odim)] + + ([activation] if n < len(layers_dim) - 2 else []) + for n, (idim, odim) in enumerate( + zip(layers_dim, layers_dim[1:]) + ) + ], + [], + ) + + self.tail + ) + ) + return mlp + else: + raise ValueError( + "Base Model must be provided when shared_weights is set to True" + ) \ No newline at end of file From 1388756cd8aa41ccfbcc7aa69c4f9aae5ca5ab4a Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Thu, 20 Jun 2024 20:46:53 -0400 Subject: [PATCH 02/29] ignore git logs --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 64a683331..cc4f3ca2d 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ playground/ !docs/requirements-docs.txt .DS_Store docs/_build/ +logs From 7c8a51724b67f26dc350e5ab88ae819d6fefeb81 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Thu, 20 Jun 2024 20:48:29 -0400 Subject: [PATCH 03/29] refactor base policy class and move models into differrent file --- config/policy/mlp.yaml | 3 +-- gflownet/policy/base.py | 45 +++++------------------------------------ gflownet/policy/mlp.py | 34 +++++++++++++++++++++---------- 3 files changed, 29 insertions(+), 53 deletions(-) diff --git a/config/policy/mlp.yaml b/config/policy/mlp.yaml index a9da46af7..b9e2f9a6c 100644 --- a/config/policy/mlp.yaml +++ b/config/policy/mlp.yaml @@ -1,9 +1,8 @@ -_target_: gflownet.policy.base.Policy +_target_: gflownet.policy.mlp.MLPPolicy shared: null forward: - type: mlp n_hid: 128 n_layers: 2 checkpoint: null diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 38630fd88..0aecc9142 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -2,12 +2,11 @@ import torch from omegaconf import OmegaConf -from torch import nn from gflownet.utils.common import set_device, set_float_precision -class ModelBase(ABC): +class Policy(ABC): def __init__(self, config, env, device, float_precision, base=None): # Device and float precision self.device = set_device(device) @@ -21,6 +20,7 @@ def __init__(self, config, env, device, float_precision, base=None): self.base = base self.parse_config(config) + self.instantiate() def parse_config(self, config): # If config is null, default to uniform @@ -28,30 +28,10 @@ def parse_config(self, config): config = OmegaConf.create() config.type = "uniform" self.checkpoint = config.get("checkpoint", None) - self.shared_weights = config.get("shared_weights", False) - self.n_hid = config.get("n_hid", None) - self.n_layers = config.get("n_layers", None) - self.tail = config.get("tail", []) if "type" in config: self.type = config.type - elif self.shared_weights: - self.type = self.base.type else: - raise "Policy type must be defined if shared_weights is False" - - @abstractmethod - def instantiate(self): - pass - - def __call__(self, states): - return self.model(states) - - -class Policy(ModelBase): - def __init__(self, config, env, device, float_precision, base=None): - super().__init__(config, env, device, float_precision, base) - - self.instantiate() + self.type = "uniform" def instantiate(self): if self.type == "fixed": @@ -60,26 +40,11 @@ def instantiate(self): elif self.type == "uniform": self.model = self.uniform_distribution self.is_model = False - elif self.type == "mlp": - self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) - self.is_model = True else: raise "Policy model type not defined" - def instantiate(self): - if self.type == "fixed": - self.model = self.fixed_distribution - self.is_model = False - elif self.type == "uniform": - self.model = self.uniform_distribution - self.is_model = False - elif self.type == "mlp": - from policy.mlp import MLPPolicy - mlp_policy = MLPPolicy(self.config, self.env, self.device, self.float_precision, self.base) - self.model = mlp_policy.model - self.is_model = mlp_policy.is_model - else: - raise "Policy model type not defined" + def __call__(self, states): + return self.model(states) def fixed_distribution(self, states): """ diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index f5bd35b43..ec7237448 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -1,17 +1,16 @@ -# policy_models/mlp_policy.py - -import torch from torch import nn -from gflownet.policy.base import ModelBase +from gflownet.policy.base import Policy -class MLPPolicy(ModelBase): +class MLPPolicy(Policy): def __init__(self, config, env, device, float_precision, base=None): - super().__init__(config, env, device, float_precision, base) - self.instantiate() - - def instantiate(self): - self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) + super().__init__( + config=config, + env=env, + device=device, + float_precision=float_precision, + base=base, + ) self.is_model = True def make_mlp(self, activation): @@ -57,4 +56,17 @@ def make_mlp(self, activation): else: raise ValueError( "Base Model must be provided when shared_weights is set to True" - ) \ No newline at end of file + ) + + def parse_config(self, config): + self.checkpoint = config.get("checkpoint", False) + self.shared_weights = config.get("shared_weights", False) + self.n_hid = config.get("n_hid", None) + self.n_layers = config.get("n_layers", None) + self.tail = config.get("tail", []) + + def instantiate(self): + self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) + + def __call__(self, states): + return self.model(states) From 4ecd865a778dd61cc175e22641447b3e1de8e8d0 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Fri, 21 Jun 2024 01:45:40 -0400 Subject: [PATCH 04/29] handle when config none gracefully --- gflownet/policy/mlp.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index ec7237448..37379b74a 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -1,4 +1,5 @@ from torch import nn +from omegaconf import OmegaConf from gflownet.policy.base import Policy @@ -59,11 +60,15 @@ def make_mlp(self, activation): ) def parse_config(self, config): - self.checkpoint = config.get("checkpoint", False) + if config is None: + config = OmegaConf.create() + config.type = "mlp" + self.checkpoint = config.get("checkpoint", None) self.shared_weights = config.get("shared_weights", False) - self.n_hid = config.get("n_hid", None) - self.n_layers = config.get("n_layers", None) + self.n_hid = config.get("n_hid", 128) + self.n_layers = config.get("n_layers", 2) self.tail = config.get("tail", []) + self.reload_ckpt = config.get("reload_ckpt", False) def instantiate(self): self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) From d50aeeae0e26b127762b2bb9e9b4e70e45f21ed9 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Fri, 21 Jun 2024 01:51:44 -0400 Subject: [PATCH 05/29] black formatting --- gflownet/policy/mlp.py | 2 +- scripts/crystal/eval_crystalgflownet.py | 1 + scripts/crystal/eval_gflownet.py | 48 +++++++++---------- .../crystal/sample_uniform_with_rewards.py | 1 + scripts/pyxtal/compatibility_sg_n_atoms.py | 1 + scripts/pyxtal/get_n_compatible_for_sg.py | 1 + scripts/pyxtal/pyxtal_vs_pymatgen.py | 1 + 7 files changed, 30 insertions(+), 25 deletions(-) diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index 37379b74a..583760b67 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -62,7 +62,7 @@ def make_mlp(self, activation): def parse_config(self, config): if config is None: config = OmegaConf.create() - config.type = "mlp" + config.type = "mlp" self.checkpoint = config.get("checkpoint", None) self.shared_weights = config.get("shared_weights", False) self.n_hid = config.get("n_hid", 128) diff --git a/scripts/crystal/eval_crystalgflownet.py b/scripts/crystal/eval_crystalgflownet.py index beae9a130..c4f0f3973 100644 --- a/scripts/crystal/eval_crystalgflownet.py +++ b/scripts/crystal/eval_crystalgflownet.py @@ -1,6 +1,7 @@ """ Computes evaluation metrics and plots from a pre-trained GFlowNet model. """ + import pickle import shutil import sys diff --git a/scripts/crystal/eval_gflownet.py b/scripts/crystal/eval_gflownet.py index e8857e5b5..085622d2c 100644 --- a/scripts/crystal/eval_gflownet.py +++ b/scripts/crystal/eval_gflownet.py @@ -229,30 +229,30 @@ def main(args): env.proxy.is_bandgap = False # Test -# samples = [env.readable2state(readable) for readable in gflownet.buffer.test["samples"]] -# energies = env.proxy(env.states2proxy(samples)) -# df = pd.DataFrame( -# { -# "readable": gflownet.buffer.test["samples"], -# "energies": energies.tolist(), -# } -# ) -# df.to_csv(output_dir / f"val.csv") -# dct = {"x": samples, "energy": energies.tolist()} -# pickle.dump(dct, open(output_dir / f"val.pkl", "wb")) -# -# # Train -# samples = [env.readable2state(readable) for readable in gflownet.buffer.train["samples"]] -# energies = env.proxy(env.states2proxy(samples)) -# df = pd.DataFrame( -# { -# "readable": gflownet.buffer.train["samples"], -# "energies": energies.tolist(), -# } -# ) -# df.to_csv(output_dir / f"train.csv") -# dct = {"x": samples, "energy": energies.tolist()} -# pickle.dump(dct, open(output_dir / f"train.pkl", "wb")) + # samples = [env.readable2state(readable) for readable in gflownet.buffer.test["samples"]] + # energies = env.proxy(env.states2proxy(samples)) + # df = pd.DataFrame( + # { + # "readable": gflownet.buffer.test["samples"], + # "energies": energies.tolist(), + # } + # ) + # df.to_csv(output_dir / f"val.csv") + # dct = {"x": samples, "energy": energies.tolist()} + # pickle.dump(dct, open(output_dir / f"val.pkl", "wb")) + # + # # Train + # samples = [env.readable2state(readable) for readable in gflownet.buffer.train["samples"]] + # energies = env.proxy(env.states2proxy(samples)) + # df = pd.DataFrame( + # { + # "readable": gflownet.buffer.train["samples"], + # "energies": energies.tolist(), + # } + # ) + # df.to_csv(output_dir / f"train.csv") + # dct = {"x": samples, "energy": energies.tolist()} + # pickle.dump(dct, open(output_dir / f"train.pkl", "wb")) if args.n_samples > 0 and args.n_samples <= 1e5 and not args.random_only: print( diff --git a/scripts/crystal/sample_uniform_with_rewards.py b/scripts/crystal/sample_uniform_with_rewards.py index c1d715f4f..02cadd717 100644 --- a/scripts/crystal/sample_uniform_with_rewards.py +++ b/scripts/crystal/sample_uniform_with_rewards.py @@ -3,6 +3,7 @@ should be run with the same config as main.py, e.g. python sample_uniform_with_rewards.py +experiments=crystals/albatross_sg_first logger.do.online=False user=sasha """ + import pickle import sys diff --git a/scripts/pyxtal/compatibility_sg_n_atoms.py b/scripts/pyxtal/compatibility_sg_n_atoms.py index b0d08c5bb..a01272d16 100644 --- a/scripts/pyxtal/compatibility_sg_n_atoms.py +++ b/scripts/pyxtal/compatibility_sg_n_atoms.py @@ -3,6 +3,7 @@ combinations spanned by the N_SYMMETRY_GROUPS, N_SPECIES and MAX_N_ATOMS. The results are printed to stdout. """ + import itertools import time diff --git a/scripts/pyxtal/get_n_compatible_for_sg.py b/scripts/pyxtal/get_n_compatible_for_sg.py index 443a72a4a..6e1c0d335 100644 --- a/scripts/pyxtal/get_n_compatible_for_sg.py +++ b/scripts/pyxtal/get_n_compatible_for_sg.py @@ -3,6 +3,7 @@ spanned by the --max_n_atoms and --max_n_species. The results are written to a file in --output_dir. """ + import itertools import time from argparse import ArgumentParser diff --git a/scripts/pyxtal/pyxtal_vs_pymatgen.py b/scripts/pyxtal/pyxtal_vs_pymatgen.py index 62a226ae7..4a1e326be 100644 --- a/scripts/pyxtal/pyxtal_vs_pymatgen.py +++ b/scripts/pyxtal/pyxtal_vs_pymatgen.py @@ -2,6 +2,7 @@ A simple script to determine which space group symbols are different in pyxtal and pymatgen. """ + from argparse import ArgumentParser from pymatgen.symmetry.groups import ( From 1508e4b2d1ef6f482430657fd2c5e164f2321865 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Fri, 21 Jun 2024 01:56:43 -0400 Subject: [PATCH 06/29] further formatting with isort --- gflownet/policy/mlp.py | 3 ++- scripts/crystal/eval_crystalgflownet.py | 1 + scripts/crystal/eval_gflownet.py | 2 +- scripts/crystal/sample_uniform_with_rewards.py | 4 ++-- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index 583760b67..07d6528ec 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -1,5 +1,6 @@ -from torch import nn from omegaconf import OmegaConf +from torch import nn + from gflownet.policy.base import Policy diff --git a/scripts/crystal/eval_crystalgflownet.py b/scripts/crystal/eval_crystalgflownet.py index c4f0f3973..4eb77ec3c 100644 --- a/scripts/crystal/eval_crystalgflownet.py +++ b/scripts/crystal/eval_crystalgflownet.py @@ -15,6 +15,7 @@ sys.path.append(str(Path(__file__).resolve().parent.parent)) from crystalrandom import generate_random_crystals + from gflownet.gflownet import GFlowNetAgent from gflownet.utils.common import load_gflow_net_from_run_path from gflownet.utils.policy import parse_policy_config diff --git a/scripts/crystal/eval_gflownet.py b/scripts/crystal/eval_gflownet.py index 085622d2c..f90d051bb 100644 --- a/scripts/crystal/eval_gflownet.py +++ b/scripts/crystal/eval_gflownet.py @@ -8,8 +8,8 @@ from argparse import ArgumentParser from pathlib import Path -import pandas as pd import numpy as np +import pandas as pd import torch from tqdm import tqdm diff --git a/scripts/crystal/sample_uniform_with_rewards.py b/scripts/crystal/sample_uniform_with_rewards.py index 02cadd717..e078791d2 100644 --- a/scripts/crystal/sample_uniform_with_rewards.py +++ b/scripts/crystal/sample_uniform_with_rewards.py @@ -9,11 +9,11 @@ import hydra import pandas as pd +from crystalrandom import generate_random_crystals_uniform + from gflownet.utils.common import chdir_random_subdir from gflownet.utils.policy import parse_policy_config -from crystalrandom import generate_random_crystals_uniform - @hydra.main(config_path="../../config", config_name="main", version_base="1.1") def main(config): From 9cdf18e46dc7033e26cbfdbbf1f55af4a8da9ac6 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Thu, 27 Jun 2024 23:43:50 -0400 Subject: [PATCH 07/29] formatting black + isort --- gflownet/envs/base.py | 3 ++- gflownet/envs/crystals/composition.py | 7 ++----- gflownet/envs/crystals/crystal.py | 3 ++- gflownet/envs/crystals/lattice_parameters.py | 12 +++--------- gflownet/envs/ctorus.py | 3 ++- gflownet/gflownet.py | 10 ++-------- gflownet/utils/batch.py | 12 ++---------- .../utils/crystals/build_lattice_dicts.py | 19 ++++++------------- playground/botorch/mes_exact_deepKernel.py | 1 - playground/botorch/mes_gp.py | 1 - playground/botorch/mes_gp_debug.py | 7 ++++--- playground/botorch/mes_nn_bao_fix.py | 5 ++--- playground/botorch/mes_nn_hardcode_gpVal.py | 4 ++-- playground/botorch/mes_nn_like_gp.py | 5 ++--- .../mes_nn_like_gp_nondiagonalcovar.py | 5 ++--- playground/botorch/mes_var_deepKernel.py | 5 +---- scripts/crystal/eval_gflownet.py | 3 ++- scripts/crystal_eval/eval_CGFN.py | 3 ++- scripts/crystal_eval/metrics.py | 3 ++- scripts/dav_mp20_stats.py | 3 ++- scripts/eval_gflownet.py | 3 ++- scripts/pyxtal/pyxtal_vs_pymatgen.py | 8 ++------ .../gflownet/envs/test_lattice_parameters.py | 17 ++++++----------- tests/gflownet/envs/test_tree.py | 11 ++--------- .../policy/test_multihead_tree_policy.py | 11 ++++------- tests/gflownet/utils/test_batch.py | 13 +++---------- 26 files changed, 61 insertions(+), 116 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 3a1d0a3f9..7591e1dbd 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -16,7 +16,8 @@ from torch.distributions import Categorical from torchtyping import TensorType -from gflownet.utils.common import copy, set_device, set_float_precision, tbool, tfloat +from gflownet.utils.common import (copy, set_device, set_float_precision, + tbool, tfloat) CMAP = mpl.colormaps["cividis"] """ diff --git a/gflownet/envs/crystals/composition.py b/gflownet/envs/crystals/composition.py index 65a67a9d6..9fceefeb9 100644 --- a/gflownet/envs/crystals/composition.py +++ b/gflownet/envs/crystals/composition.py @@ -16,11 +16,8 @@ from gflownet.utils.common import tfloat, tlong from gflownet.utils.crystals.constants import ELEMENT_NAMES, OXIDATION_STATES from gflownet.utils.crystals.pyxtal_cache import ( - get_space_group, - space_group_check_compatible, - space_group_lowest_free_wp_multiplicity, - space_group_wyckoff_gcd, -) + get_space_group, space_group_check_compatible, + space_group_lowest_free_wp_multiplicity, space_group_wyckoff_gcd) N_ELEMENTS_ORACLE = 94 diff --git a/gflownet/envs/crystals/crystal.py b/gflownet/envs/crystals/crystal.py index fb527ff25..322150bef 100644 --- a/gflownet/envs/crystals/crystal.py +++ b/gflownet/envs/crystals/crystal.py @@ -14,7 +14,8 @@ from tqdm import tqdm from gflownet.envs.crystals.composition import Composition -from gflownet.envs.crystals.lattice_parameters import PARAMETER_NAMES, LatticeParameters +from gflownet.envs.crystals.lattice_parameters import (PARAMETER_NAMES, + LatticeParameters) from gflownet.envs.crystals.spacegroup import SpaceGroup from gflownet.envs.stack import Stack from gflownet.utils.crystals.constants import TRICLINIC diff --git a/gflownet/envs/crystals/lattice_parameters.py b/gflownet/envs/crystals/lattice_parameters.py index 8983a0aee..52dac848d 100644 --- a/gflownet/envs/crystals/lattice_parameters.py +++ b/gflownet/envs/crystals/lattice_parameters.py @@ -15,15 +15,9 @@ from gflownet.envs.cube import ContinuousCube from gflownet.utils.common import copy, tfloat -from gflownet.utils.crystals.constants import ( - CUBIC, - HEXAGONAL, - MONOCLINIC, - ORTHORHOMBIC, - RHOMBOHEDRAL, - TETRAGONAL, - TRICLINIC, -) +from gflownet.utils.crystals.constants import (CUBIC, HEXAGONAL, MONOCLINIC, + ORTHORHOMBIC, RHOMBOHEDRAL, + TETRAGONAL, TRICLINIC) LENGTH_PARAMETER_NAMES = ("a", "b", "c") ANGLE_PARAMETER_NAMES = ("alpha", "beta", "gamma") diff --git a/gflownet/envs/ctorus.py b/gflownet/envs/ctorus.py index 261a0bd6b..5cb590907 100644 --- a/gflownet/envs/ctorus.py +++ b/gflownet/envs/ctorus.py @@ -9,7 +9,8 @@ import numpy.typing as npt import pandas as pd import torch -from torch.distributions import Categorical, MixtureSameFamily, Uniform, VonMises +from torch.distributions import (Categorical, MixtureSameFamily, Uniform, + VonMises) from torchtyping import TensorType from gflownet.envs.htorus import HybridTorus diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 81c881076..2fd1b5ba8 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -24,14 +24,8 @@ from gflownet.proxy.base import Proxy from gflownet.utils.batch import Batch from gflownet.utils.buffer import Buffer -from gflownet.utils.common import ( - bootstrap_samples, - set_device, - set_float_precision, - tbool, - tfloat, - tlong, -) +from gflownet.utils.common import (bootstrap_samples, set_device, + set_float_precision, tbool, tfloat, tlong) class GFlowNetAgent: diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index adbe869f2..6e7d784f9 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -8,16 +8,8 @@ from gflownet.envs.base import GFlowNetEnv from gflownet.proxy.base import Proxy -from gflownet.utils.common import ( - concat_items, - copy, - extend, - set_device, - set_float_precision, - tbool, - tfloat, - tlong, -) +from gflownet.utils.common import (concat_items, copy, extend, set_device, + set_float_precision, tbool, tfloat, tlong) class Batch: diff --git a/gflownet/utils/crystals/build_lattice_dicts.py b/gflownet/utils/crystals/build_lattice_dicts.py index 8bdaac227..d2589ae97 100644 --- a/gflownet/utils/crystals/build_lattice_dicts.py +++ b/gflownet/utils/crystals/build_lattice_dicts.py @@ -9,19 +9,12 @@ import numpy as np import yaml -from lattice_constants import ( - CRYSTAL_CLASSES_WIKIPEDIA, - CRYSTAL_LATTICE_SYSTEMS, - CRYSTAL_SYSTEMS, - POINT_SYMMETRIES, - RHOMBOHEDRAL_SPACE_GROUPS_WIKIPEDIA, -) -from pymatgen.symmetry.groups import ( - PointGroup, - SpaceGroup, - SymmetryGroup, - sg_symbol_from_int_number, -) +from lattice_constants import (CRYSTAL_CLASSES_WIKIPEDIA, + CRYSTAL_LATTICE_SYSTEMS, CRYSTAL_SYSTEMS, + POINT_SYMMETRIES, + RHOMBOHEDRAL_SPACE_GROUPS_WIKIPEDIA) +from pymatgen.symmetry.groups import (PointGroup, SpaceGroup, SymmetryGroup, + sg_symbol_from_int_number) N_SPACE_GROUPS = 230 diff --git a/playground/botorch/mes_exact_deepKernel.py b/playground/botorch/mes_exact_deepKernel.py index b77bb2e89..743b68256 100644 --- a/playground/botorch/mes_exact_deepKernel.py +++ b/playground/botorch/mes_exact_deepKernel.py @@ -8,7 +8,6 @@ from math import floor import gpytorch - # import tqdm import torch from botorch.test_functions import Hartmann diff --git a/playground/botorch/mes_gp.py b/playground/botorch/mes_gp.py index b51df0ce6..8afde5dc8 100644 --- a/playground/botorch/mes_gp.py +++ b/playground/botorch/mes_gp.py @@ -6,7 +6,6 @@ import numpy as np import torch - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Branin, Hartmann diff --git a/playground/botorch/mes_gp_debug.py b/playground/botorch/mes_gp_debug.py index 06c5a3ed6..76af6ff00 100644 --- a/playground/botorch/mes_gp_debug.py +++ b/playground/botorch/mes_gp_debug.py @@ -3,7 +3,6 @@ import gpytorch import numpy as np import torch - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -50,8 +49,10 @@ def forward(self, x): from botorch.models.utils import add_output_dim from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal -from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) +from gpytorch.likelihoods.gaussian_likelihood import \ + FixedNoiseGaussianLikelihood class myGPModel(SingleTaskGP): diff --git a/playground/botorch/mes_nn_bao_fix.py b/playground/botorch/mes_nn_bao_fix.py index c4f7de6d0..861268aeb 100644 --- a/playground/botorch/mes_nn_bao_fix.py +++ b/playground/botorch/mes_nn_bao_fix.py @@ -2,7 +2,6 @@ import numpy as np import torch - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -56,8 +55,8 @@ from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal - +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) # from botorch.posteriors. from torch.distributions import Normal diff --git a/playground/botorch/mes_nn_hardcode_gpVal.py b/playground/botorch/mes_nn_hardcode_gpVal.py index 6320d4f05..42dcbb9b4 100644 --- a/playground/botorch/mes_nn_hardcode_gpVal.py +++ b/playground/botorch/mes_nn_hardcode_gpVal.py @@ -2,7 +2,6 @@ import numpy as np import torch - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -57,7 +56,8 @@ from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) class NN_Model(Model): diff --git a/playground/botorch/mes_nn_like_gp.py b/playground/botorch/mes_nn_like_gp.py index d0664a342..0b15c98be 100644 --- a/playground/botorch/mes_nn_like_gp.py +++ b/playground/botorch/mes_nn_like_gp.py @@ -3,15 +3,14 @@ import numpy as np import torch from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.test_functions import Hartmann -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) from gpytorch.mlls import ExactMarginalLogLikelihood - # from botorch.posteriors. from torch import distributions, tensor from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential diff --git a/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py b/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py index 2c75fd6a4..1d6626b33 100644 --- a/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py +++ b/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py @@ -3,15 +3,14 @@ import numpy as np import torch from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.test_functions import Hartmann -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) from gpytorch.mlls import ExactMarginalLogLikelihood - # from botorch.posteriors. from torch import distributions, tensor from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential diff --git a/playground/botorch/mes_var_deepKernel.py b/playground/botorch/mes_var_deepKernel.py index f712eaaf0..989af46c4 100644 --- a/playground/botorch/mes_var_deepKernel.py +++ b/playground/botorch/mes_var_deepKernel.py @@ -10,7 +10,6 @@ from math import floor import gpytorch - # import tqdm import torch from botorch.test_functions import Hartmann @@ -215,9 +214,7 @@ def posterior( from botorch.acquisition.max_value_entropy_search import ( - qLowerBoundMaxValueEntropy, - qMaxValueEntropy, -) + qLowerBoundMaxValueEntropy, qMaxValueEntropy) proxy = myGPModel(model, train_x, train_y.unsqueeze(-1)) qMES = qLowerBoundMaxValueEntropy(proxy, candidate_set=train_x, use_gumbel=True) diff --git a/scripts/crystal/eval_gflownet.py b/scripts/crystal/eval_gflownet.py index f90d051bb..e3018acd9 100644 --- a/scripts/crystal/eval_gflownet.py +++ b/scripts/crystal/eval_gflownet.py @@ -19,7 +19,8 @@ from hydra.utils import instantiate from gflownet.gflownet import GFlowNetAgent -from gflownet.utils.common import load_gflow_net_from_run_path, read_hydra_config +from gflownet.utils.common import (load_gflow_net_from_run_path, + read_hydra_config) from gflownet.utils.policy import parse_policy_config diff --git a/scripts/crystal_eval/eval_CGFN.py b/scripts/crystal_eval/eval_CGFN.py index 6c38a8792..e95968bbc 100644 --- a/scripts/crystal_eval/eval_CGFN.py +++ b/scripts/crystal_eval/eval_CGFN.py @@ -11,7 +11,8 @@ import numpy as np import pandas as pd -from metrics import SG2LP, SMACT, Comp2SG, Eform, Ehull, NumberOfElements, Rediscovery +from metrics import (SG2LP, SMACT, Comp2SG, Eform, Ehull, NumberOfElements, + Rediscovery) # put all metrics to be computed here: diff --git a/scripts/crystal_eval/metrics.py b/scripts/crystal_eval/metrics.py index 245d8b75b..faad9a8df 100644 --- a/scripts/crystal_eval/metrics.py +++ b/scripts/crystal_eval/metrics.py @@ -15,7 +15,8 @@ from pymatgen.core import Composition, Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from smact import Element -from smact.data_loader import lookup_element_oxidation_states_custom as oxi_custom +from smact.data_loader import \ + lookup_element_oxidation_states_custom as oxi_custom from smact.screening import pauling_test, smact_filter from tqdm import tqdm diff --git a/scripts/dav_mp20_stats.py b/scripts/dav_mp20_stats.py index 9c0a27630..ac0f3a957 100644 --- a/scripts/dav_mp20_stats.py +++ b/scripts/dav_mp20_stats.py @@ -19,7 +19,8 @@ from collections import Counter -from external.repos.ActiveLearningMaterials.dave.utils.loaders import make_loaders +from external.repos.ActiveLearningMaterials.dave.utils.loaders import \ + make_loaders from gflownet.proxy.crystals.dave import DAVE from gflownet.utils.common import load_gflow_net_from_run_path, resolve_path diff --git a/scripts/eval_gflownet.py b/scripts/eval_gflownet.py index fcb509d4e..d07ac5fc9 100644 --- a/scripts/eval_gflownet.py +++ b/scripts/eval_gflownet.py @@ -15,7 +15,8 @@ sys.path.append(str(Path(__file__).resolve().parent.parent)) from gflownet.gflownet import GFlowNetAgent -from gflownet.utils.common import load_gflow_net_from_run_path, read_hydra_config +from gflownet.utils.common import (load_gflow_net_from_run_path, + read_hydra_config) from gflownet.utils.policy import parse_policy_config diff --git a/scripts/pyxtal/pyxtal_vs_pymatgen.py b/scripts/pyxtal/pyxtal_vs_pymatgen.py index 4a1e326be..7ca544489 100644 --- a/scripts/pyxtal/pyxtal_vs_pymatgen.py +++ b/scripts/pyxtal/pyxtal_vs_pymatgen.py @@ -5,12 +5,8 @@ from argparse import ArgumentParser -from pymatgen.symmetry.groups import ( - PointGroup, - SpaceGroup, - SymmetryGroup, - sg_symbol_from_int_number, -) +from pymatgen.symmetry.groups import (PointGroup, SpaceGroup, SymmetryGroup, + sg_symbol_from_int_number) from pyxtal.symmetry import Group N_SYMMETRY_GROUPS = 230 diff --git a/tests/gflownet/envs/test_lattice_parameters.py b/tests/gflownet/envs/test_lattice_parameters.py index e2dfb0398..9bbb5b514 100644 --- a/tests/gflownet/envs/test_lattice_parameters.py +++ b/tests/gflownet/envs/test_lattice_parameters.py @@ -9,18 +9,13 @@ import pytest import torch -from gflownet.envs.crystals.lattice_parameters import PARAMETER_NAMES, LatticeParameters +from gflownet.envs.crystals.lattice_parameters import (PARAMETER_NAMES, + LatticeParameters) from gflownet.utils.common import tfloat -from gflownet.utils.crystals.constants import ( - CUBIC, - HEXAGONAL, - LATTICE_SYSTEMS, - MONOCLINIC, - ORTHORHOMBIC, - RHOMBOHEDRAL, - TETRAGONAL, - TRICLINIC, -) +from gflownet.utils.crystals.constants import (CUBIC, HEXAGONAL, + LATTICE_SYSTEMS, MONOCLINIC, + ORTHORHOMBIC, RHOMBOHEDRAL, + TETRAGONAL, TRICLINIC) N_REPETITIONS = 100 diff --git a/tests/gflownet/envs/test_tree.py b/tests/gflownet/envs/test_tree.py index 5ddf5a384..1acd98217 100644 --- a/tests/gflownet/envs/test_tree.py +++ b/tests/gflownet/envs/test_tree.py @@ -5,15 +5,8 @@ import pytest import torch -from gflownet.envs.tree import ( - ActionType, - Attribute, - NodeType, - Operator, - Stage, - Status, - Tree, -) +from gflownet.envs.tree import (ActionType, Attribute, NodeType, Operator, + Stage, Status, Tree) from gflownet.utils.common import tfloat NAN = float("NaN") diff --git a/tests/gflownet/policy/test_multihead_tree_policy.py b/tests/gflownet/policy/test_multihead_tree_policy.py index a28570a53..5d5448099 100644 --- a/tests/gflownet/policy/test_multihead_tree_policy.py +++ b/tests/gflownet/policy/test_multihead_tree_policy.py @@ -4,13 +4,10 @@ from torch_geometric.data import Batch from gflownet.envs.tree import Attribute, Operator, Tree -from gflownet.policy.multihead_tree import ( - Backbone, - FeatureSelectionHead, - LeafSelectionHead, - OperatorSelectionHead, - ThresholdSelectionHead, -) +from gflownet.policy.multihead_tree import (Backbone, FeatureSelectionHead, + LeafSelectionHead, + OperatorSelectionHead, + ThresholdSelectionHead) N_OBSERVATIONS = 17 N_FEATURES = 5 diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 5c0ca8628..afdf7ce7e 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -8,16 +8,9 @@ from gflownet.proxy.box.corners import Corners from gflownet.proxy.tetris import Tetris as TetrisScore from gflownet.utils.batch import Batch -from gflownet.utils.common import ( - concat_items, - copy, - set_device, - set_float_precision, - tbool, - tfloat, - tint, - tlong, -) +from gflownet.utils.common import (concat_items, copy, set_device, + set_float_precision, tbool, tfloat, tint, + tlong) # Sets the number of repetitions for the tests. Please increase to ~10 after # introducing changes to the Batch class and decrease again to 1 when passed. From 8af7f82a56a321fc6242ed5110a5b598c7612f20 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Thu, 27 Jun 2024 23:45:13 -0400 Subject: [PATCH 08/29] added flatten flag and device movement handling --- gflownet/envs/tetris.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/gflownet/envs/tetris.py b/gflownet/envs/tetris.py index bf927de2b..4fa98764d 100644 --- a/gflownet/envs/tetris.py +++ b/gflownet/envs/tetris.py @@ -75,6 +75,7 @@ def __init__( height: int = 20, pieces: List = ["I", "J", "L", "O", "S", "T", "Z"], rotations: List = [0, 90, 180, 270], + flatten: bool = True, allow_redundant_rotations: bool = False, allow_eos_before_full: bool = False, **kwargs, @@ -87,6 +88,7 @@ def __init__( self.height = height self.pieces = pieces self.rotations = rotations + self.flatten = flatten self.allow_redundant_rotations = allow_redundant_rotations self.allow_eos_before_full = allow_eos_before_full self.max_pieces_per_type = 100 @@ -307,7 +309,9 @@ def states2policy( A tensor containing all the states in the batch. """ states = tint(states, device=self.device, int_type=self.int) - return self.states2proxy(states).flatten(start_dim=1).to(self.float) + if self.flatten: + return self.states2proxy(states).flatten(start_dim=1).to(self.float) + return self.states2proxy(states).to(self.float) def state2readable(self, state: Optional[TensorType["height", "width"]] = None): """ @@ -581,7 +585,7 @@ def _plot_board(board, ax: Axes, cellsize: int = 20, linewidth: int = 2): linewidth : int The width of the separation between cells, in pixels. """ - board = board.clone().numpy() + board = board.clone().cpu().numpy() height = board.shape[0] * cellsize width = board.shape[1] * cellsize board_img = 128 * np.ones( From 94741509595d4905b02999eeb63d23150168b743 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Thu, 27 Jun 2024 23:47:03 -0400 Subject: [PATCH 09/29] bug fix: use .cpu() before .numpy() --- gflownet/evaluator/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index adc633c73..aa2d91c86 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -278,7 +278,7 @@ def compute_log_prob_metrics(self, x_tt, metrics=None): if "corr_prob_traj_rewards" in metrics: lp_metrics["corr_prob_traj_rewards"] = np.corrcoef( - np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt + np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt.cpu().numpy() )[0, 1] if "var_logrewards_logp" in metrics: From 83a49043cca27fe95f69783660ffa4179c7f0932 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Thu, 27 Jun 2024 23:47:54 -0400 Subject: [PATCH 10/29] added cnn policy, and flatten flag should be set to false when using cnn policy --- config/env/tetris.yaml | 2 + config/policy/cnn.yaml | 16 ++++++++ gflownet/policy/cnn.py | 91 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+) create mode 100644 config/policy/cnn.yaml diff --git a/config/env/tetris.yaml b/config/env/tetris.yaml index 95ae0503c..bd46fc32a 100644 --- a/config/env/tetris.yaml +++ b/config/env/tetris.yaml @@ -11,6 +11,8 @@ height: 20 pieces: ["I", "J", "L", "O", "S", "T", "Z"] # Allowed roations rotations: [0, 90, 180, 270] +# Don't flatten if using CNN +flatten: True # Other config allow_redundant_rotations: False allow_eos_before_full: False diff --git a/config/policy/cnn.yaml b/config/policy/cnn.yaml new file mode 100644 index 000000000..6c7956925 --- /dev/null +++ b/config/policy/cnn.yaml @@ -0,0 +1,16 @@ +_target_: gflownet.policy.cnn.CNNPolicy + +shared: null + +forward: + n_layers: 1 + channels: [16] + kernel_sizes: [[3, 3], [2, 2], [1, 1]] # Each tuple represents (height, width) + strides: [[2, 2], [2, 1], [1, 1]] # Each tuple represents (vertical_stride, horizontal_stride) + checkpoint: null + reload_ckpt: False + +backward: + shared_weights: True + checkpoint: null + reload_ckpt: False diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py index e69de29bb..6e4cedf1d 100644 --- a/gflownet/policy/cnn.py +++ b/gflownet/policy/cnn.py @@ -0,0 +1,91 @@ +import torch +from omegaconf import OmegaConf +from torch import nn + +from gflownet.policy.base import Policy + + +class CNNPolicy(Policy): + def __init__(self, config, env, device, float_precision, base=None): + self.env = env + super().__init__( + config=config, + env=env, + device=device, + float_precision=float_precision, + base=base, + ) + self.is_model = True + + def make_cnn(self): + """ + Defines an CNN with no top layer activation + """ + if self.shared_weights and self.base is not None: + layers = list(self.base.model.children())[:-1] + last_layer = nn.Linear( + self.base.model[-1].in_features, self.base.model[-1].out_features + ) + + model = nn.Sequential(*layers, last_layer).to(self.device) + return model + + current_channels = 1 + conv_module = nn.Sequential() + + if len(self.kernel_sizes) != self.n_layers: + raise ValueError( + f"Inconsistent dimensions kernel_sizes != n_layers, {len(self.kernel_sizes)} != {self.n_layers}" + ) + + for i in range(self.n_layers): + conv_module.add_module( + f"conv_{i}", + nn.Conv2d( + in_channels=current_channels, + out_channels=self.channels[i], + kernel_size=tuple(self.kernel_sizes[i]), + stride=tuple(self.strides[i]), + padding=0, + padding_mode="zeros", # Constant zero padding + ), + ) + conv_module.add_module(f"relu_{i}", nn.ReLU()) + current_channels = self.channels[i] + + dummy_input = torch.ones( + (1, 1, self.env.height, self.env.width) + ) # (batch_size, channels, height, width) + try: + in_channels = conv_module(dummy_input).numel() + if in_channels >= 500_000: # TODO: this could better be handled + raise RuntimeWarning( + "Input channels for the dense layer are too big, this will increase number of parameters" + ) + except RuntimeError as e: + raise RuntimeError( + "Failed during convolution operation. Ensure that the kernel sizes and strides are appropriate for the input dimensions." + ) from e + + model = nn.Sequential( + conv_module, nn.Flatten(), nn.Linear(in_channels, self.output_dim) + ) + return model.to(self.device) + + def parse_config(self, config): + if config is None: + config = OmegaConf.create() + self.checkpoint = config.get("checkpoint", None) + self.shared_weights = config.get("shared_weights", False) + self.reload_ckpt = config.get("reload_ckpt", False) + self.n_layers = config.get("n_layers", 3) + self.channels = config.get("channels", [16] * self.n_layers) + self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers) + self.strides = config.get("strides", [(1, 1)] * self.n_layers) + + def instantiate(self): + self.model = self.make_cnn() + + def __call__(self, states): + states = states.unsqueeze(1) # (batch_size, channels, height, width) + return self.model(states) From e254cf2b5a2d334382758af7af7a5b981155a118 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Thu, 27 Jun 2024 23:51:53 -0400 Subject: [PATCH 11/29] black formatting --- gflownet/envs/base.py | 3 +-- gflownet/envs/crystals/composition.py | 7 +++++-- gflownet/envs/crystals/crystal.py | 3 +-- gflownet/envs/crystals/lattice_parameters.py | 12 +++++++++--- gflownet/envs/ctorus.py | 3 +-- gflownet/gflownet.py | 10 ++++++++-- gflownet/utils/batch.py | 12 ++++++++++-- .../utils/crystals/build_lattice_dicts.py | 19 +++++++++++++------ scripts/crystal/eval_gflownet.py | 3 +-- scripts/crystal_eval/eval_CGFN.py | 3 +-- scripts/crystal_eval/metrics.py | 3 +-- scripts/dav_mp20_stats.py | 3 +-- scripts/eval_gflownet.py | 3 +-- scripts/pyxtal/pyxtal_vs_pymatgen.py | 8 ++++++-- .../gflownet/envs/test_lattice_parameters.py | 17 +++++++++++------ tests/gflownet/envs/test_tree.py | 11 +++++++++-- .../policy/test_multihead_tree_policy.py | 11 +++++++---- tests/gflownet/utils/test_batch.py | 13 ++++++++++--- 18 files changed, 96 insertions(+), 48 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 7591e1dbd..3a1d0a3f9 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -16,8 +16,7 @@ from torch.distributions import Categorical from torchtyping import TensorType -from gflownet.utils.common import (copy, set_device, set_float_precision, - tbool, tfloat) +from gflownet.utils.common import copy, set_device, set_float_precision, tbool, tfloat CMAP = mpl.colormaps["cividis"] """ diff --git a/gflownet/envs/crystals/composition.py b/gflownet/envs/crystals/composition.py index 9fceefeb9..65a67a9d6 100644 --- a/gflownet/envs/crystals/composition.py +++ b/gflownet/envs/crystals/composition.py @@ -16,8 +16,11 @@ from gflownet.utils.common import tfloat, tlong from gflownet.utils.crystals.constants import ELEMENT_NAMES, OXIDATION_STATES from gflownet.utils.crystals.pyxtal_cache import ( - get_space_group, space_group_check_compatible, - space_group_lowest_free_wp_multiplicity, space_group_wyckoff_gcd) + get_space_group, + space_group_check_compatible, + space_group_lowest_free_wp_multiplicity, + space_group_wyckoff_gcd, +) N_ELEMENTS_ORACLE = 94 diff --git a/gflownet/envs/crystals/crystal.py b/gflownet/envs/crystals/crystal.py index 322150bef..fb527ff25 100644 --- a/gflownet/envs/crystals/crystal.py +++ b/gflownet/envs/crystals/crystal.py @@ -14,8 +14,7 @@ from tqdm import tqdm from gflownet.envs.crystals.composition import Composition -from gflownet.envs.crystals.lattice_parameters import (PARAMETER_NAMES, - LatticeParameters) +from gflownet.envs.crystals.lattice_parameters import PARAMETER_NAMES, LatticeParameters from gflownet.envs.crystals.spacegroup import SpaceGroup from gflownet.envs.stack import Stack from gflownet.utils.crystals.constants import TRICLINIC diff --git a/gflownet/envs/crystals/lattice_parameters.py b/gflownet/envs/crystals/lattice_parameters.py index 52dac848d..8983a0aee 100644 --- a/gflownet/envs/crystals/lattice_parameters.py +++ b/gflownet/envs/crystals/lattice_parameters.py @@ -15,9 +15,15 @@ from gflownet.envs.cube import ContinuousCube from gflownet.utils.common import copy, tfloat -from gflownet.utils.crystals.constants import (CUBIC, HEXAGONAL, MONOCLINIC, - ORTHORHOMBIC, RHOMBOHEDRAL, - TETRAGONAL, TRICLINIC) +from gflownet.utils.crystals.constants import ( + CUBIC, + HEXAGONAL, + MONOCLINIC, + ORTHORHOMBIC, + RHOMBOHEDRAL, + TETRAGONAL, + TRICLINIC, +) LENGTH_PARAMETER_NAMES = ("a", "b", "c") ANGLE_PARAMETER_NAMES = ("alpha", "beta", "gamma") diff --git a/gflownet/envs/ctorus.py b/gflownet/envs/ctorus.py index 5cb590907..261a0bd6b 100644 --- a/gflownet/envs/ctorus.py +++ b/gflownet/envs/ctorus.py @@ -9,8 +9,7 @@ import numpy.typing as npt import pandas as pd import torch -from torch.distributions import (Categorical, MixtureSameFamily, Uniform, - VonMises) +from torch.distributions import Categorical, MixtureSameFamily, Uniform, VonMises from torchtyping import TensorType from gflownet.envs.htorus import HybridTorus diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 2fd1b5ba8..81c881076 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -24,8 +24,14 @@ from gflownet.proxy.base import Proxy from gflownet.utils.batch import Batch from gflownet.utils.buffer import Buffer -from gflownet.utils.common import (bootstrap_samples, set_device, - set_float_precision, tbool, tfloat, tlong) +from gflownet.utils.common import ( + bootstrap_samples, + set_device, + set_float_precision, + tbool, + tfloat, + tlong, +) class GFlowNetAgent: diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index 6e7d784f9..adbe869f2 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -8,8 +8,16 @@ from gflownet.envs.base import GFlowNetEnv from gflownet.proxy.base import Proxy -from gflownet.utils.common import (concat_items, copy, extend, set_device, - set_float_precision, tbool, tfloat, tlong) +from gflownet.utils.common import ( + concat_items, + copy, + extend, + set_device, + set_float_precision, + tbool, + tfloat, + tlong, +) class Batch: diff --git a/gflownet/utils/crystals/build_lattice_dicts.py b/gflownet/utils/crystals/build_lattice_dicts.py index d2589ae97..8bdaac227 100644 --- a/gflownet/utils/crystals/build_lattice_dicts.py +++ b/gflownet/utils/crystals/build_lattice_dicts.py @@ -9,12 +9,19 @@ import numpy as np import yaml -from lattice_constants import (CRYSTAL_CLASSES_WIKIPEDIA, - CRYSTAL_LATTICE_SYSTEMS, CRYSTAL_SYSTEMS, - POINT_SYMMETRIES, - RHOMBOHEDRAL_SPACE_GROUPS_WIKIPEDIA) -from pymatgen.symmetry.groups import (PointGroup, SpaceGroup, SymmetryGroup, - sg_symbol_from_int_number) +from lattice_constants import ( + CRYSTAL_CLASSES_WIKIPEDIA, + CRYSTAL_LATTICE_SYSTEMS, + CRYSTAL_SYSTEMS, + POINT_SYMMETRIES, + RHOMBOHEDRAL_SPACE_GROUPS_WIKIPEDIA, +) +from pymatgen.symmetry.groups import ( + PointGroup, + SpaceGroup, + SymmetryGroup, + sg_symbol_from_int_number, +) N_SPACE_GROUPS = 230 diff --git a/scripts/crystal/eval_gflownet.py b/scripts/crystal/eval_gflownet.py index e3018acd9..f90d051bb 100644 --- a/scripts/crystal/eval_gflownet.py +++ b/scripts/crystal/eval_gflownet.py @@ -19,8 +19,7 @@ from hydra.utils import instantiate from gflownet.gflownet import GFlowNetAgent -from gflownet.utils.common import (load_gflow_net_from_run_path, - read_hydra_config) +from gflownet.utils.common import load_gflow_net_from_run_path, read_hydra_config from gflownet.utils.policy import parse_policy_config diff --git a/scripts/crystal_eval/eval_CGFN.py b/scripts/crystal_eval/eval_CGFN.py index e95968bbc..6c38a8792 100644 --- a/scripts/crystal_eval/eval_CGFN.py +++ b/scripts/crystal_eval/eval_CGFN.py @@ -11,8 +11,7 @@ import numpy as np import pandas as pd -from metrics import (SG2LP, SMACT, Comp2SG, Eform, Ehull, NumberOfElements, - Rediscovery) +from metrics import SG2LP, SMACT, Comp2SG, Eform, Ehull, NumberOfElements, Rediscovery # put all metrics to be computed here: diff --git a/scripts/crystal_eval/metrics.py b/scripts/crystal_eval/metrics.py index faad9a8df..245d8b75b 100644 --- a/scripts/crystal_eval/metrics.py +++ b/scripts/crystal_eval/metrics.py @@ -15,8 +15,7 @@ from pymatgen.core import Composition, Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from smact import Element -from smact.data_loader import \ - lookup_element_oxidation_states_custom as oxi_custom +from smact.data_loader import lookup_element_oxidation_states_custom as oxi_custom from smact.screening import pauling_test, smact_filter from tqdm import tqdm diff --git a/scripts/dav_mp20_stats.py b/scripts/dav_mp20_stats.py index ac0f3a957..9c0a27630 100644 --- a/scripts/dav_mp20_stats.py +++ b/scripts/dav_mp20_stats.py @@ -19,8 +19,7 @@ from collections import Counter -from external.repos.ActiveLearningMaterials.dave.utils.loaders import \ - make_loaders +from external.repos.ActiveLearningMaterials.dave.utils.loaders import make_loaders from gflownet.proxy.crystals.dave import DAVE from gflownet.utils.common import load_gflow_net_from_run_path, resolve_path diff --git a/scripts/eval_gflownet.py b/scripts/eval_gflownet.py index d07ac5fc9..fcb509d4e 100644 --- a/scripts/eval_gflownet.py +++ b/scripts/eval_gflownet.py @@ -15,8 +15,7 @@ sys.path.append(str(Path(__file__).resolve().parent.parent)) from gflownet.gflownet import GFlowNetAgent -from gflownet.utils.common import (load_gflow_net_from_run_path, - read_hydra_config) +from gflownet.utils.common import load_gflow_net_from_run_path, read_hydra_config from gflownet.utils.policy import parse_policy_config diff --git a/scripts/pyxtal/pyxtal_vs_pymatgen.py b/scripts/pyxtal/pyxtal_vs_pymatgen.py index 7ca544489..4a1e326be 100644 --- a/scripts/pyxtal/pyxtal_vs_pymatgen.py +++ b/scripts/pyxtal/pyxtal_vs_pymatgen.py @@ -5,8 +5,12 @@ from argparse import ArgumentParser -from pymatgen.symmetry.groups import (PointGroup, SpaceGroup, SymmetryGroup, - sg_symbol_from_int_number) +from pymatgen.symmetry.groups import ( + PointGroup, + SpaceGroup, + SymmetryGroup, + sg_symbol_from_int_number, +) from pyxtal.symmetry import Group N_SYMMETRY_GROUPS = 230 diff --git a/tests/gflownet/envs/test_lattice_parameters.py b/tests/gflownet/envs/test_lattice_parameters.py index 9bbb5b514..e2dfb0398 100644 --- a/tests/gflownet/envs/test_lattice_parameters.py +++ b/tests/gflownet/envs/test_lattice_parameters.py @@ -9,13 +9,18 @@ import pytest import torch -from gflownet.envs.crystals.lattice_parameters import (PARAMETER_NAMES, - LatticeParameters) +from gflownet.envs.crystals.lattice_parameters import PARAMETER_NAMES, LatticeParameters from gflownet.utils.common import tfloat -from gflownet.utils.crystals.constants import (CUBIC, HEXAGONAL, - LATTICE_SYSTEMS, MONOCLINIC, - ORTHORHOMBIC, RHOMBOHEDRAL, - TETRAGONAL, TRICLINIC) +from gflownet.utils.crystals.constants import ( + CUBIC, + HEXAGONAL, + LATTICE_SYSTEMS, + MONOCLINIC, + ORTHORHOMBIC, + RHOMBOHEDRAL, + TETRAGONAL, + TRICLINIC, +) N_REPETITIONS = 100 diff --git a/tests/gflownet/envs/test_tree.py b/tests/gflownet/envs/test_tree.py index 1acd98217..5ddf5a384 100644 --- a/tests/gflownet/envs/test_tree.py +++ b/tests/gflownet/envs/test_tree.py @@ -5,8 +5,15 @@ import pytest import torch -from gflownet.envs.tree import (ActionType, Attribute, NodeType, Operator, - Stage, Status, Tree) +from gflownet.envs.tree import ( + ActionType, + Attribute, + NodeType, + Operator, + Stage, + Status, + Tree, +) from gflownet.utils.common import tfloat NAN = float("NaN") diff --git a/tests/gflownet/policy/test_multihead_tree_policy.py b/tests/gflownet/policy/test_multihead_tree_policy.py index 5d5448099..a28570a53 100644 --- a/tests/gflownet/policy/test_multihead_tree_policy.py +++ b/tests/gflownet/policy/test_multihead_tree_policy.py @@ -4,10 +4,13 @@ from torch_geometric.data import Batch from gflownet.envs.tree import Attribute, Operator, Tree -from gflownet.policy.multihead_tree import (Backbone, FeatureSelectionHead, - LeafSelectionHead, - OperatorSelectionHead, - ThresholdSelectionHead) +from gflownet.policy.multihead_tree import ( + Backbone, + FeatureSelectionHead, + LeafSelectionHead, + OperatorSelectionHead, + ThresholdSelectionHead, +) N_OBSERVATIONS = 17 N_FEATURES = 5 diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index afdf7ce7e..5c0ca8628 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -8,9 +8,16 @@ from gflownet.proxy.box.corners import Corners from gflownet.proxy.tetris import Tetris as TetrisScore from gflownet.utils.batch import Batch -from gflownet.utils.common import (concat_items, copy, set_device, - set_float_precision, tbool, tfloat, tint, - tlong) +from gflownet.utils.common import ( + concat_items, + copy, + set_device, + set_float_precision, + tbool, + tfloat, + tint, + tlong, +) # Sets the number of repetitions for the tests. Please increase to ~10 after # introducing changes to the Batch class and decrease again to 1 when passed. From d30c3f2a26bfb1d6147b3a649b0405b7bd4d267b Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Tue, 9 Jul 2024 18:35:08 -0400 Subject: [PATCH 12/29] smaller cnn config like kernel size etc --- config/policy/cnn.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/config/policy/cnn.yaml b/config/policy/cnn.yaml index 6c7956925..98818bd8e 100644 --- a/config/policy/cnn.yaml +++ b/config/policy/cnn.yaml @@ -3,10 +3,10 @@ _target_: gflownet.policy.cnn.CNNPolicy shared: null forward: - n_layers: 1 - channels: [16] - kernel_sizes: [[3, 3], [2, 2], [1, 1]] # Each tuple represents (height, width) - strides: [[2, 2], [2, 1], [1, 1]] # Each tuple represents (vertical_stride, horizontal_stride) + n_layers: 2 + channels: [16, 32] + kernel_sizes: [[3, 3], [2, 2]] # Each tuple represents (height, width) + strides: [[1, 1], [1, 1]] # Each tuple represents (vertical_stride, horizontal_stride) checkpoint: null reload_ckpt: False From 2ac9af595deeb16291360bd7b16927c5acc9c8cb Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Tue, 9 Jul 2024 18:36:08 -0400 Subject: [PATCH 13/29] minor refactor on parse_config --- gflownet/policy/base.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 0aecc9142..eb98e9bdb 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -6,7 +6,7 @@ from gflownet.utils.common import set_device, set_float_precision -class Policy(ABC): +class Policy: def __init__(self, config, env, device, float_precision, base=None): # Device and float precision self.device = set_device(device) @@ -26,12 +26,8 @@ def parse_config(self, config): # If config is null, default to uniform if config is None: config = OmegaConf.create() - config.type = "uniform" + self.type = config.get("type", "uniform") self.checkpoint = config.get("checkpoint", None) - if "type" in config: - self.type = config.type - else: - self.type = "uniform" def instantiate(self): if self.type == "fixed": From 7e02200b5daa5565e138d9627fbc9af0c4a8f84c Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Tue, 9 Jul 2024 18:37:17 -0400 Subject: [PATCH 14/29] move self.is_model to instantiate and add super().parse_config(config) in the parse_config --- gflownet/policy/cnn.py | 3 ++- gflownet/policy/mlp.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py index 6e4cedf1d..f8343d88b 100644 --- a/gflownet/policy/cnn.py +++ b/gflownet/policy/cnn.py @@ -15,7 +15,6 @@ def __init__(self, config, env, device, float_precision, base=None): float_precision=float_precision, base=base, ) - self.is_model = True def make_cnn(self): """ @@ -73,6 +72,7 @@ def make_cnn(self): return model.to(self.device) def parse_config(self, config): + super().parse_config(config) if config is None: config = OmegaConf.create() self.checkpoint = config.get("checkpoint", None) @@ -85,6 +85,7 @@ def parse_config(self, config): def instantiate(self): self.model = self.make_cnn() + self.is_model = True def __call__(self, states): states = states.unsqueeze(1) # (batch_size, channels, height, width) diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index 07d6528ec..b90f6e52f 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -13,7 +13,6 @@ def __init__(self, config, env, device, float_precision, base=None): float_precision=float_precision, base=base, ) - self.is_model = True def make_mlp(self, activation): """ @@ -61,9 +60,9 @@ def make_mlp(self, activation): ) def parse_config(self, config): + super().parse_config(config) if config is None: config = OmegaConf.create() - config.type = "mlp" self.checkpoint = config.get("checkpoint", None) self.shared_weights = config.get("shared_weights", False) self.n_hid = config.get("n_hid", 128) @@ -73,6 +72,7 @@ def parse_config(self, config): def instantiate(self): self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) + self.is_model = True def __call__(self, states): return self.model(states) From e6a14efa1ef546f39f07b55bbcf68a606d93f304 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 18:01:12 +0200 Subject: [PATCH 15/29] Add docstring and typing to __init__ of policy base. --- gflownet/policy/base.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index eb98e9bdb..e6bc5d6b2 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -1,13 +1,41 @@ -from abc import ABC, abstractmethod +""" +Base Policy class for GFlowNet policy models. +""" + +from typing import Union import torch from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from gflownet.envs.base import GFlowNetEnv from gflownet.utils.common import set_device, set_float_precision class Policy: - def __init__(self, config, env, device, float_precision, base=None): + def __init__( + self, + config: Union[dict, DictConfig], + env: GFlowNetEnv, + device: Union[str, torch.device], + float_precision: [int, torch.dtype], + base=None, + ): + """ + Base Policy class for a :class:`GFlowNetAgent`. + + Parameters + ---------- + config : dict or DictConfig + The configuration dictionary to set up the policy model. + env : GFlowNetEnv + The environment used to train the :class:`GFlowNetAgent`, used to extract + needed properties. + device : str or torch.device + The device to be passed to torch tensors. + float_precision : int or torch.dtype + The floating point precision to be passed to torch tensors. + """ # Device and float precision self.device = set_device(device) self.float = set_float_precision(float_precision) From 57b0b136bcc936e61b60f3b715dec86bd71fdf6f Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 18:04:47 +0200 Subject: [PATCH 16/29] Use kwargs instead of listing parameters explicitly --- gflownet/policy/cnn.py | 10 ++-------- gflownet/policy/mlp.py | 10 ++-------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py index f8343d88b..e501ab688 100644 --- a/gflownet/policy/cnn.py +++ b/gflownet/policy/cnn.py @@ -6,15 +6,9 @@ class CNNPolicy(Policy): - def __init__(self, config, env, device, float_precision, base=None): + def __init__(self, **kwargs): + super().__init__(**kwargs) self.env = env - super().__init__( - config=config, - env=env, - device=device, - float_precision=float_precision, - base=base, - ) def make_cnn(self): """ diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index b90f6e52f..aa332a5dd 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -5,14 +5,8 @@ class MLPPolicy(Policy): - def __init__(self, config, env, device, float_precision, base=None): - super().__init__( - config=config, - env=env, - device=device, - float_precision=float_precision, - base=base, - ) + def __init__(self, **kwargs): + super().__init__(**kwargs) def make_mlp(self, activation): """ From 178a08e57606334219a0a89cc84d0b980617613d Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 18:39:05 +0200 Subject: [PATCH 17/29] Policy MLP: docstring and typing. --- gflownet/policy/mlp.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index aa332a5dd..a1d3dc32f 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -8,17 +8,17 @@ class MLPPolicy(Policy): def __init__(self, **kwargs): super().__init__(**kwargs) - def make_mlp(self, activation): + def make_mlp(self, activation: nn.Module): """ Defines an MLP with no top layer activation - If share_weight == True, - baseModel (the model with which weights are to be shared) must be provided - Args - ---- - layers_dim : list - Dimensionality of each layer - activation : Activation - Activation function + + If config.share_weights is True, the base model with which weights are to be + shared must be provided. + + Parameters + ---------- + activation : nn.Module + Activation function of the MLP layers """ if self.shared_weights == True and self.base is not None: mlp = nn.Sequential( From ace8a28f3e2176428ac260e5b4519dc7b6f70f87 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 19:05:25 +0200 Subject: [PATCH 18/29] Get rid of parse_config and include its content in __init__ --- gflownet/policy/base.py | 15 +++++++-------- gflownet/policy/cnn.py | 26 +++++++++++++------------- gflownet/policy/mlp.py | 22 ++++++++++------------ 3 files changed, 30 insertions(+), 33 deletions(-) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index e6bc5d6b2..68fb489fb 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -36,6 +36,9 @@ def __init__( float_precision : int or torch.dtype The floating point precision to be passed to torch tensors. """ + # If config is None, instantiate an empty config (defaults will be used) + if config is None: + config = OmegaConf.create() # Device and float precision self.device = set_device(device) self.float = set_float_precision(float_precision) @@ -46,16 +49,12 @@ def __init__( self.output_dim = len(self.fixed_output) # Optional base model self.base = base - - self.parse_config(config) - self.instantiate() - - def parse_config(self, config): - # If config is null, default to uniform - if config is None: - config = OmegaConf.create() + # Policy type, defaults to uniform self.type = config.get("type", "uniform") + # Checkpoint, defaults to None self.checkpoint = config.get("checkpoint", None) + # Instantiate the model + self.instantiate() def instantiate(self): if self.type == "fixed": diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py index e501ab688..7d80d1270 100644 --- a/gflownet/policy/cnn.py +++ b/gflownet/policy/cnn.py @@ -7,8 +7,20 @@ class CNNPolicy(Policy): def __init__(self, **kwargs): - super().__init__(**kwargs) + # Shared weights, defaults to False + self.shared_weights = config.get("shared_weights", False) + # Reload checkpoint, defaults to False + self.reload_ckpt = config.get("reload_ckpt", False) + # CNN features: number of layers, number of channels, kernel sizes, strides + self.n_layers = config.get("n_layers", 3) + self.channels = config.get("channels", [16] * self.n_layers) + self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers) + self.strides = config.get("strides", [(1, 1)] * self.n_layers) + # Environment + # TODO: rethink whether storing the whole environment is needed self.env = env + # Base init + super().__init__(**kwargs) def make_cnn(self): """ @@ -65,18 +77,6 @@ def make_cnn(self): ) return model.to(self.device) - def parse_config(self, config): - super().parse_config(config) - if config is None: - config = OmegaConf.create() - self.checkpoint = config.get("checkpoint", None) - self.shared_weights = config.get("shared_weights", False) - self.reload_ckpt = config.get("reload_ckpt", False) - self.n_layers = config.get("n_layers", 3) - self.channels = config.get("channels", [16] * self.n_layers) - self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers) - self.strides = config.get("strides", [(1, 1)] * self.n_layers) - def instantiate(self): self.model = self.make_cnn() self.is_model = True diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index a1d3dc32f..bf012e1b6 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -6,13 +6,22 @@ class MLPPolicy(Policy): def __init__(self, **kwargs): + # Shared weights, defaults to False + self.shared_weights = config.get("shared_weights", False) + # Reload checkpoint, defaults to False + self.reload_ckpt = config.get("reload_ckpt", False) + # MLP features: number of layers, number of hidden units, tail, etc. + self.n_layers = config.get("n_layers", 2) + self.n_hid = config.get("n_hid", 128) + self.tail = config.get("tail", []) + # Base init super().__init__(**kwargs) def make_mlp(self, activation: nn.Module): """ Defines an MLP with no top layer activation - If config.share_weights is True, the base model with which weights are to be + If self.shared_weights is True, the base model with which weights are to be shared must be provided. Parameters @@ -53,17 +62,6 @@ def make_mlp(self, activation: nn.Module): "Base Model must be provided when shared_weights is set to True" ) - def parse_config(self, config): - super().parse_config(config) - if config is None: - config = OmegaConf.create() - self.checkpoint = config.get("checkpoint", None) - self.shared_weights = config.get("shared_weights", False) - self.n_hid = config.get("n_hid", 128) - self.n_layers = config.get("n_layers", 2) - self.tail = config.get("tail", []) - self.reload_ckpt = config.get("reload_ckpt", False) - def instantiate(self): self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) self.is_model = True From 8e6f03d5f37a7272e7a98a9bec516cfaafd8e5cc Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 19:33:36 +0200 Subject: [PATCH 19/29] Combine instantiate and make_* into a single method make_model() --- gflownet/policy/base.py | 21 +++++++++++++++------ gflownet/policy/cnn.py | 28 +++++++++++++++++----------- gflownet/policy/mlp.py | 21 +++++++++++++-------- 3 files changed, 45 insertions(+), 25 deletions(-) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 68fb489fb..fb5e86152 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -54,15 +54,24 @@ def __init__( # Checkpoint, defaults to None self.checkpoint = config.get("checkpoint", None) # Instantiate the model - self.instantiate() + self.model, self.is_model = self.make_model() - def instantiate(self): + def make_model(self) -> Tuple[Union[torch.Tensor, torch.nn.Module], bool]: + """ + Instantiates the model of the policy. + + Returns + ------- + model : torch.tensor or torch.nn.Module + A tensor representing the output of the policy or a torch model. + is_model : bool + True if the policy is a model (for example, a neural network) and False if + it is a fixed tensor (for example to make a uniform distribution). + """ if self.type == "fixed": - self.model = self.fixed_distribution - self.is_model = False + return self.fixed_distribution, False elif self.type == "uniform": - self.model = self.uniform_distribution - self.is_model = False + return self.uniform_distribution, False else: raise "Policy model type not defined" diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py index 7d80d1270..1a7d191b5 100644 --- a/gflownet/policy/cnn.py +++ b/gflownet/policy/cnn.py @@ -22,9 +22,16 @@ def __init__(self, **kwargs): # Base init super().__init__(**kwargs) - def make_cnn(self): + def make_model(self): """ - Defines an CNN with no top layer activation + Instantiates a CNN with no top layer activation. + + Returns + ------- + model : torch.nn.Module + A torch model containing the CNN. + is_model : bool + True because a CNN is a model. """ if self.shared_weights and self.base is not None: layers = list(self.base.model.children())[:-1] @@ -33,14 +40,15 @@ def make_cnn(self): ) model = nn.Sequential(*layers, last_layer).to(self.device) - return model + return model, True current_channels = 1 conv_module = nn.Sequential() if len(self.kernel_sizes) != self.n_layers: raise ValueError( - f"Inconsistent dimensions kernel_sizes != n_layers, {len(self.kernel_sizes)} != {self.n_layers}" + f"Inconsistent dimensions kernel_sizes != n_layers, " + "{len(self.kernel_sizes)} != {self.n_layers}" ) for i in range(self.n_layers): @@ -65,21 +73,19 @@ def make_cnn(self): in_channels = conv_module(dummy_input).numel() if in_channels >= 500_000: # TODO: this could better be handled raise RuntimeWarning( - "Input channels for the dense layer are too big, this will increase number of parameters" + "Input channels for the dense layer are too big, this will " + "increase number of parameters" ) except RuntimeError as e: raise RuntimeError( - "Failed during convolution operation. Ensure that the kernel sizes and strides are appropriate for the input dimensions." + "Failed during convolution operation. Ensure that the kernel sizes " + "and strides are appropriate for the input dimensions." ) from e model = nn.Sequential( conv_module, nn.Flatten(), nn.Linear(in_channels, self.output_dim) ) - return model.to(self.device) - - def instantiate(self): - self.model = self.make_cnn() - self.is_model = True + return model.to(self.device), True def __call__(self, states): states = states.unsqueeze(1) # (batch_size, channels, height, width) diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index bf012e1b6..534ba3691 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -17,9 +17,9 @@ def __init__(self, **kwargs): # Base init super().__init__(**kwargs) - def make_mlp(self, activation: nn.Module): + def make_model(self, activation: nn.Module = nn.LeakyReLU()): """ - Defines an MLP with no top layer activation + Instantiates an MLP with no top layer activation as the policy model. If self.shared_weights is True, the base model with which weights are to be shared must be provided. @@ -28,7 +28,16 @@ def make_mlp(self, activation: nn.Module): ---------- activation : nn.Module Activation function of the MLP layers + + Returns + ------- + model : torch.tensor or torch.nn.Module + A torch model containing the MLP. + is_model : bool + True because an MLP is a model. """ + activation.to(self.device) + if self.shared_weights == True and self.base is not None: mlp = nn.Sequential( self.base.model[:-1], @@ -36,7 +45,7 @@ def make_mlp(self, activation: nn.Module): self.base.model[-1].in_features, self.base.model[-1].out_features ), ) - return mlp + return mlp, True elif self.shared_weights == False: layers_dim = ( [self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)] @@ -56,15 +65,11 @@ def make_mlp(self, activation: nn.Module): + self.tail ) ) - return mlp + return mlp, True else: raise ValueError( "Base Model must be provided when shared_weights is set to True" ) - def instantiate(self): - self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) - self.is_model = True - def __call__(self, states): return self.model(states) From 774c411a69fcc733b94b7cae2e1a29669cd5640d Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 20:02:45 +0200 Subject: [PATCH 20/29] Missing import --- gflownet/policy/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index fb5e86152..81f0d4b3c 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -2,7 +2,7 @@ Base Policy class for GFlowNet policy models. """ -from typing import Union +from typing import Tuple, Union import torch from omegaconf import OmegaConf From 9520315e1609b77cd333ea59a1977a7601f3e16d Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 20:13:56 +0200 Subject: [PATCH 21/29] Fix config issue by implementing _get_config() --- gflownet/policy/base.py | 24 +++++++++++++++++++++--- gflownet/policy/cnn.py | 1 + gflownet/policy/mlp.py | 1 + 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 81f0d4b3c..733293823 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -36,9 +36,7 @@ def __init__( float_precision : int or torch.dtype The floating point precision to be passed to torch tensors. """ - # If config is None, instantiate an empty config (defaults will be used) - if config is None: - config = OmegaConf.create() + config = self._get_config(config) # Device and float precision self.device = set_device(device) self.float = set_float_precision(float_precision) @@ -56,6 +54,26 @@ def __init__( # Instantiate the model self.model, self.is_model = self.make_model() + @staticmethod + def _get_config(config: Union[dict, DictConfig]) -> Union[dict, DictConfig]: + """ + Returns a configuration dictionary, even if the input is None. + + Parameters + ---------- + config : dict or DictConfig + The configuration dictionary to set up the policy model. It may be None, in + which an empty config is created and the defaults will be used. + + Returns + ------- + config : dict or DictConfig + The configuration dictionary to set up the policy model. + """ + if config is None: + config = OmegaConf.create() + return config + def make_model(self) -> Tuple[Union[torch.Tensor, torch.nn.Module], bool]: """ Instantiates the model of the policy. diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py index 1a7d191b5..52693fc37 100644 --- a/gflownet/policy/cnn.py +++ b/gflownet/policy/cnn.py @@ -7,6 +7,7 @@ class CNNPolicy(Policy): def __init__(self, **kwargs): + config = self._get_config(kwargs["config"]) # Shared weights, defaults to False self.shared_weights = config.get("shared_weights", False) # Reload checkpoint, defaults to False diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index 534ba3691..8f4fbc801 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -6,6 +6,7 @@ class MLPPolicy(Policy): def __init__(self, **kwargs): + config = self._get_config(kwargs["config"]) # Shared weights, defaults to False self.shared_weights = config.get("shared_weights", False) # Reload checkpoint, defaults to False From 910a9480f69730b392d1edc9f0483dc4d03936bf Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 20:22:45 +0200 Subject: [PATCH 22/29] Docstring for base argument --- gflownet/policy/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 733293823..5d26e0c14 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -35,6 +35,8 @@ def __init__( The device to be passed to torch tensors. float_precision : int or torch.dtype The floating point precision to be passed to torch tensors. + base: Policy (optional) + A base policy to be used as backbone for the backward policy. """ config = self._get_config(config) # Device and float precision From 9fa9381c60cbee3e3ee886eb325624d76bccba37 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Mon, 23 Sep 2024 14:54:59 -0400 Subject: [PATCH 23/29] remove env from the cnn policy --- gflownet/policy/cnn.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py index 52693fc37..96f465b20 100644 --- a/gflownet/policy/cnn.py +++ b/gflownet/policy/cnn.py @@ -17,9 +17,6 @@ def __init__(self, **kwargs): self.channels = config.get("channels", [16] * self.n_layers) self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers) self.strides = config.get("strides", [(1, 1)] * self.n_layers) - # Environment - # TODO: rethink whether storing the whole environment is needed - self.env = env # Base init super().__init__(**kwargs) @@ -68,7 +65,7 @@ def make_model(self): current_channels = self.channels[i] dummy_input = torch.ones( - (1, 1, self.env.height, self.env.width) + (1, 1, self.height, self.width) ) # (batch_size, channels, height, width) try: in_channels = conv_module(dummy_input).numel() From 2bf438a2e76e3b5ba6f57047e05908fe380f83a8 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Mon, 23 Sep 2024 14:56:20 -0400 Subject: [PATCH 24/29] init the cnn env's height and width in the policy --- gflownet/policy/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 5d26e0c14..053eee51c 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -53,6 +53,11 @@ def __init__( self.type = config.get("type", "uniform") # Checkpoint, defaults to None self.checkpoint = config.get("checkpoint", None) + # TODO: This could be done better? We could store this only when using CNN policy. e.g. self.type could be "cnn" + if hasattr(env, 'height'): + self.height = env.height + if hasattr(env, 'width'): + self.width = env.width # Instantiate the model self.model, self.is_model = self.make_model() From 9395e610e01b8705de4ee9436797f52944a9af30 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Mon, 23 Sep 2024 14:57:00 -0400 Subject: [PATCH 25/29] add mlp to device --- gflownet/policy/mlp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index 8f4fbc801..aacea2044 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -37,7 +37,6 @@ def make_model(self, activation: nn.Module = nn.LeakyReLU()): is_model : bool True because an MLP is a model. """ - activation.to(self.device) if self.shared_weights == True and self.base is not None: mlp = nn.Sequential( @@ -66,7 +65,7 @@ def make_model(self, activation: nn.Module = nn.LeakyReLU()): + self.tail ) ) - return mlp, True + return mlp.to(self.device), True else: raise ValueError( "Base Model must be provided when shared_weights is set to True" From 595db808173c54d53591f2aa7169a86ba8183f66 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Mon, 23 Sep 2024 15:00:49 -0400 Subject: [PATCH 26/29] formatting --- gflownet/policy/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 053eee51c..29a6c6100 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -54,9 +54,9 @@ def __init__( # Checkpoint, defaults to None self.checkpoint = config.get("checkpoint", None) # TODO: This could be done better? We could store this only when using CNN policy. e.g. self.type could be "cnn" - if hasattr(env, 'height'): + if hasattr(env, "height"): self.height = env.height - if hasattr(env, 'width'): + if hasattr(env, "width"): self.width = env.width # Instantiate the model self.model, self.is_model = self.make_model() From 6797bccebc5c483f2b9eee598b9937fb90739a04 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Mon, 23 Sep 2024 19:55:34 -0400 Subject: [PATCH 27/29] debug: add to print environment information when running GitHub Actions workflow --- .github/workflows/push_code_check.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/push_code_check.yml b/.github/workflows/push_code_check.yml index 199f6d87f..cb367e7d6 100644 --- a/.github/workflows/push_code_check.yml +++ b/.github/workflows/push_code_check.yml @@ -27,6 +27,9 @@ jobs: with: python-version: "3.10" + - name: Display Python version + run: python -V + - name: Install code dependencies run: bash ./prereq_ci.sh pip install --upgrade @@ -36,6 +39,12 @@ jobs: - name: Install Pytest and Isort run: pip install pytest isort + - name: Display pip list + run: pip list + + - name: Display system information + run: uname -a + - name: Validate import format in main source code run: isort --profile black ./gflownet/ --check-only From 95258a07146ba03cee435241e3fbdc19a35da45e Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Wed, 25 Sep 2024 13:34:04 -0400 Subject: [PATCH 28/29] Add the specific versions of pymatgen and spglib --- prereq_ci.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/prereq_ci.sh b/prereq_ci.sh index 2aec7a616..0590333a5 100755 --- a/prereq_ci.sh +++ b/prereq_ci.sh @@ -5,3 +5,4 @@ python -m pip install torch==2.0.1 python -m pip install dgl==1.1.3+cu118 python -m pip install dglgo==0.0.2 python -m pip install pyg_lib==0.3.1 torch_scatter==2.1.2 torch_sparse==0.6.18 torch_geometric==2.4.0 torch_cluster==1.6.3 torch_spline_conv==1.2.2 -f https://data.pyg.org/whl/torch-2.0.0+cpu.html +python -m pip install pymatgen==2024.5.1 spglib==2.4.0 \ No newline at end of file From b7f33d852226187c443073be7d057bfa5adf44d1 Mon Sep 17 00:00:00 2001 From: Jama Hussein Mohamud Date: Wed, 25 Sep 2024 16:51:55 -0400 Subject: [PATCH 29/29] revert version downgrade of pymatgen --- prereq_ci.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/prereq_ci.sh b/prereq_ci.sh index 0590333a5..4ca1afb19 100755 --- a/prereq_ci.sh +++ b/prereq_ci.sh @@ -4,5 +4,4 @@ python -m pip install --ignore-installed six appdirs python -m pip install torch==2.0.1 python -m pip install dgl==1.1.3+cu118 python -m pip install dglgo==0.0.2 -python -m pip install pyg_lib==0.3.1 torch_scatter==2.1.2 torch_sparse==0.6.18 torch_geometric==2.4.0 torch_cluster==1.6.3 torch_spline_conv==1.2.2 -f https://data.pyg.org/whl/torch-2.0.0+cpu.html -python -m pip install pymatgen==2024.5.1 spglib==2.4.0 \ No newline at end of file +python -m pip install pyg_lib==0.3.1 torch_scatter==2.1.2 torch_sparse==0.6.18 torch_geometric==2.4.0 torch_cluster==1.6.3 torch_spline_conv==1.2.2 -f https://data.pyg.org/whl/torch-2.0.0+cpu.html \ No newline at end of file