diff --git a/.github/workflows/push_code_check.yml b/.github/workflows/push_code_check.yml index 199f6d87f..cb367e7d6 100644 --- a/.github/workflows/push_code_check.yml +++ b/.github/workflows/push_code_check.yml @@ -27,6 +27,9 @@ jobs: with: python-version: "3.10" + - name: Display Python version + run: python -V + - name: Install code dependencies run: bash ./prereq_ci.sh pip install --upgrade @@ -36,6 +39,12 @@ jobs: - name: Install Pytest and Isort run: pip install pytest isort + - name: Display pip list + run: pip list + + - name: Display system information + run: uname -a + - name: Validate import format in main source code run: isort --profile black ./gflownet/ --check-only diff --git a/.gitignore b/.gitignore index 64a683331..cc4f3ca2d 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ playground/ !docs/requirements-docs.txt .DS_Store docs/_build/ +logs diff --git a/config/env/tetris.yaml b/config/env/tetris.yaml index 95ae0503c..bd46fc32a 100644 --- a/config/env/tetris.yaml +++ b/config/env/tetris.yaml @@ -11,6 +11,8 @@ height: 20 pieces: ["I", "J", "L", "O", "S", "T", "Z"] # Allowed roations rotations: [0, 90, 180, 270] +# Don't flatten if using CNN +flatten: True # Other config allow_redundant_rotations: False allow_eos_before_full: False diff --git a/config/policy/cnn.yaml b/config/policy/cnn.yaml new file mode 100644 index 000000000..98818bd8e --- /dev/null +++ b/config/policy/cnn.yaml @@ -0,0 +1,16 @@ +_target_: gflownet.policy.cnn.CNNPolicy + +shared: null + +forward: + n_layers: 2 + channels: [16, 32] + kernel_sizes: [[3, 3], [2, 2]] # Each tuple represents (height, width) + strides: [[1, 1], [1, 1]] # Each tuple represents (vertical_stride, horizontal_stride) + checkpoint: null + reload_ckpt: False + +backward: + shared_weights: True + checkpoint: null + reload_ckpt: False diff --git a/config/policy/mlp.yaml b/config/policy/mlp.yaml index a9da46af7..b9e2f9a6c 100644 --- a/config/policy/mlp.yaml +++ b/config/policy/mlp.yaml @@ -1,9 +1,8 @@ -_target_: gflownet.policy.base.Policy +_target_: gflownet.policy.mlp.MLPPolicy shared: null forward: - type: mlp n_hid: 128 n_layers: 2 checkpoint: null diff --git a/gflownet/envs/tetris.py b/gflownet/envs/tetris.py index bf927de2b..4fa98764d 100644 --- a/gflownet/envs/tetris.py +++ b/gflownet/envs/tetris.py @@ -75,6 +75,7 @@ def __init__( height: int = 20, pieces: List = ["I", "J", "L", "O", "S", "T", "Z"], rotations: List = [0, 90, 180, 270], + flatten: bool = True, allow_redundant_rotations: bool = False, allow_eos_before_full: bool = False, **kwargs, @@ -87,6 +88,7 @@ def __init__( self.height = height self.pieces = pieces self.rotations = rotations + self.flatten = flatten self.allow_redundant_rotations = allow_redundant_rotations self.allow_eos_before_full = allow_eos_before_full self.max_pieces_per_type = 100 @@ -307,7 +309,9 @@ def states2policy( A tensor containing all the states in the batch. """ states = tint(states, device=self.device, int_type=self.int) - return self.states2proxy(states).flatten(start_dim=1).to(self.float) + if self.flatten: + return self.states2proxy(states).flatten(start_dim=1).to(self.float) + return self.states2proxy(states).to(self.float) def state2readable(self, state: Optional[TensorType["height", "width"]] = None): """ @@ -581,7 +585,7 @@ def _plot_board(board, ax: Axes, cellsize: int = 20, linewidth: int = 2): linewidth : int The width of the separation between cells, in pixels. """ - board = board.clone().numpy() + board = board.clone().cpu().numpy() height = board.shape[0] * cellsize width = board.shape[1] * cellsize board_img = 128 * np.ones( diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index adc633c73..aa2d91c86 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -278,7 +278,7 @@ def compute_log_prob_metrics(self, x_tt, metrics=None): if "corr_prob_traj_rewards" in metrics: lp_metrics["corr_prob_traj_rewards"] = np.corrcoef( - np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt + np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt.cpu().numpy() )[0, 1] if "var_logrewards_logp" in metrics: diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index f84b335fc..29a6c6100 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -1,14 +1,44 @@ -from abc import ABC, abstractmethod +""" +Base Policy class for GFlowNet policy models. +""" + +from typing import Tuple, Union import torch from omegaconf import OmegaConf -from torch import nn +from omegaconf.dictconfig import DictConfig +from gflownet.envs.base import GFlowNetEnv from gflownet.utils.common import set_device, set_float_precision -class ModelBase(ABC): - def __init__(self, config, env, device, float_precision, base=None): +class Policy: + def __init__( + self, + config: Union[dict, DictConfig], + env: GFlowNetEnv, + device: Union[str, torch.device], + float_precision: [int, torch.dtype], + base=None, + ): + """ + Base Policy class for a :class:`GFlowNetAgent`. + + Parameters + ---------- + config : dict or DictConfig + The configuration dictionary to set up the policy model. + env : GFlowNetEnv + The environment used to train the :class:`GFlowNetAgent`, used to extract + needed properties. + device : str or torch.device + The device to be passed to torch tensors. + float_precision : int or torch.dtype + The floating point precision to be passed to torch tensors. + base: Policy (optional) + A base policy to be used as backbone for the backward policy. + """ + config = self._get_config(config) # Device and float precision self.device = set_device(device) self.float = set_float_precision(float_precision) @@ -19,98 +49,60 @@ def __init__(self, config, env, device, float_precision, base=None): self.output_dim = len(self.fixed_output) # Optional base model self.base = base + # Policy type, defaults to uniform + self.type = config.get("type", "uniform") + # Checkpoint, defaults to None + self.checkpoint = config.get("checkpoint", None) + # TODO: This could be done better? We could store this only when using CNN policy. e.g. self.type could be "cnn" + if hasattr(env, "height"): + self.height = env.height + if hasattr(env, "width"): + self.width = env.width + # Instantiate the model + self.model, self.is_model = self.make_model() + + @staticmethod + def _get_config(config: Union[dict, DictConfig]) -> Union[dict, DictConfig]: + """ + Returns a configuration dictionary, even if the input is None. - self.parse_config(config) - - def parse_config(self, config): - # If config is null, default to uniform + Parameters + ---------- + config : dict or DictConfig + The configuration dictionary to set up the policy model. It may be None, in + which an empty config is created and the defaults will be used. + + Returns + ------- + config : dict or DictConfig + The configuration dictionary to set up the policy model. + """ if config is None: config = OmegaConf.create() - config.type = "uniform" - self.checkpoint = config.get("checkpoint", None) - self.shared_weights = config.get("shared_weights", False) - self.n_hid = config.get("n_hid", None) - self.n_layers = config.get("n_layers", None) - self.tail = config.get("tail", []) - if "type" in config: - self.type = config.type - elif self.shared_weights: - self.type = self.base.type - else: - raise "Policy type must be defined if shared_weights is False" - - @abstractmethod - def instantiate(self): - pass + return config - def __call__(self, states): - return self.model(states) - - def make_mlp(self, activation): + def make_model(self) -> Tuple[Union[torch.Tensor, torch.nn.Module], bool]: """ - Defines an MLP with no top layer activation - If share_weight == True, - baseModel (the model with which weights are to be shared) must be provided - Args - ---- - layers_dim : list - Dimensionality of each layer - activation : Activation - Activation function + Instantiates the model of the policy. + + Returns + ------- + model : torch.tensor or torch.nn.Module + A tensor representing the output of the policy or a torch model. + is_model : bool + True if the policy is a model (for example, a neural network) and False if + it is a fixed tensor (for example to make a uniform distribution). """ - if self.shared_weights == True and self.base is not None: - mlp = nn.Sequential( - self.base.model[:-1], - nn.Linear( - self.base.model[-1].in_features, self.base.model[-1].out_features - ), - ) - return mlp - elif self.shared_weights == False: - layers_dim = ( - [self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)] - ) - mlp = nn.Sequential( - *( - sum( - [ - [nn.Linear(idim, odim)] - + ([activation] if n < len(layers_dim) - 2 else []) - for n, (idim, odim) in enumerate( - zip(layers_dim, layers_dim[1:]) - ) - ], - [], - ) - + self.tail - ) - ) - return mlp - else: - raise ValueError( - "Base Model must be provided when shared_weights is set to True" - ) - - -class Policy(ModelBase): - def __init__(self, config, env, device, float_precision, base=None): - super().__init__(config, env, device, float_precision, base) - - self.instantiate() - - def instantiate(self): if self.type == "fixed": - self.model = self.fixed_distribution - self.is_model = False + return self.fixed_distribution, False elif self.type == "uniform": - self.model = self.uniform_distribution - self.is_model = False - elif self.type == "mlp": - self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) - self.is_model = True + return self.uniform_distribution, False else: raise "Policy model type not defined" + def __call__(self, states): + return self.model(states) + def fixed_distribution(self, states): """ Returns the fixed distribution specified by the environment. diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py new file mode 100644 index 000000000..96f465b20 --- /dev/null +++ b/gflownet/policy/cnn.py @@ -0,0 +1,90 @@ +import torch +from omegaconf import OmegaConf +from torch import nn + +from gflownet.policy.base import Policy + + +class CNNPolicy(Policy): + def __init__(self, **kwargs): + config = self._get_config(kwargs["config"]) + # Shared weights, defaults to False + self.shared_weights = config.get("shared_weights", False) + # Reload checkpoint, defaults to False + self.reload_ckpt = config.get("reload_ckpt", False) + # CNN features: number of layers, number of channels, kernel sizes, strides + self.n_layers = config.get("n_layers", 3) + self.channels = config.get("channels", [16] * self.n_layers) + self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers) + self.strides = config.get("strides", [(1, 1)] * self.n_layers) + # Base init + super().__init__(**kwargs) + + def make_model(self): + """ + Instantiates a CNN with no top layer activation. + + Returns + ------- + model : torch.nn.Module + A torch model containing the CNN. + is_model : bool + True because a CNN is a model. + """ + if self.shared_weights and self.base is not None: + layers = list(self.base.model.children())[:-1] + last_layer = nn.Linear( + self.base.model[-1].in_features, self.base.model[-1].out_features + ) + + model = nn.Sequential(*layers, last_layer).to(self.device) + return model, True + + current_channels = 1 + conv_module = nn.Sequential() + + if len(self.kernel_sizes) != self.n_layers: + raise ValueError( + f"Inconsistent dimensions kernel_sizes != n_layers, " + "{len(self.kernel_sizes)} != {self.n_layers}" + ) + + for i in range(self.n_layers): + conv_module.add_module( + f"conv_{i}", + nn.Conv2d( + in_channels=current_channels, + out_channels=self.channels[i], + kernel_size=tuple(self.kernel_sizes[i]), + stride=tuple(self.strides[i]), + padding=0, + padding_mode="zeros", # Constant zero padding + ), + ) + conv_module.add_module(f"relu_{i}", nn.ReLU()) + current_channels = self.channels[i] + + dummy_input = torch.ones( + (1, 1, self.height, self.width) + ) # (batch_size, channels, height, width) + try: + in_channels = conv_module(dummy_input).numel() + if in_channels >= 500_000: # TODO: this could better be handled + raise RuntimeWarning( + "Input channels for the dense layer are too big, this will " + "increase number of parameters" + ) + except RuntimeError as e: + raise RuntimeError( + "Failed during convolution operation. Ensure that the kernel sizes " + "and strides are appropriate for the input dimensions." + ) from e + + model = nn.Sequential( + conv_module, nn.Flatten(), nn.Linear(in_channels, self.output_dim) + ) + return model.to(self.device), True + + def __call__(self, states): + states = states.unsqueeze(1) # (batch_size, channels, height, width) + return self.model(states) diff --git a/gflownet/policy/gnn.py b/gflownet/policy/gnn.py new file mode 100644 index 000000000..e69de29bb diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py new file mode 100644 index 000000000..aacea2044 --- /dev/null +++ b/gflownet/policy/mlp.py @@ -0,0 +1,75 @@ +from omegaconf import OmegaConf +from torch import nn + +from gflownet.policy.base import Policy + + +class MLPPolicy(Policy): + def __init__(self, **kwargs): + config = self._get_config(kwargs["config"]) + # Shared weights, defaults to False + self.shared_weights = config.get("shared_weights", False) + # Reload checkpoint, defaults to False + self.reload_ckpt = config.get("reload_ckpt", False) + # MLP features: number of layers, number of hidden units, tail, etc. + self.n_layers = config.get("n_layers", 2) + self.n_hid = config.get("n_hid", 128) + self.tail = config.get("tail", []) + # Base init + super().__init__(**kwargs) + + def make_model(self, activation: nn.Module = nn.LeakyReLU()): + """ + Instantiates an MLP with no top layer activation as the policy model. + + If self.shared_weights is True, the base model with which weights are to be + shared must be provided. + + Parameters + ---------- + activation : nn.Module + Activation function of the MLP layers + + Returns + ------- + model : torch.tensor or torch.nn.Module + A torch model containing the MLP. + is_model : bool + True because an MLP is a model. + """ + + if self.shared_weights == True and self.base is not None: + mlp = nn.Sequential( + self.base.model[:-1], + nn.Linear( + self.base.model[-1].in_features, self.base.model[-1].out_features + ), + ) + return mlp, True + elif self.shared_weights == False: + layers_dim = ( + [self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)] + ) + mlp = nn.Sequential( + *( + sum( + [ + [nn.Linear(idim, odim)] + + ([activation] if n < len(layers_dim) - 2 else []) + for n, (idim, odim) in enumerate( + zip(layers_dim, layers_dim[1:]) + ) + ], + [], + ) + + self.tail + ) + ) + return mlp.to(self.device), True + else: + raise ValueError( + "Base Model must be provided when shared_weights is set to True" + ) + + def __call__(self, states): + return self.model(states) diff --git a/playground/botorch/mes_exact_deepKernel.py b/playground/botorch/mes_exact_deepKernel.py index b77bb2e89..743b68256 100644 --- a/playground/botorch/mes_exact_deepKernel.py +++ b/playground/botorch/mes_exact_deepKernel.py @@ -8,7 +8,6 @@ from math import floor import gpytorch - # import tqdm import torch from botorch.test_functions import Hartmann diff --git a/playground/botorch/mes_gp.py b/playground/botorch/mes_gp.py index b51df0ce6..8afde5dc8 100644 --- a/playground/botorch/mes_gp.py +++ b/playground/botorch/mes_gp.py @@ -6,7 +6,6 @@ import numpy as np import torch - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Branin, Hartmann diff --git a/playground/botorch/mes_gp_debug.py b/playground/botorch/mes_gp_debug.py index 06c5a3ed6..76af6ff00 100644 --- a/playground/botorch/mes_gp_debug.py +++ b/playground/botorch/mes_gp_debug.py @@ -3,7 +3,6 @@ import gpytorch import numpy as np import torch - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -50,8 +49,10 @@ def forward(self, x): from botorch.models.utils import add_output_dim from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal -from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) +from gpytorch.likelihoods.gaussian_likelihood import \ + FixedNoiseGaussianLikelihood class myGPModel(SingleTaskGP): diff --git a/playground/botorch/mes_nn_bao_fix.py b/playground/botorch/mes_nn_bao_fix.py index c4f7de6d0..861268aeb 100644 --- a/playground/botorch/mes_nn_bao_fix.py +++ b/playground/botorch/mes_nn_bao_fix.py @@ -2,7 +2,6 @@ import numpy as np import torch - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -56,8 +55,8 @@ from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal - +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) # from botorch.posteriors. from torch.distributions import Normal diff --git a/playground/botorch/mes_nn_hardcode_gpVal.py b/playground/botorch/mes_nn_hardcode_gpVal.py index 6320d4f05..42dcbb9b4 100644 --- a/playground/botorch/mes_nn_hardcode_gpVal.py +++ b/playground/botorch/mes_nn_hardcode_gpVal.py @@ -2,7 +2,6 @@ import numpy as np import torch - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -57,7 +56,8 @@ from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) class NN_Model(Model): diff --git a/playground/botorch/mes_nn_like_gp.py b/playground/botorch/mes_nn_like_gp.py index d0664a342..0b15c98be 100644 --- a/playground/botorch/mes_nn_like_gp.py +++ b/playground/botorch/mes_nn_like_gp.py @@ -3,15 +3,14 @@ import numpy as np import torch from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.test_functions import Hartmann -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) from gpytorch.mlls import ExactMarginalLogLikelihood - # from botorch.posteriors. from torch import distributions, tensor from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential diff --git a/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py b/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py index 2c75fd6a4..1d6626b33 100644 --- a/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py +++ b/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py @@ -3,15 +3,14 @@ import numpy as np import torch from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy - # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.test_functions import Hartmann -from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.distributions import (MultitaskMultivariateNormal, + MultivariateNormal) from gpytorch.mlls import ExactMarginalLogLikelihood - # from botorch.posteriors. from torch import distributions, tensor from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential diff --git a/playground/botorch/mes_var_deepKernel.py b/playground/botorch/mes_var_deepKernel.py index f712eaaf0..989af46c4 100644 --- a/playground/botorch/mes_var_deepKernel.py +++ b/playground/botorch/mes_var_deepKernel.py @@ -10,7 +10,6 @@ from math import floor import gpytorch - # import tqdm import torch from botorch.test_functions import Hartmann @@ -215,9 +214,7 @@ def posterior( from botorch.acquisition.max_value_entropy_search import ( - qLowerBoundMaxValueEntropy, - qMaxValueEntropy, -) + qLowerBoundMaxValueEntropy, qMaxValueEntropy) proxy = myGPModel(model, train_x, train_y.unsqueeze(-1)) qMES = qLowerBoundMaxValueEntropy(proxy, candidate_set=train_x, use_gumbel=True) diff --git a/prereq_ci.sh b/prereq_ci.sh index 2aec7a616..4ca1afb19 100755 --- a/prereq_ci.sh +++ b/prereq_ci.sh @@ -4,4 +4,4 @@ python -m pip install --ignore-installed six appdirs python -m pip install torch==2.0.1 python -m pip install dgl==1.1.3+cu118 python -m pip install dglgo==0.0.2 -python -m pip install pyg_lib==0.3.1 torch_scatter==2.1.2 torch_sparse==0.6.18 torch_geometric==2.4.0 torch_cluster==1.6.3 torch_spline_conv==1.2.2 -f https://data.pyg.org/whl/torch-2.0.0+cpu.html +python -m pip install pyg_lib==0.3.1 torch_scatter==2.1.2 torch_sparse==0.6.18 torch_geometric==2.4.0 torch_cluster==1.6.3 torch_spline_conv==1.2.2 -f https://data.pyg.org/whl/torch-2.0.0+cpu.html \ No newline at end of file diff --git a/scripts/crystal/eval_crystalgflownet.py b/scripts/crystal/eval_crystalgflownet.py index beae9a130..4eb77ec3c 100644 --- a/scripts/crystal/eval_crystalgflownet.py +++ b/scripts/crystal/eval_crystalgflownet.py @@ -1,6 +1,7 @@ """ Computes evaluation metrics and plots from a pre-trained GFlowNet model. """ + import pickle import shutil import sys @@ -14,6 +15,7 @@ sys.path.append(str(Path(__file__).resolve().parent.parent)) from crystalrandom import generate_random_crystals + from gflownet.gflownet import GFlowNetAgent from gflownet.utils.common import load_gflow_net_from_run_path from gflownet.utils.policy import parse_policy_config diff --git a/scripts/crystal/eval_gflownet.py b/scripts/crystal/eval_gflownet.py index e8857e5b5..f90d051bb 100644 --- a/scripts/crystal/eval_gflownet.py +++ b/scripts/crystal/eval_gflownet.py @@ -8,8 +8,8 @@ from argparse import ArgumentParser from pathlib import Path -import pandas as pd import numpy as np +import pandas as pd import torch from tqdm import tqdm @@ -229,30 +229,30 @@ def main(args): env.proxy.is_bandgap = False # Test -# samples = [env.readable2state(readable) for readable in gflownet.buffer.test["samples"]] -# energies = env.proxy(env.states2proxy(samples)) -# df = pd.DataFrame( -# { -# "readable": gflownet.buffer.test["samples"], -# "energies": energies.tolist(), -# } -# ) -# df.to_csv(output_dir / f"val.csv") -# dct = {"x": samples, "energy": energies.tolist()} -# pickle.dump(dct, open(output_dir / f"val.pkl", "wb")) -# -# # Train -# samples = [env.readable2state(readable) for readable in gflownet.buffer.train["samples"]] -# energies = env.proxy(env.states2proxy(samples)) -# df = pd.DataFrame( -# { -# "readable": gflownet.buffer.train["samples"], -# "energies": energies.tolist(), -# } -# ) -# df.to_csv(output_dir / f"train.csv") -# dct = {"x": samples, "energy": energies.tolist()} -# pickle.dump(dct, open(output_dir / f"train.pkl", "wb")) + # samples = [env.readable2state(readable) for readable in gflownet.buffer.test["samples"]] + # energies = env.proxy(env.states2proxy(samples)) + # df = pd.DataFrame( + # { + # "readable": gflownet.buffer.test["samples"], + # "energies": energies.tolist(), + # } + # ) + # df.to_csv(output_dir / f"val.csv") + # dct = {"x": samples, "energy": energies.tolist()} + # pickle.dump(dct, open(output_dir / f"val.pkl", "wb")) + # + # # Train + # samples = [env.readable2state(readable) for readable in gflownet.buffer.train["samples"]] + # energies = env.proxy(env.states2proxy(samples)) + # df = pd.DataFrame( + # { + # "readable": gflownet.buffer.train["samples"], + # "energies": energies.tolist(), + # } + # ) + # df.to_csv(output_dir / f"train.csv") + # dct = {"x": samples, "energy": energies.tolist()} + # pickle.dump(dct, open(output_dir / f"train.pkl", "wb")) if args.n_samples > 0 and args.n_samples <= 1e5 and not args.random_only: print( diff --git a/scripts/crystal/sample_uniform_with_rewards.py b/scripts/crystal/sample_uniform_with_rewards.py index c1d715f4f..e078791d2 100644 --- a/scripts/crystal/sample_uniform_with_rewards.py +++ b/scripts/crystal/sample_uniform_with_rewards.py @@ -3,16 +3,17 @@ should be run with the same config as main.py, e.g. python sample_uniform_with_rewards.py +experiments=crystals/albatross_sg_first logger.do.online=False user=sasha """ + import pickle import sys import hydra import pandas as pd +from crystalrandom import generate_random_crystals_uniform + from gflownet.utils.common import chdir_random_subdir from gflownet.utils.policy import parse_policy_config -from crystalrandom import generate_random_crystals_uniform - @hydra.main(config_path="../../config", config_name="main", version_base="1.1") def main(config): diff --git a/scripts/pyxtal/compatibility_sg_n_atoms.py b/scripts/pyxtal/compatibility_sg_n_atoms.py index b0d08c5bb..a01272d16 100644 --- a/scripts/pyxtal/compatibility_sg_n_atoms.py +++ b/scripts/pyxtal/compatibility_sg_n_atoms.py @@ -3,6 +3,7 @@ combinations spanned by the N_SYMMETRY_GROUPS, N_SPECIES and MAX_N_ATOMS. The results are printed to stdout. """ + import itertools import time diff --git a/scripts/pyxtal/get_n_compatible_for_sg.py b/scripts/pyxtal/get_n_compatible_for_sg.py index 443a72a4a..6e1c0d335 100644 --- a/scripts/pyxtal/get_n_compatible_for_sg.py +++ b/scripts/pyxtal/get_n_compatible_for_sg.py @@ -3,6 +3,7 @@ spanned by the --max_n_atoms and --max_n_species. The results are written to a file in --output_dir. """ + import itertools import time from argparse import ArgumentParser diff --git a/scripts/pyxtal/pyxtal_vs_pymatgen.py b/scripts/pyxtal/pyxtal_vs_pymatgen.py index 62a226ae7..4a1e326be 100644 --- a/scripts/pyxtal/pyxtal_vs_pymatgen.py +++ b/scripts/pyxtal/pyxtal_vs_pymatgen.py @@ -2,6 +2,7 @@ A simple script to determine which space group symbols are different in pyxtal and pymatgen. """ + from argparse import ArgumentParser from pymatgen.symmetry.groups import (