Skip to content

Commit

Permalink
Merge pull request #270 from alexhernandezgarcia/conversions-simplify
Browse files Browse the repository at this point in the history
Simplify state conversions - combination
  • Loading branch information
alexhernandezgarcia authored Dec 22, 2023
2 parents 40fabb7 + 3ecea37 commit bf65b57
Show file tree
Hide file tree
Showing 29 changed files with 451 additions and 1,548 deletions.
16 changes: 0 additions & 16 deletions config/env/aptamers.yaml

This file was deleted.

42 changes: 26 additions & 16 deletions gflownet/envs/alaninedipeptide.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import List, Tuple
from typing import List, Tuple, Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -40,25 +40,34 @@ def sync_conformer_with_state(self, state: List = None):
self.conformer.set_torsion_angle(ta, state[idx])
return self.conformer

def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDArray:
# TODO: are the conversions to oracle relevant?
def states2proxy(
self, states: Union[List[List], TensorType["batch", "state_dim"]]
) -> npt.NDArray:
"""
Prepares a batch of states in torch "GFlowNet format" for the oracle.
"""
device = states.device
if device == torch.device("cpu"):
np_states = states.numpy()
else:
np_states = states.cpu().numpy()
return np_states[:, :-1]

def statebatch2proxy(self, states: List[List]) -> npt.NDArray:
"""
Prepares a batch of states in "GFlowNet format" for the proxy: a tensor where
each state is a row of length n_dim with an angle in radians. The n_actions
Prepares a batch of states in "environment format" for the proxy: each state is
a vector of length n_dim where each value is an angle in radians. The n_actions
item is removed.
Important: this method returns a numpy array, unlike in most other
environments.
Args
----
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.
Returns
-------
A numpy array containing all the states in the batch.
"""
return np.array(states)[:, :-1]
if torch.is_tensor(states[0]):
return states.cpu().numpy()[:, :-1]
else:
return np.array(states)[:, :-1]

# TODO: need to keep?
def statetorch2oracle(
self, states: TensorType["batch", "state_dim"]
) -> List[Tuple[npt.NDArray, npt.NDArray]]:
Expand All @@ -73,6 +82,7 @@ def statetorch2oracle(
result = self.statebatch2oracle(np_states)
return result

# TODO: need to keep?
def statebatch2oracle(
self, states: List[List]
) -> List[Tuple[npt.NDArray, npt.NDArray]]:
Expand Down
Loading

0 comments on commit bf65b57

Please sign in to comment.