Skip to content

Commit

Permalink
Implement functionality to handle max_reward
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed May 21, 2024
1 parent abe50cb commit a1462f1
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 17 deletions.
2 changes: 1 addition & 1 deletion gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,7 +1538,7 @@ def sample_from_reward(
format.
"""
samples_final = []
max_reward = self.proxy.proxy2reward(self.proxy.min)
max_reward = self.get_max_reward()
while len(samples_final) < n_samples:
if proposal_distribution == "uniform":
# TODO: sample only the remaining number of samples
Expand Down
48 changes: 47 additions & 1 deletion gflownet/proxy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,59 @@ def get_min_reward(self, log: bool = False) -> float:
Returns
-------
float
The mimnimum (log) reward.
The minimum (log) reward.
"""
if log:
return self.logreward_min
else:
return self.reward_min

def get_max_reward(self, log: bool = False) -> float:
"""
Returns the maximum value of the (log) reward, retrieved from self.optimum, in
case it is defined.
Parameters
----------
log : bool
If True, returns the logarithm of the maximum reward. If False (default),
returns the natural maximum reward.
Returns
-------
float
The maximum (log) reward.
"""
if log:
return self.proxy2logreward(self.optimum)
else:
return self.proxy2reward(self.optimum)

@property
def optimum(self):
"""
Returns the optimum value of the proxy.
Not implemented by default but may be implemented for synthetic proxies or when
the optimum is known.
The optimum is used, for example, to accelerate rejection sampling, to sample
from the reward function.
"""
if not hasattr(self, "_optimum"):
raise NotImplementedError(
"The optimum value of the proxy needs to be implemented explicitly for "
f"each Proxy and is not available for {self.__class__}."
)
return self._optimum

@optimum.setter
def optimum(self, value):
"""
Sets the optimum value of the proxy.
"""
self._optimum = value

def _get_reward_functions(
self,
reward_function: Union[Callable, str],
Expand Down
8 changes: 4 additions & 4 deletions gflownet/proxy/corners.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def setup(self, env=None):
self.mulnormal_norm = 1.0 / ((2 * torch.pi) ** 2 * cov_det) ** 0.5

@property
def min(self):
if not hasattr(self, "_min"):
def optimum(self):
if not hasattr(self, "_optimum"):
mode = self.mu * torch.ones(
self.n_dim, device=self.device, dtype=self.float
)
self._min = self(torch.unsqueeze(mode, 0))[0]
return self._min
self._optimum = self(torch.unsqueeze(mode, 0))[0]
return self._optimum

def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]:
return self.mulnormal_norm * torch.exp(
Expand Down
10 changes: 5 additions & 5 deletions gflownet/proxy/torus.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ def setup(self, env=None):
self.n_dim = env.n_dim

@property
def min(self):
if not hasattr(self, "_min"):
def optimum(self):
if not hasattr(self, "_optimum"):
if self.normalize:
self._min = torch.tensor(0.0, device=self.device, dtype=self.float)
self._optimum = torch.tensor(1.0, device=self.device, dtype=self.float)
else:
self._min = torch.tensor(
self._optimum = torch.tensor(
((self.n_dim * 2) ** 3), device=self.device, dtype=self.float
)
return self._min
return self._optimum

@property
def norm(self):
Expand Down
7 changes: 1 addition & 6 deletions gflownet/proxy/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,9 @@
class Uniform(Proxy):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._optimum = torch.tensor(1.0, device=self.device, dtype=self.float)

def __call__(
self, states: Union[List, TensorType["batch", "state_dim"]]
) -> TensorType["batch"]:
return torch.ones(len(states), device=self.device, dtype=self.float)

@property
def min(self):
if not hasattr(self, "_min"):
self._min = torch.tensor(1.0, device=self.device, dtype=self.float)
return self._min
26 changes: 26 additions & 0 deletions tests/gflownet/proxy/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,29 @@ def test_reward_function_callable__behaves_as_expected(
# Log Rewards
logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float)
assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp))


@pytest.mark.parametrize(
"proxy, beta, optimum, reward_max",
[
("uniform", None, 1.0, 1.0),
("uniform", None, 2.0, 2.0),
("proxy_power", 1, 2.0, 2.0),
("proxy_power", 2, 2.0, 4.0),
("proxy_exponential", 1, 1.0, np.exp(1.0)),
("proxy_exponential", -1, -1.0, np.exp(1.0)),
("proxy_shift", 5, 10.0, 15.0),
("proxy_shift", -5, 10.0, 5.0),
("proxy_product", 2, 2.0, 4.0),
("proxy_product", -2, -5.0, 10.0),
],
)
def test__uniform_proxy_initializes_without_errors(
proxy, beta, optimum, reward_max, request
):
proxy = request.getfixturevalue(proxy)
reward_max = torch.tensor(reward_max, dtype=proxy.float, device=proxy.device)
# Forcibly set the optimum for testing purposes, even if the proxy is uniform.
proxy.optimum = torch.tensor(optimum)
assert torch.isclose(proxy.get_max_reward(log=False), reward_max)
assert torch.isclose(proxy.get_max_reward(log=True), torch.log(reward_max))

0 comments on commit a1462f1

Please sign in to comment.