diff --git a/config/env/aptamers.yaml b/config/env/aptamers.yaml deleted file mode 100644 index 5de209fa3..000000000 --- a/config/env/aptamers.yaml +++ /dev/null @@ -1,16 +0,0 @@ -defaults: - - base - -_target_: gflownet.envs.aptamers.AptamerSeq - -id: aptamers -func: nupack energy -# Minimum and maximum length for the sequences -min_seq_length: 30 -max_seq_length: 30 -# Number of letters in alphabet -n_alphabet: 4 -# Minimum and maximum number of steps in the action space -min_word_len: 1 -max_word_len: 1 - diff --git a/gflownet/envs/alaninedipeptide.py b/gflownet/envs/alaninedipeptide.py index 76b725e3b..04b8e39b8 100644 --- a/gflownet/envs/alaninedipeptide.py +++ b/gflownet/envs/alaninedipeptide.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import List, Tuple +from typing import List, Tuple, Union import numpy as np import numpy.typing as npt @@ -40,25 +40,34 @@ def sync_conformer_with_state(self, state: List = None): self.conformer.set_torsion_angle(ta, state[idx]) return self.conformer - def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDArray: + # TODO: are the conversions to oracle relevant? + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> npt.NDArray: """ - Prepares a batch of states in torch "GFlowNet format" for the oracle. - """ - device = states.device - if device == torch.device("cpu"): - np_states = states.numpy() - else: - np_states = states.cpu().numpy() - return np_states[:, :-1] - - def statebatch2proxy(self, states: List[List]) -> npt.NDArray: - """ - Prepares a batch of states in "GFlowNet format" for the proxy: a tensor where - each state is a row of length n_dim with an angle in radians. The n_actions + Prepares a batch of states in "environment format" for the proxy: each state is + a vector of length n_dim where each value is an angle in radians. The n_actions item is removed. + + Important: this method returns a numpy array, unlike in most other + environments. + + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A numpy array containing all the states in the batch. """ - return np.array(states)[:, :-1] + if torch.is_tensor(states[0]): + return states.cpu().numpy()[:, :-1] + else: + return np.array(states)[:, :-1] + # TODO: need to keep? def statetorch2oracle( self, states: TensorType["batch", "state_dim"] ) -> List[Tuple[npt.NDArray, npt.NDArray]]: @@ -73,6 +82,7 @@ def statetorch2oracle( result = self.statebatch2oracle(np_states) return result + # TODO: need to keep? def statebatch2oracle( self, states: List[List] ) -> List[Tuple[npt.NDArray, npt.NDArray]]: diff --git a/gflownet/envs/aptamers.py b/gflownet/envs/aptamers.py deleted file mode 100644 index 425a2eb91..000000000 --- a/gflownet/envs/aptamers.py +++ /dev/null @@ -1,431 +0,0 @@ -""" -Classes to represent aptamers environments -""" -import itertools -import time -from typing import List - -import numpy as np -import numpy.typing as npt -import pandas as pd - -from gflownet.envs.base import GFlowNetEnv - - -class AptamerSeq(GFlowNetEnv): - """ - Aptamer sequence environment - - Attributes - ---------- - max_seq_length : int - Maximum length of the sequences - - min_seq_length : int - Minimum length of the sequences - - n_alphabet : int - Number of letters in the alphabet - - state : list - Representation of a sequence (state), as a list of length max_seq_length where - each element is the index of a letter in the alphabet, from 0 to (n_alphabet - - 1). - - done : bool - True if the sequence has reached a terminal state (maximum length, or stop - action executed. - - func : str - Name of the reward function - - n_actions : int - Number of actions applied to the sequence - - proxy : lambda - Proxy model - """ - - def __init__( - self, - max_seq_length=42, - min_seq_length=1, - n_alphabet=4, - min_word_len=1, - max_word_len=1, - **kwargs, - ): - super().__init__() - self.source = [] - self.min_seq_length = min_seq_length - self.max_seq_length = max_seq_length - self.n_alphabet = n_alphabet - self.min_word_len = min_word_len - self.max_word_len = max_word_len - self.action_space = self.get_action_space() - self.eos = self.action_space_dim - self.reset() - self.fixed_policy_output = self.get_fixed_policy_output() - self.random_policy_output = self.get_fixed_policy_output() - self.policy_output_dim = len(self.fixed_policy_output) - self.policy_input_dim = len(self.state2policy()) - self.max_traj_len = self.get_max_traj_length() - # Set up proxy - self.setup_proxy() - - def get_action_space(self): - """ - Constructs list with all possible actions - """ - assert self.max_word_len >= self.min_word_len - valid_wordlens = np.arange(self.min_word_len, self.max_word_len + 1) - alphabet = [a for a in range(self.n_alphabet)] - actions = [] - for r in valid_wordlens: - actions_r = [el for el in itertools.product(alphabet, repeat=r)] - actions += actions_r - return actions - - def get_max_traj_length( - self, - ): - return self.max_seq_length / self.min_word_len + 1 - - def reward_arbitrary_i(self, state): - if len(state) > 0: - return (state[-1] + 1) * len(state) - else: - return 0 - - def statebatch2oracle(self, states: List[List]): - """ - Prepares a batch of sequence states for the oracles. - - Args - ---- - states : list of lists - List of sequences. - """ - queries = [s + [-1] * (self.max_seq_length - len(s)) for s in states] - queries = np.array(queries, dtype=int) - if queries.ndim == 1: - queries = queries[np.newaxis, ...] - queries += 1 - if queries.shape[1] == 1: - import ipdb - - ipdb.set_trace() - queries = np.column_stack((queries, np.zeros(queries.shape[0]))) - return queries - - def state2policy(self, state: List = None) -> List: - """ - Transforms the sequence (state) given as argument (or self.state if None) into a - one-hot encoding. The output is a list of length n_alphabet * max_seq_length, - where each n-th successive block of n_alphabet elements is a one-hot encoding of - the letter in the n-th position. - - Example: - - Sequence: AATGC - - state: [0, 1, 3, 2] - A, T, G, C - - state2policy(state): [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0] - | A | T | G | C | - - If max_seq_length > len(s), the last (max_seq_length - len(s)) blocks are all - 0s. - """ - if state is None: - state = self.state.copy() - state_policy = np.zeros(self.n_alphabet * self.max_seq_length, dtype=np.float32) - if len(state) > 0: - state_policy[(np.arange(len(state)) * self.n_alphabet + state)] = 1 - return state_policy - - def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: - """ - Transforms a batch of states into the policy model format. The output is a numpy - array of shape [n_states, n_angles * n_dim + 1]. - - See state2policy(). - """ - cols, lengths = zip( - *[ - (state + np.arange(len(state)) * self.n_alphabet, len(state)) - for state in states - ] - ) - rows = np.repeat(np.arange(len(states)), lengths) - state_policy = np.zeros( - (len(states), self.n_alphabet * self.max_seq_length), dtype=np.float32 - ) - state_policy[rows, np.concatenate(cols)] = 1.0 - return state_policy - - def policy2state(self, state_policy: List) -> List: - """ - Transforms the one-hot encoding version of a sequence (state) given as argument - into a a sequence of letter indices. - - Example: - - Sequence: AATGC - - state_policy: [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0] - | A | A | T | G | C | - - state: [0, 0, 1, 3, 2] - A, A, T, G, C - """ - return np.where( - np.reshape(state_policy, (self.max_seq_length, self.n_alphabet)) - )[1].tolist() - - def state2readable(self, state, alphabet={0: "A", 1: "T", 2: "C", 3: "G"}): - """ - Transforms a sequence given as a list of indices into a sequence of letters - according to an alphabet. - """ - return [alphabet[el] for el in state] - - def readable2state(self, letters, alphabet={0: "A", 1: "T", 2: "C", 3: "G"}): - """ - Transforms a sequence given as a list of indices into a sequence of letters - according to an alphabet. - """ - alphabet = {v: k for k, v in alphabet.items()} - return [alphabet[el] for el in letters] - - def reset(self, env_id=None): - """ - Resets the environment. - """ - self.state = [] - self.n_actions = 0 - self.done = False - self.id = env_id - return self - - def get_parents(self, state=None, done=None, action=None): - """ - Determines all parents and actions that lead to sequence state - - Args - ---- - state : list - Representation of a sequence (state), as a list of length max_seq_length - where each element is the index of a letter in the alphabet, from 0 to - (n_alphabet - 1). - - done : bool - Whether the trajectory is done. If None, done is taken from instance. - - action : None - Ignored - - Returns - ------- - parents : list - List of parents in state format - - actions : list - List of actions that lead to state for each parent in parents - """ - if state is None: - state = self.state.copy() - if done is None: - done = self.done - if done: - return [state], [self.eos] - else: - parents = [] - actions = [] - for idx, a in enumerate(self.action_space): - is_parent = state[-len(a) :] == list(a) - if not isinstance(is_parent, bool): - is_parent = all(is_parent) - if is_parent: - parents.append(state[: -len(a)]) - actions.append(idx) - return parents, actions - - def step(self, action_idx): - """ - Executes step given an action index - - If action_idx is smaller than eos (no stop), add action to next - position. - - See: step_daug() - See: step_chain() - - Args - ---- - action_idx : int - Index of action in the action space. a == eos indicates "stop action" - - Returns - ------- - self.state : list - The sequence after executing the action - - valid : bool - False, if the action is not allowed for the current state, e.g. stop at the - root state - """ - # If only possible action is eos, then force eos - if len(self.state) == self.max_seq_length: - self.done = True - self.n_actions += 1 - return self.state, [self.eos], True - # If action is not eos, then perform action - if action_idx != self.eos: - action = self.action_space[action_idx] - state_next = self.state + list(action) - if len(state_next) > self.max_seq_length: - valid = False - else: - self.state = state_next - valid = True - self.n_actions += 1 - return self.state, action_idx, valid - # If action is eos, then perform eos - else: - if len(self.state) < self.min_seq_length: - valid = False - else: - self.done = True - valid = True - self.n_actions += 1 - return self.state, self.eos, valid - - def get_mask_invalid_actions_forward(self, state=None, done=None): - """ - Returns a vector of length the action space + 1: True if action is invalid - given the current state, False otherwise. - """ - if state is None: - state = self.state.copy() - if done is None: - done = self.done - if done: - return [True for _ in range(self.action_space_dim + 1)] - mask = [False for _ in range(self.action_space_dim + 1)] - seq_length = len(state) - if seq_length < self.min_seq_length: - mask[self.eos] = True - for idx, a in enumerate(self.action_space): - if seq_length + len(a) > self.max_seq_length: - mask[idx] = True - return mask - - def make_train_set( - self, - ntrain, - oracle=None, - seed=168, - output_csv=None, - ): - """ - Constructs a randomly sampled train set. - - Args - ---- - ntest : int - Number of test samples. - - seed : int - Random seed. - - output_csv: str - Optional path to store the test set as CSV. - """ - samples_dict = oracle.initializeDataset( - save=False, returnData=True, customSize=ntrain, custom_seed=seed - ) - energies = samples_dict["energies"] - samples_mat = samples_dict["samples"] - state_letters = oracle.numbers2letters(samples_mat) - state_ints = [ - "".join([str(el) for el in state if el > 0]) for state in samples_mat - ] - if isinstance(energies, dict): - energies.update({"samples": state_letters, "indices": state_ints}) - df_train = pd.DataFrame(energies) - else: - df_train = pd.DataFrame( - {"samples": state_letters, "indices": state_ints, "energies": energies} - ) - if output_csv: - df_train.to_csv(output_csv) - return df_train - - # TODO: improve approximation of uniform - def make_test_set( - self, - path_base_dataset, - ntest, - min_length=0, - max_length=np.inf, - seed=167, - output_csv=None, - ): - """ - Constructs an approximately uniformly distributed (on the score) set, by - selecting samples from a larger base set. - - Args - ---- - path_base_dataset : str - Path to a CSV file containing the base data set. - - ntest : int - Number of test samples. - - seed : int - Random seed. - - dask : bool - If True, use dask to efficiently read a large base file. - - output_csv: str - Optional path to store the test set as CSV. - """ - if path_base_dataset is None: - return None, None - times = { - "all": 0.0, - "indices": 0.0, - } - t0_all = time.time() - if seed: - np.random.seed(seed) - df_base = pd.read_csv(path_base_dataset, index_col=0) - df_base = df_base.loc[ - (df_base["samples"].map(len) >= min_length) - & (df_base["samples"].map(len) <= max_length) - ] - energies_base = df_base["energies"].values - min_base = energies_base.min() - max_base = energies_base.max() - distr_unif = np.random.uniform(low=min_base, high=max_base, size=ntest) - # Get minimum distance samples without duplicates - t0_indices = time.time() - idx_samples = [] - for idx in tqdm(range(ntest)): - dist = np.abs(energies_base - distr_unif[idx]) - idx_min = np.argmin(dist) - if idx_min in idx_samples: - idx_sort = np.argsort(dist) - for idx_next in idx_sort: - if idx_next not in idx_samples: - idx_samples.append(idx_next) - break - else: - idx_samples.append(idx_min) - t1_indices = time.time() - times["indices"] += t1_indices - t0_indices - # Make test set - df_test = df_base.iloc[idx_samples] - if output_csv: - df_test.to_csv(output_csv) - t1_all = time.time() - times["all"] += t1_all - t0_all - return df_test, times diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 278490adc..ab68825a3 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -38,7 +38,6 @@ def __init__( energies_stats: List[int] = None, denorm_proxy: bool = False, proxy=None, - oracle=None, proxy_state_format: str = "oracle", fixed_distr_params: Optional[dict] = None, random_distr_params: Optional[dict] = None, @@ -68,17 +67,10 @@ def __init__( self.reward_func = reward_func self.energies_stats = energies_stats self.denorm_proxy = denorm_proxy - # Proxy and oracle + # Proxy self.proxy = proxy self.setup_proxy() - if oracle is None: - self.oracle = self.proxy - else: - self.oracle = oracle - if self.oracle is None or self.oracle.higher_is_better: - self.proxy_factor = 1.0 - else: - self.proxy_factor = -1.0 + self.proxy_factor = -1.0 self.proxy_state_format = proxy_state_format # Flag to skip checking if action is valid (computing mask) before step self.skip_mask_check = skip_mask_check @@ -100,9 +92,6 @@ def __init__( self.random_policy_output = self.get_policy_output(self.random_distr_params) self.policy_output_dim = len(self.fixed_policy_output) self.policy_input_dim = len(self.state2policy()) - if proxy is not None and self.proxy == self.oracle: - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle @abstractmethod def get_action_space(self): @@ -683,91 +672,75 @@ def get_policy_output( """ return torch.ones(self.action_space_dim, dtype=self.float, device=self.device) - def state2proxy(self, state: List = None): + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a state in "GFlowNet format" for the proxy. + Prepares a batch of states in "environment format" for the proxy. By default, + the batch of states is converted into a tensor with float dtype and returned as + is. Args ---- - state : list - A state + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ - if state is None: - state = self.state.copy() - return self.statebatch2proxy([state]) + return tfloat(states, device=self.device, float_type=self.float) - def statebatch2proxy(self, states: List[List]) -> npt.NDArray[np.float32]: + def state2proxy( + self, state: Union[List, TensorType["state_dim"]] = None + ) -> TensorType["state_proxy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the proxy. + Prepares a state in "GFlowNet format" for the proxy. By default, states2proxy + is called, which by default will return the state as is. Args ---- state : list A state """ - return np.array(states) - - def statetorch2proxy( - self, states: TensorType["batch_size", "state_dim"] - ) -> TensorType["batch_size", "state_proxy_dim"]: - """ - Prepares a batch of states in torch "GFlowNet format" for the proxy. - """ - return states + state = self._get_state(state) + return torch.squeeze(self.states2proxy([state]), dim=0) - def state2oracle(self, state: List = None): + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "policy_input_dim"]: """ - Prepares a state in "GFlowNet format" for the oracle. + Prepares a batch of states in "environment format" for the policy model: By + default, the batch of states is converted into a tensor with float dtype and + returned as is. Args ---- - state : list - A state - """ - if state is None: - state = self.state.copy() - return state - - def statebatch2oracle(self, states: List[List]): - """ - Prepares a batch of states in "GFlowNet format" for the oracles - """ - return states + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - def statetorch2policy( - self, states: TensorType["batch_size", "state_dim"] - ) -> TensorType["batch_size", "policy_input_dim"]: - """ - Prepares a batch of states in torch "GFlowNet format" for the policy - """ - return states - - def state2policy(self, state=None): - """ - Converts a state into a format suitable for a machine learning model, such as a - one-hot encoding. + Returns + ------- + A tensor containing all the states in the batch. """ - if state is None: - state = self.state - return tfloat(state, float_type=self.float, device=self.device) + return tfloat(states, device=self.device, float_type=self.float) - def statebatch2policy( - self, states: List[List] - ) -> TensorType["batch_size", "policy_input_dim"]: + def state2policy( + self, state: Union[List, TensorType["state_dim"]] = None + ) -> TensorType["policy_input_dim"]: """ - Converts a batch of states into a format suitable for a machine learning model, - such as a one-hot encoding. Returns a numpy array. - """ - return self.statetorch2policy( - tfloat(states, float_type=self.float, device=self.device) - ) + Prepares a state in "GFlowNet format" for the policy model. By default, + states2policy is called, which by default will return the state as is. - def policy2state(self, state_policy: List) -> List: - """ - Converts the model (e.g. one-hot encoding) version of a state given as - argument into a state. + Args + ---- + state : list + A state """ - return state_policy + state = self._get_state(state) + return torch.squeeze(self.states2policy([state]), dim=0) def state2readable(self, state=None): """ @@ -797,15 +770,18 @@ def reward(self, state=None, done=None): done = self._get_done(done) if done is False: return tfloat(0.0, float_type=self.float, device=self.device) - return self.proxy2reward(self.proxy(self.state2proxy(state))[0]) + return self.proxy2reward( + self.proxy(torch.unsqueeze(self.state2proxy(state), dim=0))[0] + ) + # TODO: cleanup def reward_batch(self, states: List[List], done=None): """ Computes the rewards of a batch of states, given a list of states and 'dones' """ if done is None: done = np.ones(len(states), dtype=bool) - states_proxy = self.statebatch2proxy(states) + states_proxy = self.states2proxy(states) if isinstance(states_proxy, torch.Tensor): states_proxy = states_proxy[list(done), :] elif isinstance(states_proxy, list): @@ -815,27 +791,11 @@ def reward_batch(self, states: List[List], done=None): rewards[list(done)] = self.proxy2reward(self.proxy(states_proxy)).tolist() return rewards - def reward_torchbatch( - self, - states: TensorType["batch_size", "state_dim"], - done: TensorType["batch_size"] = None, - ): - """ - Computes the rewards of a batch of states in "GFlownet format" - """ - if done is None: - done = torch.ones(states.shape[0], dtype=torch.bool, device=self.device) - states_proxy = self.statetorch2proxy(states[done, :]) - reward = torch.zeros(done.shape[0], dtype=self.float, device=self.device) - if states[done, :].shape[0] > 0: - reward[done] = self.proxy2reward(self.proxy(states_proxy)) - return reward - def proxy2reward(self, proxy_vals): """ - Prepares the output of an oracle for GFlowNet: the inputs proxy_vals is - expected to be a negative value (energy), unless self.denorm_proxy is True. If - the latter, the proxy values are first de-normalized according to the mean and + Prepares the output of a proxy for GFlowNet: the inputs proxy_vals is expected + to be a negative value (energy), unless self.denorm_proxy is True. If the + latter, the proxy values are first de-normalized according to the mean and standard deviation in self.energies_stats. The output of the function is a strictly positive reward - provided self.reward_norm and self.reward_beta are positive - and larger than self.min_reward. @@ -878,7 +838,7 @@ def proxy2reward(self, proxy_vals): def reward2proxy(self, reward): """ Converts a "GFlowNet reward" into a (negative) energy or values as returned by - an oracle. + a proxy. """ if self.reward_func == "power": return self.proxy_factor * torch.exp( @@ -1359,7 +1319,7 @@ def top_k_metrics_and_plots( return metrics, figs, fig_names def plot_reward_distribution( - self, states=None, scores=None, ax=None, title=None, oracle=None, **kwargs + self, states=None, scores=None, ax=None, title=None, proxy=None, **kwargs ): if ax is None: fig, ax = plt.subplots() @@ -1368,15 +1328,15 @@ def plot_reward_distribution( standalone = False if title == None: title = "Scores of Sampled States" - if oracle is None: - oracle = self.oracle + if proxy is None: + proxy = self.proxy if scores is None: if isinstance(states[0], torch.Tensor): states = torch.vstack(states).to(self.device, self.float) if isinstance(states, torch.Tensor) == False: states = torch.tensor(states, device=self.device, dtype=self.float) - oracle_states = self.statetorch2oracle(states) - scores = oracle(oracle_states) + states_proxy = self.states2proxy(states) + scores = self.proxy(states_proxy) if isinstance(scores, TensorType): scores = scores.cpu().detach().numpy() ax.hist(scores) diff --git a/gflownet/envs/crystals/ccrystal.py b/gflownet/envs/crystals/ccrystal.py index 0336d8a43..49cfdda8f 100644 --- a/gflownet/envs/crystals/ccrystal.py +++ b/gflownet/envs/crystals/ccrystal.py @@ -857,100 +857,58 @@ def get_logprobs( ) return logprobs - def state2policy(self, state: Optional[List[int]] = None) -> Tensor: - """ - Prepares one state in "GFlowNet format" for the policy. Simply - a concatenation of all crystal components. - """ - state = self._get_state(state) - return self.statetorch2policy( - torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0) - )[0] - - def statebatch2policy( - self, states: List[List] + def states2policy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_policy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the policy. Simply + Prepares a batch of states in "environment format" for the policy model: simply a concatenation of all crystal components. - """ - return self.statetorch2policy( - tfloat(states, device=self.device, float_type=self.float) - ) - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_policy_dim"]: - """ - Prepares a tensor batch of states in "GFlowNet format" for the policy. Simply - a concatenation of all crystal components. + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ + states = tfloat(states, device=self.device, float_type=self.float) return torch.cat( [ - subenv.statetorch2policy(self._get_states_of_subenv(states, stage)) + subenv.states2policy(self._get_states_of_subenv(states, stage)) for stage, subenv in self.subenvs.items() ], dim=1, ) - def state2oracle(self, state: Optional[List[int]] = None) -> Tensor: - """ - Prepares one state in "GFlowNet format" for the oracle. Simply - a concatenation of all crystal components. - """ - state = self._get_state(state) - return self.statetorch2oracle( - torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0) - ) - - def statebatch2oracle( - self, states: List[List] + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_oracle_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the oracle. Simply - a concatenation of all crystal components. - """ - return self.statetorch2oracle( - tfloat(states, device=self.device, float_type=self.float) - ) + Prepares a batch of states in "environment format" for a proxy: simply a + concatenation of all crystal components. - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares one state in "GFlowNet format" for the oracle. Simply - a concatenation of all crystal components. + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ + states = tfloat(states, device=self.device, float_type=self.float) return torch.cat( [ - subenv.statetorch2oracle(self._get_states_of_subenv(states, stage)) + subenv.states2proxy(self._get_states_of_subenv(states, stage)) for stage, subenv in self.subenvs.items() ], dim=1, ) - def state2proxy(self, state: Optional[List[int]] = None) -> Tensor: - """ - Returns state2oracle(state). - """ - return self.state2oracle(state) - - def statebatch2proxy( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Returns statebatch2oracle(states). - """ - return self.statebatch2oracle(states) - - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Returns statetorch2oracle(states). - """ - return self.statetorch2oracle(states) - def set_state(self, state: List, done: Optional[bool] = False): super().set_state(state, done) diff --git a/gflownet/envs/crystals/composition.py b/gflownet/envs/crystals/composition.py index 55b8bb186..f109faadb 100644 --- a/gflownet/envs/crystals/composition.py +++ b/gflownet/envs/crystals/composition.py @@ -134,10 +134,6 @@ def __init__( self.source = [0 for _ in self.elements] # End-of-sequence action self.eos = (-1, -1) - # Conversions - self.state2proxy = self.state2oracle - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle super().__init__(**kwargs) def get_action_space(self): @@ -406,67 +402,33 @@ def get_element_mask(min_atoms, max_atoms): return mask - def state2oracle(self, state: List = None) -> Tensor: + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a state in "GFlowNet format" for the oracle. The output is a tensor of - length N_ELEMENTS_ORACLE + 1, where the positions of self.elements are filled with - the number of atoms of each element in the state. + Prepares a batch of states in "environment format" for the proxy: The output is + a tensor of dtype long with N_ELEMENTS_ORACLE + 1 columns, where the positions + of self.elements are filled with the number of atoms of each element in the + state. Args ---- - state : list - A state - - Returns - ---- - oracle_state : Tensor - Tensor containing counts of individual elements - """ - if state is None: - state = self.state - return self.statetorch2oracle( - torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0) - )[0] - - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracle. The output is - a tensor with N_ELEMENTS_ORACLE + 1 columns, where the positions of - self.elements are filled with the number of atoms of each element in the state. - - Args - ---- - states : Tensor - A state + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. Returns - ---- - oracle_states : Tensor + ------- + A tensor containing all the states in the batch. """ - states_float = states.to(self.float) - - states_oracle = torch.zeros( + states = tlong(states, device=self.device) + states_proxy = torch.zeros( (states.shape[0], N_ELEMENTS_ORACLE + 1), device=self.device, - dtype=self.float, + dtype=torch.long, ) - states_oracle[:, tlong(self.elements, device=self.device)] = states_float - return states_oracle - - def statebatch2oracle( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracles. In this case, - it simply converts the states into a torch tensor, with dtype torch.long. - - Args - ---- - state : list - """ - return self.statetorch2oracle(tlong(states, device=self.device)) + states_proxy[:, tlong(self.elements, device=self.device)] = states + return states_proxy def state2readable(self, state=None): """ diff --git a/gflownet/envs/crystals/crystal.py b/gflownet/envs/crystals/crystal.py index 0e914ce9a..0acf0a8fa 100644 --- a/gflownet/envs/crystals/crystal.py +++ b/gflownet/envs/crystals/crystal.py @@ -10,6 +10,7 @@ from gflownet.envs.crystals.composition import Composition from gflownet.envs.crystals.lattice_parameters import LatticeParameters from gflownet.envs.crystals.spacegroup import SpaceGroup +from gflownet.utils.common import tlong from gflownet.utils.crystals.constants import TRICLINIC @@ -128,11 +129,6 @@ def __init__( self.lattice_parameters.eos, Stage.LATTICE_PARAMETERS ) - # Conversions - self.state2proxy = self.state2oracle - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle - super().__init__(**kwargs) def _set_lattice_parameters(self): @@ -247,7 +243,7 @@ def _get_composition_state(self, state: Optional[List[int]] = None) -> List[int] def _get_composition_tensor_states( self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: return states[:, self.composition_state_start : self.composition_state_end] def _get_space_group_state(self, state: Optional[List[int]] = None) -> List[int]: @@ -258,7 +254,7 @@ def _get_space_group_state(self, state: Optional[List[int]] = None) -> List[int] def _get_space_group_tensor_states( self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: return states[:, self.space_group_state_start : self.space_group_state_end] def _get_lattice_parameters_state( @@ -273,7 +269,7 @@ def _get_lattice_parameters_state( def _get_lattice_parameters_tensor_states( self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: return states[ :, self.lattice_parameters_state_start : self.lattice_parameters_state_end ] @@ -466,58 +462,38 @@ def get_parents( return parents, actions - def state2oracle(self, state: Optional[List[int]] = None) -> Tensor: + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a list of states in "GFlowNet format" for the oracle. Simply - a concatenation of all crystal components. + Prepares a batch of states in "environment format" for the proxy: simply a + concatenation of all crystal components. + + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ - if state is None: - state = self.state.copy() - - composition_oracle_state = self.composition.state2oracle( - state=self._get_composition_state(state) - ).to(self.device) - space_group_oracle_state = ( - self.space_group.state2oracle(state=self._get_space_group_state(state)) - .unsqueeze(-1) # StateGroup oracle state is a single number - .to(self.device) - ) - lattice_parameters_oracle_state = self.lattice_parameters.state2oracle( - state=self._get_lattice_parameters_state(state) - ).to(self.device) - - return torch.cat( - [ - composition_oracle_state, - space_group_oracle_state, - lattice_parameters_oracle_state, - ] - ) - - def statebatch2oracle( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - return self.statetorch2oracle( - torch.tensor(states, device=self.device, dtype=torch.long) - ) - - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - composition_oracle_states = self.composition.statetorch2oracle( + states = tlong(states, device=self.device) + composition_proxy_states = self.composition.states2proxy( self._get_composition_tensor_states(states) ).to(self.device) - space_group_oracle_states = self.space_group.statetorch2oracle( + space_group_proxy_states = self.space_group.states2proxy( self._get_space_group_tensor_states(states) ).to(self.device) - lattice_parameters_oracle_states = self.lattice_parameters.statetorch2oracle( + lattice_parameters_proxy_states = self.lattice_parameters.states2proxy( self._get_lattice_parameters_tensor_states(states) ).to(self.device) return torch.cat( [ - composition_oracle_states, - space_group_oracle_states, - lattice_parameters_oracle_states, + composition_proxy_states, + space_group_proxy_states, + lattice_parameters_proxy_states, ], dim=1, ) diff --git a/gflownet/envs/crystals/lattice_parameters.py b/gflownet/envs/crystals/lattice_parameters.py index 957a7e229..4901c6404 100644 --- a/gflownet/envs/crystals/lattice_parameters.py +++ b/gflownet/envs/crystals/lattice_parameters.py @@ -1,7 +1,7 @@ """ Classes to represent crystal environments """ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -9,6 +9,7 @@ from torchtyping import TensorType from gflownet.envs.grid import Grid +from gflownet.utils.common import tlong from gflownet.utils.crystals.constants import ( CUBIC, HEXAGONAL, @@ -336,48 +337,28 @@ def get_mask_invalid_actions_forward( return mask - def state2oracle(self, state: Optional[List[int]] = None) -> Tensor: + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a list of states in "GFlowNet format" for the oracle. + Prepares a batch of states in "environment format" for the proxy: the + concatenation of the lengths and angles. Args ---- - state : list - A state. - - Returns - ---- - oracle_state : Tensor - Tensor containing lengths and angles converted from the Grid format. - """ - if state is None: - state = self.state.copy() - - return Tensor( - [self.cell2length[s] for s in state[:3]] - + [self.cell2angle[s] for s in state[3:]] - ) - - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracle. The input to the - oracle is the lengths and angles. - - Args - ---- - states : Tensor - A state + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. Returns - ---- - oracle_states : Tensor + ------- + A tensor containing all the states in the batch. """ + states = tlong(states, device=self.device) return torch.cat( [ - self.lengths_tensor[states[:, :3].long()], - self.angles_tensor[states[:, 3:].long()], + self.lengths_tensor[states[:, :3]], + self.angles_tensor[states[:, 3:]], ], dim=1, ) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 0faa4ffc3..011bc1e3e 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -14,6 +14,7 @@ from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv +from gflownet.utils.common import tlong from gflownet.utils.crystals.pyxtal_cache import space_group_check_compatible CRYSTAL_LATTICE_SYSTEMS = None @@ -130,10 +131,6 @@ def __init__( # Source state: index 0 (empty) for all three properties (crystal-lattice # system index, point symmetry index, space group) self.source = [0 for _ in range(3)] - # Conversions - self.state2proxy = self.state2oracle - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle # Base class init super().__init__(**kwargs) @@ -247,65 +244,25 @@ def get_mask_invalid_actions_forward( ] return mask - def state2oracle(self, state: List = None) -> Tensor: + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a list of states in "GFlowNet format" for the oracle. The input to the - oracle is simply the space group. + Prepares a batch of states in "environment format" for the proxy: the proxy + format is simply the space group. Args ---- - state : list - A state - - Returns - ---- - oracle_state : Tensor - """ - if state is None: - state = self.state - if state[self.sg_idx] == 0: - raise ValueError( - "The space group must have been set in order to call the oracle" - ) - return torch.tensor(state[self.sg_idx], device=self.device, dtype=torch.long) - - def statebatch2oracle( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracle. The input to the - oracle is simply the space group. - - Args - ---- - state : list - A state - - Returns - ---- - oracle_state : Tensor - """ - return self.statetorch2oracle( - torch.tensor(states, device=self.device, dtype=torch.long) - ) - - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracle. The input to the - oracle is simply the space group. - - Args - ---- - state : list - A state + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. Returns - ---- - oracle_state : Tensor + ------- + A tensor containing all the states in the batch. """ - return torch.unsqueeze(states[:, self.sg_idx], dim=1).to(torch.long) + states = tlong(states, device=self.device) + return torch.unsqueeze(states[:, self.sg_idx], dim=1) def state2readable(self, state=None): """ diff --git a/gflownet/envs/cube.py b/gflownet/envs/cube.py index ca929ea6d..ee5c52622 100644 --- a/gflownet/envs/cube.py +++ b/gflownet/envs/cube.py @@ -4,7 +4,7 @@ import itertools import warnings from abc import ABC, abstractmethod -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -102,6 +102,9 @@ def __init__( self.epsilon = epsilon # Small constant to restrict the interval of (test) sets self.kappa = kappa + # Conversions: only conversions to policy are implemented and the conversion to + # proxy format is the same + self.states2proxy = self.states2policy # Base class init super().__init__( fixed_distr_params=fixed_distr_params, @@ -130,105 +133,25 @@ def get_mask_invalid_actions_forward( def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=None): pass - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] = None - ) -> TensorType["batch", "oracle_input_dim"]: - """ - Clips the states into [0, 1] and maps them to [-1.0, 1.0] - - Args - ---- - state : list - State - """ - return 2.0 * torch.clip(states, min=0.0, max=1.0) - 1.0 - - def statebatch2oracle( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Clips the states into [0, 1] and maps them to [-1.0, 1.0] - - Args - ---- - state : list - State - """ - return self.statetorch2oracle( - tfloat(states, device=self.device, float_type=self.float) - ) - - def state2oracle(self, state: List = None) -> List: - """ - Clips the state into [0, 1] and maps it to [-1.0, 1.0] - """ - if state is None: - state = self.state.copy() - return [2.0 * min(max(0.0, s), 1.0) - 1.0 for s in state] - - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] = None - ) -> TensorType["batch", "oracle_input_dim"]: - """ - Returns statetorch2oracle(states), that is states mapped to [-1.0, 1.0]. - - Args - ---- - state : list - State - """ - return self.statetorch2oracle(states) - - def statebatch2proxy( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Returns statebatch2oracle(states), that is states mapped to [-1.0, 1.0]. - - Args - ---- - state : list - State - """ - return self.statebatch2oracle(states) - - def state2proxy(self, state: List = None) -> List: - """ - Returns state2oracle(state), that is the state mapped to [-1.0, 1.0]. - """ - return self.state2oracle(state) - - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] = None + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] ) -> TensorType["batch", "policy_input_dim"]: """ - Returns statetorch2proxy(states), that is states mapped to [-1.0, 1.0]. - - Args - ---- - state : list - State - """ - return self.statetorch2proxy(states) - - def statebatch2policy( - self, states: List[List] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Returns statebatch2proxy(states), that is states mapped to [-1.0, 1.0]. + Prepares a batch of states in "environment format" for the policy model: clips + the states into [0, 1] and maps them to [-1.0, 1.0] Args ---- - state : list - State - """ - return self.statebatch2proxy(states) + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - def state2policy(self, state: List = None) -> List: - """ - Returns state2proxy(state), that is the state mapped to [-1.0, 1.0]. + Returns + ------- + A tensor containing all the states in the batch. """ - return self.state2proxy(state) + states = tfloat(states, device=self.device, float_type=self.float) + return 2.0 * torch.clip(states, min=0.0, max=1.0) - 1.0 def state2readable(self, state: List) -> str: """ @@ -1489,7 +1412,7 @@ def sample_from_reward( samples_final = [] max_reward = self.proxy2reward(self.proxy.min) while len(samples_final) < n_samples: - samples_uniform = self.statebatch2proxy( + samples_uniform = self.states2proxy( self.get_uniform_terminating_states(n_samples) ) rewards = self.proxy2reward(self.proxy(samples_uniform)) diff --git a/gflownet/envs/grid.py b/gflownet/envs/grid.py index 46e3639ef..8f6638eaa 100644 --- a/gflownet/envs/grid.py +++ b/gflownet/envs/grid.py @@ -2,7 +2,7 @@ Classes to represent a hyper-grid environments """ import itertools -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -12,6 +12,7 @@ from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv +from gflownet.utils.common import tfloat, tlong class Grid(GFlowNetEnv): @@ -80,10 +81,7 @@ def __init__( # Proxy format # TODO: assess if really needed if self.proxy_state_format == "ohe": - self.statebatch2proxy = self.statebatch2policy - elif self.proxy_state_format == "oracle": - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle + self.states2proxy = self.states2policy def get_action_space(self): """ @@ -127,127 +125,68 @@ def get_mask_invalid_actions_forward( mask[idx] = True return mask - def state2oracle(self, state: List = None) -> List: + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a state in "GFlowNet format" for the oracles: a list of length - n_dim with values in the range [cell_min, cell_max] for each state. - - See: state2policy() - - Args - ---- - state : list - State - """ - if state is None: - state = self.state.copy() - return ( - ( - np.array(self.state2policy(state)).reshape((self.n_dim, self.length)) - * self.cells[None, :] - ) - .sum(axis=1) - .tolist() - ) - - def statebatch2oracle( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracles: each state is + Prepares a batch of states in "environment format" for the proxy: each state is a vector of length n_dim with values in the range [cell_min, cell_max]. - See: statetorch2oracle() + See: states2policy() Args ---- - state : list - State - """ - return self.statetorch2oracle( - torch.tensor(states, device=self.device, dtype=self.float) - ) + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracles: each state is - a vector of length n_dim with values in the range [cell_min, cell_max]. - - See: statetorch2policy() + Returns + ------- + A tensor containing all the states in the batch. """ + states = tfloat(states, device=self.device, float_type=self.float) return ( - self.statetorch2policy(states).reshape( - (len(states), self.n_dim, self.length) + self.states2policy(states).reshape( + (states.shape[0], self.n_dim, self.length) ) * torch.tensor(self.cells[None, :]).to(states.device, self.float) ).sum(axis=2) - def state2policy(self, state: List = None) -> List: + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "policy_input_dim"]: """ - Transforms the state given as argument (or self.state if None) into a - one-hot encoding. The output is a list of len length * n_dim, + Prepares a batch of states in "environment format" for the policy model: states + are one-hot encoded. + + The output is a 2D tensor, with the second dimension of size length * n_dim, where each n-th successive block of length elements is a one-hot encoding of the position in the n-th dimension. - Example: - - State, state: [0, 3, 1] (n_dim = 3) - - state2policy(state): [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0] (length = 4) - | 0 | 3 | 1 | - """ - if state is None: - state = self.state.copy() - state_policy = np.zeros(self.length * self.n_dim, dtype=np.float32) - state_policy[(np.arange(len(state)) * self.length + state)] = 1 - return state_policy.tolist() - - def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: - """ - Transforms a batch of states into a one-hot encoding. The output is a numpy - array of shape [n_states, length * n_dim]. - - See state2policy(). - """ - cols = np.array(states) + np.arange(self.n_dim) * self.length - rows = np.repeat(np.arange(len(states)), self.n_dim) - state_policy = np.zeros( - (len(states), self.length * self.n_dim), dtype=np.float32 - ) - state_policy[rows, cols.flatten()] = 1.0 - return state_policy + Example (n_dim = 3, length = 4): + - state: [0, 3, 1] + - policy format: [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0] + | 0 | 3 | 1 | - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "policy_output_dim"]: - """ - Transforms a batch of states into a one-hot encoding. The output is a numpy - array of shape [n_states, length * n_dim]. + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - See state2policy(). + Returns + ------- + A tensor containing all the states in the batch. """ - device = states.device - cols = (states + torch.arange(self.n_dim).to(device) * self.length).to(int) - rows = torch.repeat_interleave( - torch.arange(states.shape[0]).to(device), self.n_dim + states = tlong(states, device=self.device) + n_states = states.shape[0] + cols = states + torch.arange(self.n_dim) * self.length + rows = torch.repeat_interleave(torch.arange(n_states), self.n_dim) + states_policy = torch.zeros( + (n_states, self.length * self.n_dim), dtype=self.float, device=self.device ) - state_policy = torch.zeros( - (states.shape[0], self.length * self.n_dim), dtype=states.dtype - ).to(device) - state_policy[rows, cols.flatten()] = 1.0 - return state_policy - - def policy2state(self, state_policy: List) -> List: - """ - Transforms the one-hot encoding version of a state given as argument - into a state (list of the position at each dimension). - - Example: - - state_policy: [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0] (length = 4, n_dim = 3) - | 0 | 3 | 1 | - - policy2state(state_policy): [0, 3, 1] - """ - return np.where(np.reshape(state_policy, (self.n_dim, self.length)))[1].tolist() + states_policy[rows, cols.flatten()] = 1.0 + return states_policy def readable2state(self, readable, alphabet={}): """ diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index 011a74c51..0f7146be4 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -4,7 +4,7 @@ import itertools import re from copy import deepcopy -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -73,10 +73,6 @@ def __init__( self.source = self.source_angles + [0] # End-of-sequence action: (n_dim, 0) self.eos = (self.n_dim, 0) - # TODO: assess if really needed - self.state2oracle = self.state2proxy - self.statebatch2oracle = self.statebatch2proxy - self.statetorch2oracle = self.statetorch2proxy # Base class init super().__init__( fixed_distr_params=fixed_distr_params, @@ -185,84 +181,63 @@ def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=Non ] + [mask[-1]] return mask - def statebatch2proxy( - self, states: List[List] + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the proxy: a tensor where - each state is a row of length n_dim with an angle in radians. The n_actions + Prepares a batch of states in "environment format" for the proxy: each state is + a vector of length n_dim where each value is an angle in radians. The n_actions item is removed. - """ - return torch.tensor(states, device=self.device)[:, :-1] - - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in torch "GFlowNet format" for the proxy. - """ - return states[:, :-1] - def state2policy(self, state: List = None) -> List: - """ - Returns the policy encoding of the state. + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - See: statebatch2policy() + Returns + ------- + A tensor containing all the states in the batch. """ - if state is None: - state = self.state.copy() - return self.statebatch2policy([state]).tolist()[0] + return tfloat(states, device=self.device, float_type=self.float)[:, :-1] - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] ) -> TensorType["batch", "policy_input_dim"]: """ - Prepares a batch of states in torch "GFlowNet format" for the policy. - - If policy_encoding_dim_per_angle >= 2, then the state (angles) is encoded using + Prepares a batch of states in "environment format" for the policy model: if + policy_encoding_dim_per_angle >= 2, then the state (angles) is encoded using trigonometric components. - """ - if ( - self.policy_encoding_dim_per_angle is not None - and self.policy_encoding_dim_per_angle >= 2 - ): - step = states[:, -1] - code_half_size = self.policy_encoding_dim_per_angle // 2 - int_coeff = ( - torch.arange(1, code_half_size + 1) - .repeat(states.shape[-1] - 1) - .to(states) - ) - encoding = ( - torch.repeat_interleave(states[:, :-1], repeats=code_half_size, dim=1) - * int_coeff - ) - states = torch.cat( - [torch.cos(encoding), torch.sin(encoding), torch.unsqueeze(step, 1)], - dim=1, - ) - return states - def statebatch2policy( - self, states: List[List] - ) -> TensorType["batch_size", "policy_input_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the policy. + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - See: statetorch2policy() + Returns + ------- + A tensor containing all the states in the batch. """ states = tfloat(states, float_type=self.float, device=self.device) - return self.statetorch2policy(states) - - def policy2state(self, state_policy: List) -> List: - """ - Returns the input as is. - """ - if self.policy_encoding_dim_per_angle is not None: - raise NotImplementedError( - "Convertion from encoded policy_state to state is not impemented" - ) - return state_policy + if ( + self.policy_encoding_dim_per_angle is None + or self.policy_encoding_dim_per_angle < 2 + ): + return states + step = states[:, -1] + code_half_size = self.policy_encoding_dim_per_angle // 2 + int_coeff = ( + torch.arange(1, code_half_size + 1).repeat(states.shape[-1] - 1).to(states) + ) + encoding = ( + torch.repeat_interleave(states[:, :-1], repeats=code_half_size, dim=1) + * int_coeff + ) + return torch.cat( + [torch.cos(encoding), torch.sin(encoding), torch.unsqueeze(step, 1)], + dim=1, + ) def state2readable(self, state: List) -> str: """ @@ -566,7 +541,9 @@ def sample_from_reward( ), axis=1, ) - rewards = self.reward_torchbatch(samples) + rewards = tfloat( + self.reward_batch(samples), device=self.device, float_type=self.float + ) mask = ( torch.rand(n_samples, dtype=self.float, device=self.device) * (max_reward + epsilon) @@ -606,7 +583,7 @@ def plot_reward_samples( [samples_mesh, torch.ones(samples_mesh.shape[0], 1)], 1 ).to(self.device) rewards = torch2np( - self.proxy2reward(self.proxy(self.statetorch2proxy(states_mesh))) + self.proxy2reward(self.proxy(self.states2proxy(states_mesh))) ) # Init figure fig, ax = plt.subplots() diff --git a/gflownet/envs/tetris.py b/gflownet/envs/tetris.py index a80d447ff..6c4e8bcb7 100644 --- a/gflownet/envs/tetris.py +++ b/gflownet/envs/tetris.py @@ -4,7 +4,7 @@ import itertools import re import warnings -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -99,10 +99,6 @@ def __init__( ) # End-of-sequence action: all -1 self.eos = (-1, -1, -1) - # Conversions - self.state2proxy = self.state2oracle - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle # Precompute all possible rotations of each piece and the corresponding binary # mask @@ -251,87 +247,53 @@ def get_mask_invalid_actions_forward( mask[-1] = True return mask - def state2oracle( - self, state: Optional[TensorType["height", "width"]] = None - ) -> TensorType["height", "width"]: - """ - Prepares a state in "GFlowNet format" for the oracles: simply converts non-zero - (non-empty) cells into 1s. - - Args - ---- - state : tensor - """ - if state is None: - state = self.state.clone().detach() - state_oracle = state.clone().detach() - state_oracle[state_oracle != 0] = 1 - return state_oracle - - def statebatch2oracle( - self, states: List[TensorType["height", "width"]] - ) -> TensorType["batch", "state_oracle_dim"]: + def states2proxy( + self, + states: Union[ + List[TensorType["height", "width"]], TensorType["height", "width", "batch"] + ], + ) -> TensorType["height", "width", "batch"]: """ - Prepares a batch of states in "GFlowNet format" for the oracles: simply + Prepares a batch of states in "environment format" for a proxy: : simply converts non-zero (non-empty) cells into 1s. Args ---- - state : list - """ - states = torch.stack(states) - states[states != 0] = 1 - return states + states : list of 2D tensors or 3D tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - def statetorch2oracle( - self, states: TensorType["height", "width", "batch"] - ) -> TensorType["height", "width", "batch"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracles: : simply - converts non-zero (non-empty) cells into 1s. + Returns + ------- + A tensor containing all the states in the batch. """ + states = tint(states, device=self.device, int_type=self.int) states[states != 0] = 1 return states - def state2policy( - self, state: Optional[TensorType["height", "width"]] = None - ) -> TensorType["height", "width"]: - """ - Prepares a state in "GFlowNet format" for the policy model. - - See: state2oracle() - """ - return self.state2oracle(state).flatten() - - def statebatch2policy( - self, states: List[TensorType["height", "width"]] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the policy model. - - See statebatch2oracle(). - """ - return self.statebatch2oracle(states).flatten(start_dim=1) - - def statetorch2policy( - self, states: TensorType["height", "width", "batch"] + def states2policy( + self, + states: Union[ + List[TensorType["height", "width"]], TensorType["height", "width", "batch"] + ], ) -> TensorType["height", "width", "batch"]: """ - Prepares a batch of states in "GFlowNet format" for the policy model. + Prepares a batch of states in "environment format" for the policy model. - See statetorch2oracle(). - """ - return self.statetorch2oracle(states).flatten(start_dim=1) + See states2proxy(). - def policy2state( - self, policy: Optional[TensorType["height", "width"]] = None - ) -> TensorType["height", "width"]: - """ - Returns None to signal that the conversion is not reversible. + Args + ---- + states : list of 2D tensors or 3D tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - See: state2oracle() + Returns + ------- + A tensor containing all the states in the batch. """ - return None + states = tint(states, device=self.device, int_type=self.int) + return self.states2proxy(states).flatten(start_dim=1).to(self.float) def state2readable(self, state: Optional[TensorType["height", "width"]] = None): """ diff --git a/gflownet/envs/torus.py b/gflownet/envs/torus.py index 8c0ce712d..54b1183a3 100644 --- a/gflownet/envs/torus.py +++ b/gflownet/envs/torus.py @@ -2,7 +2,7 @@ Classes to represent hyper-torus environments """ import itertools -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -11,6 +11,7 @@ from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv +from gflownet.utils.common import tfloat, tlong class Torus(GFlowNetEnv): @@ -64,9 +65,6 @@ def __init__( self.eos = tuple([self.max_increment + 1 for _ in range(self.n_dim)]) # Angle increments in radians self.angle_rad = 2 * np.pi / self.n_angles - # TODO: assess if really needed - self.state2oracle = self.state2proxy - self.statebatch2oracle = self.statebatch2proxy # Base class init super().__init__(**kwargs) @@ -112,105 +110,67 @@ def get_mask_invalid_actions_forward( mask[-1] = True return mask - def statebatch2proxy(self, states: List[List]) -> npt.NDArray[np.float32]: + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the proxy: an array where - each state is a row of length n_dim with an angle in radians. The n_actions + Prepares a batch of states in "environment format" for the proxy: each state is + a vector of length n_dim where each value is an angle in radians. The n_actions item is removed. - """ - return torch.tensor(states, device=self.device)[:, :-1] * self.angle_rad - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in torch "GFlowNet format" for the proxy. + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ - return states[:, :-1] * self.angle_rad + return ( + tfloat(states, device=self.device, float_type=self.float)[:, :-1] + * self.angle_rad + ) # TODO: circular encoding as in htorus - def state2policy(self, state=None) -> List: + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "policy_input_dim"]: """ - Transforms the angles part of the state given as argument (or self.state if - None) into a one-hot encoding. The output is a list of len n_angles * n_dim + - 1, where each n-th successive block of length elements is a one-hot encoding of - the position in the n-th dimension. + Prepares a batch of states in "environment format" for the policy model: the + policy format is a one-hot encoding of the states. + + Each row is a vector of length n_angles * n_dim + 1, where each n-th successive + block of length elements is a one-hot encoding of the position in the n-th + dimension. Example, n_dim = 2, n_angles = 4: - - State, state: [1, 3, 4] + - state: [1, 3, 4] | a | n | (a = angles, n = n_actions) - - state2policy(state): [0, 1, 0, 0, 0, 0, 0, 1, 4] - | 1 | 3 | 4 | - """ - if state is None: - state = self.state.copy() - # TODO: do we need float32? - # TODO: do we need one-hot? - state_policy = np.zeros(self.n_angles * self.n_dim + 1, dtype=np.float32) - # Angles - state_policy[: self.n_dim * self.n_angles][ - (np.arange(self.n_dim) * self.n_angles + state[: self.n_dim]) - ] = 1 - # Number of actions - state_policy[-1] = state[-1] - return state_policy - - def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: - """ - Transforms a batch of states into the policy model format. The output is a numpy - array of shape [n_states, n_angles * n_dim + 1]. - - See state2policy(). - """ - states = np.array(states) - cols = states[:, :-1] + np.arange(self.n_dim) * self.n_angles - rows = np.repeat(np.arange(states.shape[0]), self.n_dim) - state_policy = np.zeros( - (len(states), self.n_angles * self.n_dim + 1), dtype=np.float32 - ) - state_policy[rows, cols.flatten()] = 1.0 - state_policy[:, -1] = states[:, -1] - return state_policy - - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "policy_output_dim"]: - """ - Transforms a batch of torch states into the policy model format. The output is - a tensor of shape [n_states, n_angles * n_dim + 1]. + - policy format: [0, 1, 0, 0, 0, 0, 0, 1, 4] + | 1 | 3 | 4 | + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - See state2policy(). + Returns + ------- + A tensor containing all the states in the batch. """ - device = states.device - cols = ( - states[:, :-1] + torch.arange(self.n_dim).to(device) * self.n_angles - ).to(int) + states = tlong(states, device=self.device) + cols = states[:, :-1] + torch.arange(self.n_dim).to(self.device) * self.n_angles rows = torch.repeat_interleave( - torch.arange(states.shape[0]).to(device), self.n_dim + torch.arange(states.shape[0]).to(self.device), self.n_dim ) - state_policy = torch.zeros( + states_policy = torch.zeros( (states.shape[0], self.n_angles * self.n_dim + 1) ).to(states) - state_policy[rows, cols.flatten()] = 1.0 - state_policy[:, -1] = states[:, -1] - return state_policy - - def policy2state(self, state_policy: List) -> List: - """ - Transforms the one-hot encoding version of a state given as argument - into a state (list of the position at each dimension). - - Example, n_dim = 2, n_angles = 4: - - state_policy: [0, 1, 0, 0, 0, 0, 0, 1, 4] - | 0 | 3 | 4 | - - policy2state(state_policy): [1, 3, 4] - | a | n | (a = angles, n = n_actions) - """ - mat_angles_policy = np.reshape( - state_policy[: self.n_dim * self.n_angles], (self.n_dim, self.n_angles) - ) - angles = np.where(mat_angles_policy)[1].tolist() - return angles + [int(state_policy[-1])] + states_policy[rows, cols.flatten()] = 1.0 + states_policy[:, -1] = states[:, -1] + return states_policy.to(self.float) def state2readable(self, state: Optional[List] = None) -> str: """ diff --git a/gflownet/envs/tree.py b/gflownet/envs/tree.py index bc92cda13..13684f4a3 100644 --- a/gflownet/envs/tree.py +++ b/gflownet/envs/tree.py @@ -285,13 +285,11 @@ def __init__( # Conversions policy_format = policy_format.lower() if policy_format == "mlp": - self.state2policy = self.state2policy_mlp - self.statetorch2policy = self.statetorch2policy_mlp + self.states2policy = self.states2policy_mlp elif policy_format != "gnn": raise ValueError( f"Unrecognized policy_format = {policy_format}, expected either 'mlp' or 'gnn'." ) - self.statetorch2oracle = self.statetorch2policy super().__init__( fixed_distr_params=fixed_distr_params, @@ -830,24 +828,19 @@ def get_logprobs( is_backward, ) - def state2policy_mlp( - self, state: Optional[TensorType["state_dim"]] = None - ) -> TensorType["policy_input_dim"]: - """ - Prepares a state in "GFlowNet format" for the policy model. - """ - if state is None: - state = self.state.clone().detach() - return self.statetorch2policy_mlp(state.unsqueeze(0))[0] - - def statetorch2policy_mlp( - self, states: TensorType["batch_size", "state_dim"] + def states2policy_mlp( + self, + states: Union[ + List[TensorType["state_dim"]], TensorType["batch_size", "state_dim"] + ], ) -> TensorType["batch_size", "policy_input_dim"]: """ Prepares a batch of states in torch "GFlowNet format" for an MLP policy model. It replaces the NaNs by -2s, removes the activity attribute, and explicitly appends the attribute vector of the active node (if present). """ + if isinstance(states, list): + states = torch.stack(states) rows, cols = torch.where(states[:, :-1, Attribute.ACTIVE] == Status.ACTIVE) active_features = torch.full((states.shape[0], 1, 4), -2.0) active_features[rows] = states[rows, cols, : Attribute.ACTIVE].unsqueeze(1) @@ -855,28 +848,6 @@ def statetorch2policy_mlp( states = torch.cat([states[:, :, : Attribute.ACTIVE], active_features], dim=1) return states.flatten(start_dim=1) - def policy2state( - self, policy: Optional[TensorType["policy_input_dim"]] = None - ) -> None: - """ - Returns None to signal that the conversion is not reversible. - """ - return None - - def statebatch2proxy( - self, states: List[TensorType["state_dim"]] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the proxy: simply - stacks the list of tensors and calls self.statetorch2proxy. - - Args - ---- - state : list - """ - states = torch.stack(states) - return self.statetorch2proxy(states) - def _attributes_to_readable(self, attributes: List) -> str: # Node type if attributes[Attribute.TYPE] == NodeType.CONDITION: diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 06777a851..fdbea6e76 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -325,7 +325,7 @@ def sample_actions( # Check for at least one non-random action if idx_norandom.sum() > 0: states_policy = tfloat( - self.env.statebatch2policy( + self.env.states2policy( [s for s, do in zip(states, idx_norandom) if do] ), device=self.device, @@ -1036,8 +1036,8 @@ def test(self, **plot_kwargs): assert batch.is_valid() x_sampled = batch.get_terminating_states() # TODO make it work with conditional env - x_sampled = torch2np(self.env.statebatch2proxy(x_sampled)) - x_tt = torch2np(self.env.statebatch2proxy(x_tt)) + x_sampled = torch2np(self.env.states2proxy(x_sampled)) + x_tt = torch2np(self.env.states2proxy(x_tt)) kde_pred = self.env.fit_kde( x_sampled, kernel=self.logger.test.kde.kernel, @@ -1051,7 +1051,7 @@ def test(self, **plot_kwargs): x_from_reward = self.env.sample_from_reward( n_samples=self.logger.test.n ) - x_from_reward = torch2np(self.env.statetorch2proxy(x_from_reward)) + x_from_reward = torch2np(self.env.states2proxy(x_from_reward)) # Fit KDE with samples from reward kde_true = self.env.fit_kde( x_from_reward, @@ -1332,7 +1332,7 @@ def logq(traj_list, actions_list, model, env): with torch.no_grad(): logits_traj = model( tfloat( - env.statebatch2policy(traj), + env.states2policy(traj), device=self.device, float_type=self.float, ) diff --git a/gflownet/proxy/aptamers.py b/gflownet/proxy/aptamers.py deleted file mode 100644 index 338c72347..000000000 --- a/gflownet/proxy/aptamers.py +++ /dev/null @@ -1,35 +0,0 @@ -import numpy as np -import numpy.typing as npt - -from gflownet.proxy.base import Proxy - - -class Aptamers(Proxy): - """ - DNA Aptamer oracles - """ - - def __init__(self, oracle_id, norm): - super().__init__() - self.type = oracle_id - self.norm = norm - - def setup(self, env=None): - self.max_seq_length = env.max_seq_length - - def __call__(self, states: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: - """ - args: - states : ndarray - """ - - def _length(x): - if self.norm: - return -1.0 * np.sum(x, axis=1) / self.max_seq_length - else: - return -1.0 * np.sum(x, axis=1) - - if self.type == "length": - return _length(states) - else: - raise NotImplementedError("self.type must be length") diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index a35f01ddf..9ab3dc24d 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -408,12 +408,7 @@ def states2policy( self.get_states_of_trajectory(traj_idx, states, traj_indices) ) return states_policy - # TODO: do we need tfloat or is done in env.statebatch2policy? - return tfloat( - self.env.statebatch2policy(states), - device=self.device, - float_type=self.float, - ) + return self.env.states2policy(states) def states2proxy( self, @@ -461,7 +456,7 @@ def states2proxy( if traj_idx not in traj_indices: continue states_proxy.append( - self.envs[traj_idx].statebatch2proxy( + self.envs[traj_idx].states2proxy( self.get_states_of_trajectory(traj_idx, states, traj_indices) ) ) @@ -471,7 +466,7 @@ def states2proxy( index[perm_index] = index.clone() states_proxy = concat_items(states_proxy, index) return states_proxy - return self.env.statebatch2proxy(states) + return self.env.states2proxy(states) def get_actions(self) -> TensorType["n_states, action_dim"]: """ @@ -678,13 +673,7 @@ def _compute_parents_all(self): self.parents_all.extend(parents) self.parents_actions_all.extend(parents_a) self.parents_all_indices.extend([idx] * len(parents)) - self.parents_all_policy.append( - tfloat( - self.envs[traj_idx].statebatch2policy(parents), - device=self.device, - float_type=self.float, - ) - ) + self.parents_all_policy.append(self.envs[traj_idx].states2policy(parents)) # Convert to tensors self.parents_actions_all = tfloat( self.parents_actions_all, diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index b5d3d3e42..17965bd93 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -250,7 +250,7 @@ def make_data_set(self, config): samples = self.env.get_random_terminating_states(config.n) else: return None, None - energies = self.env.oracle(self.env.statebatch2oracle(samples)).tolist() + energies = self.env.proxy(self.env.states2proxy(samples)).tolist() df = pd.DataFrame( { "samples": [self.env.state2readable(s) for s in samples], diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index e7e7f8afd..cfe81f40d 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -202,36 +202,36 @@ def batch_with_rest(start, stop, step, tensor=False): def tfloat(x, device, float_type): if isinstance(x, list) and torch.is_tensor(x[0]): - return torch.stack(x).type(float_type).to(device) + return torch.stack(x).to(device=device, dtype=float_type) if torch.is_tensor(x): - return x.type(float_type).to(device) + return x.to(device=device, dtype=float_type) else: return torch.tensor(x, dtype=float_type, device=device) def tlong(x, device): if isinstance(x, list) and torch.is_tensor(x[0]): - return torch.stack(x).type(torch.long).to(device) + return torch.stack(x).to(device=device, dtype=torch.long) if torch.is_tensor(x): - return x.type(torch.long).to(device) + return x.to(device=device, dtype=torch.long) else: return torch.tensor(x, dtype=torch.long, device=device) def tint(x, device, int_type): if isinstance(x, list) and torch.is_tensor(x[0]): - return torch.stack(x).type(int_type).to(device) + return torch.stack(x).to(device=device, dtype=int_type) if torch.is_tensor(x): - return x.type(int_type).to(device) + return x.to(device=device, dtype=int_type) else: return torch.tensor(x, dtype=int_type, device=device) def tbool(x, device): if isinstance(x, list) and torch.is_tensor(x[0]): - return torch.stack(x).type(torch.bool).to(device) + return torch.stack(x).to(device=device, dtype=torch.bool) if torch.is_tensor(x): - return x.type(torch.bool).to(device) + return x.to(device=device, dtype=torch.bool) else: return torch.tensor(x, dtype=torch.bool, device=device) diff --git a/main.py b/main.py index 127e59001..62a5d064a 100644 --- a/main.py +++ b/main.py @@ -77,7 +77,7 @@ def main(config): if config.n_samples > 0 and config.n_samples <= 1e5: batch, times = gflownet.sample_batch(n_forward=config.n_samples, train=False) x_sampled = batch.get_terminating_states(proxy=True) - energies = env.oracle(x_sampled) + energies = env.proxy(x_sampled) x_sampled = batch.get_terminating_states() df = pd.DataFrame( { diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index 862c8bdcd..fd8c296bf 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -213,16 +213,6 @@ def test__sample_backwards_reaches_source(env, n=100): assert n_actions <= env.max_traj_length -@pytest.mark.repeat(100) -def test__state2policy__is_reversible(env): - env = env.reset() - while not env.done: - state_recovered = env.policy2state(env.state2policy()) - if state_recovered is not None: - assert env.equal(env.state, state_recovered) - env.step_random() - - @pytest.mark.repeat(100) def test__state2readable__is_reversible(env): env = env.reset() diff --git a/tests/gflownet/envs/test_ccrystal.py b/tests/gflownet/envs/test_ccrystal.py index e88a348ae..4d0cdf00b 100644 --- a/tests/gflownet/envs/test_ccrystal.py +++ b/tests/gflownet/envs/test_ccrystal.py @@ -152,22 +152,22 @@ def test__pad_depad_action(env): ], ], ) -def test__statetorch2policy__is_concatenation_of_subenv_states(env, states): +def test__states2policy__is_concatenation_of_subenv_states(env, states): # Get policy states from the batch of states converted into each subenv states_dict = {stage: [] for stage in env.subenvs} for state in states: for stage in env.subenvs: states_dict[stage].append(env._get_state_of_subenv(state, stage)) states_policy_dict = { - stage: subenv.statebatch2policy(states_dict[stage]) + stage: subenv.states2policy(states_dict[stage]) for stage, subenv in env.subenvs.items() } states_policy_expected = torch.cat( [el for el in states_policy_dict.values()], dim=1 ) - # Get policy states from env.statetorch2policy + # Get policy states from env.states2policy states_torch = tfloat(states, float_type=env.float, device=env.device) - states_policy = env.statetorch2policy(states_torch) + states_policy = env.states2policy(states_torch) assert torch.all(torch.eq(states_policy, states_policy_expected)) @@ -191,20 +191,20 @@ def test__statetorch2policy__is_concatenation_of_subenv_states(env, states): ], ], ) -def test__statetorch2proxy__is_concatenation_of_subenv_states(env, states): +def test__states2proxy__is_concatenation_of_subenv_states(env, states): # Get proxy states from the batch of states converted into each subenv states_dict = {stage: [] for stage in env.subenvs} for state in states: for stage in env.subenvs: states_dict[stage].append(env._get_state_of_subenv(state, stage)) states_proxy_dict = { - stage: subenv.statebatch2proxy(states_dict[stage]) + stage: subenv.states2proxy(states_dict[stage]) for stage, subenv in env.subenvs.items() } states_proxy_expected = torch.cat([el for el in states_proxy_dict.values()], dim=1) - # Get proxy states from env.statetorch2proxy + # Get proxy states from env.states2proxy states_torch = tfloat(states, float_type=env.float, device=env.device) - states_proxy = env.statetorch2proxy(states_torch) + states_proxy = env.states2proxy(states_torch) assert torch.all(torch.eq(states_proxy, states_proxy_expected)) @@ -243,7 +243,7 @@ def test__state2readable__is_concatenation_of_subenv_states(env, states): f"SpaceGroup = {readables[1]}; " f"LatticeParameters = {readables[2]}" ) - # Get policy states from env.statetorch2policy + # Get policy states from env.states2policy states_readable = [env.state2readable(state) for state in states] for readable, readable_expected in zip(states_readable, states_readable_expected): assert readable == readable_expected diff --git a/tests/gflownet/envs/test_ccube.py b/tests/gflownet/envs/test_ccube.py index 25181ed4e..462eaf2f9 100644 --- a/tests/gflownet/envs/test_ccube.py +++ b/tests/gflownet/envs/test_ccube.py @@ -1097,10 +1097,8 @@ def test__state2policy_returns_expected(env, state, expected): ], ) @pytest.mark.skip(reason="skip while developping other tests") -def test__statetorch2policy_returns_expected(env, states, expected): - assert torch.equal( - env.statetorch2policy(torch.tensor(states)), torch.tensor(expected) - ) +def test__states2policy_returns_expected(env, states, expected): + assert torch.equal(env.states2policy(torch.tensor(states)), torch.tensor(expected)) @pytest.mark.parametrize( diff --git a/tests/gflownet/envs/test_composition.py b/tests/gflownet/envs/test_composition.py index 31fcd470e..eb6add01e 100644 --- a/tests/gflownet/envs/test_composition.py +++ b/tests/gflownet/envs/test_composition.py @@ -4,6 +4,7 @@ import torch from gflownet.envs.crystals.composition import Composition +from gflownet.utils.common import tlong @pytest.fixture @@ -74,8 +75,8 @@ def test__environment__initializes_properly(elements): ), ], ) -def test__state2oracle__returns_expected_tensor(env, state, exp_tensor): - assert torch.equal(env.state2oracle(state), torch.Tensor(exp_tensor)) +def test__state2proxy__returns_expected_tensor(env, state, exp_tensor): + assert torch.equal(env.state2proxy(state), tlong(exp_tensor, device=env.device)) def test__state2readable(env): diff --git a/tests/gflownet/envs/test_crystal.py b/tests/gflownet/envs/test_crystal.py index 5d4c9cbed..75dae0167 100644 --- a/tests/gflownet/envs/test_crystal.py +++ b/tests/gflownet/envs/test_crystal.py @@ -110,8 +110,8 @@ def test__pad_depad_action(env): ], ], ) -def test__state2oracle__returns_expected_value(env, state, expected): - assert torch.allclose(env.state2oracle(state), expected, atol=1e-4) +def test__state2proxy__returns_expected_value(env, state, expected): + assert torch.allclose(env.state2proxy(state), expected, atol=1e-4) @pytest.mark.parametrize( @@ -216,8 +216,8 @@ def test__state2proxy__returns_expected_value(env, state, expected): ], ], ) -def test__statebatch2proxy__returns_expected_value(env, batch, expected): - assert torch.allclose(env.statebatch2proxy(batch), expected, atol=1e-4) +def test__states2proxy__returns_expected_value(env, batch, expected): + assert torch.allclose(env.states2proxy(batch), expected, atol=1e-4) @pytest.mark.parametrize("action", [(1, 1, -2, -2, -2, -2), (3, 4, -2, -2, -2, -2)]) diff --git a/tests/gflownet/envs/test_grid.py b/tests/gflownet/envs/test_grid.py index 2bf3ea4bf..afa375649 100644 --- a/tests/gflownet/envs/test_grid.py +++ b/tests/gflownet/envs/test_grid.py @@ -3,6 +3,7 @@ import torch from gflownet.envs.grid import Grid +from gflownet.utils.common import tfloat @pytest.fixture @@ -45,7 +46,7 @@ def config_path(): @pytest.mark.parametrize( - "state, state2oracle", + "state, state2proxy", [ ( [0, 0, 0], @@ -65,12 +66,15 @@ def config_path(): ), ], ) -def test__state2oracle__returns_expected(env, state, state2oracle): - assert state2oracle == env.state2oracle(state) +def test__state2proxy__returns_expected(env, state, state2proxy): + assert torch.equal( + tfloat(state2proxy, device=env.device, float_type=env.float), + env.state2proxy(state), + ) @pytest.mark.parametrize( - "states, statebatch2oracle", + "states, states2proxy", [ ( [[0, 0, 0], [4, 4, 4], [1, 2, 3], [4, 0, 1]], @@ -78,8 +82,8 @@ def test__state2oracle__returns_expected(env, state, state2oracle): ), ], ) -def test__statebatch2oracle__returns_expected(env, states, statebatch2oracle): - assert torch.equal(torch.Tensor(statebatch2oracle), env.statebatch2oracle(states)) +def test__states2proxy__returns_expected(env, states, states2proxy): + assert torch.equal(torch.Tensor(states2proxy), env.states2proxy(states)) @pytest.mark.parametrize( diff --git a/tests/gflownet/envs/test_lattice_parameters.py b/tests/gflownet/envs/test_lattice_parameters.py index 16aea2814..281d30783 100644 --- a/tests/gflownet/envs/test_lattice_parameters.py +++ b/tests/gflownet/envs/test_lattice_parameters.py @@ -282,8 +282,8 @@ def test__step__changes_state_as_expected(env, lattice_system, actions, exp_stat ), ], ) -def test__state2oracle__returns_expected_tensor(env, lattice_system, state, exp_tensor): - assert torch.equal(env.state2oracle(state), torch.Tensor(exp_tensor)) +def test__state2proxy__returns_expected_tensor(env, lattice_system, state, exp_tensor): + assert torch.equal(env.state2proxy(state), torch.Tensor(exp_tensor)) @pytest.mark.parametrize("lattice_system", [TRICLINIC]) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 338dfd061..776bba324 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -123,14 +123,7 @@ def test__get_states__single_env_returns_expected(env, batch, request): assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) @pytest.mark.repeat(N_REPETITIONS) @@ -155,14 +148,7 @@ def test__get_parents__single_env_returns_expected(env, batch, request): assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) @pytest.mark.repeat(N_REPETITIONS) @@ -197,14 +183,7 @@ def test__get_parents_all__single_env_returns_expected(env, batch, request): float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) @pytest.mark.repeat(N_REPETITIONS) @@ -365,14 +344,7 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) # Check actions actions_batch = batch.get_actions() assert torch.equal( @@ -399,14 +371,7 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) # Check parents_all if not env.continuous: parents_all_batch, parents_all_a_batch, _ = batch.get_parents_all() @@ -423,14 +388,7 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) @@ -447,14 +405,7 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ ) else: assert states_term_batch == states_term_sorted - assert torch.equal( - states_term_policy_batch, - tfloat( - env.statebatch2policy(states_term_sorted), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_term_policy_batch, env.states2policy(states_term_sorted)) @pytest.mark.repeat(N_REPETITIONS) @@ -551,14 +502,7 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) # Check actions actions_batch = batch.get_actions() assert torch.equal( @@ -585,14 +529,7 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) # Check parents_all if not env.continuous: parents_all_batch, parents_all_a_batch, _ = batch.get_parents_all() @@ -609,14 +546,7 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) @@ -633,14 +563,7 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req ) else: assert states_term_batch == states_term_sorted - assert torch.equal( - states_term_policy_batch, - tfloat( - env.statebatch2policy(states_term_sorted), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_term_policy_batch, env.states2policy(states_term_sorted)) @pytest.mark.repeat(N_REPETITIONS) @@ -794,14 +717,7 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) # Check actions actions_batch = batch.get_actions() assert torch.equal( @@ -828,14 +744,7 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) # Check parents_all if not env.continuous: parents_all_batch, parents_all_a_batch, _ = batch.get_parents_all() @@ -852,14 +761,7 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) @@ -876,14 +778,7 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques ) else: assert states_term_batch == states_term_sorted - assert torch.equal( - states_term_policy_batch, - tfloat( - env.statebatch2policy(states_term_sorted), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_term_policy_batch, env.states2policy(states_term_sorted)) @pytest.mark.repeat(N_REPETITIONS) @@ -1043,14 +938,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) # Check actions actions_batch = batch.get_actions() assert torch.equal( @@ -1077,14 +965,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) # Check parents_all if not env.continuous: parents_all_batch, parents_all_a_batch, _ = batch.get_parents_all() @@ -1101,14 +982,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) @@ -1125,14 +999,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): ) else: assert states_term_batch == states_term_sorted - assert torch.equal( - states_term_policy_batch, - tfloat( - env.statebatch2policy(states_term_sorted), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_term_policy_batch, env.states2policy(states_term_sorted)) @pytest.mark.repeat(N_REPETITIONS)