Skip to content

Commit

Permalink
Merge pull request #126 from saleml/bye-env
Browse files Browse the repository at this point in the history
No more envs in estimators and samplers + Tests for scripts
  • Loading branch information
saleml authored Aug 14, 2023
2 parents 47f8d3b + cba9290 commit c8712dc
Show file tree
Hide file tree
Showing 19 changed files with 765 additions and 521 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ repos:
name: pytest-check
entry: pytest
language: python
pass_filenames: false
files: testing/
types: [python]
always_run: true
69 changes: 42 additions & 27 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "torchgfn"
packages = [{include = "gfn", from = "src"}]
version = "1.0.0"
version = "1.0.1"
description = "A torch implementation of GFlowNets"
authors = ["Salem Lahou <[email protected]>", "Joseph Viviano <[email protected]>", "Victor Schmidt <[email protected]>"]
license = "MIT"
Expand Down
17 changes: 12 additions & 5 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torchtyping import TensorType as TT

from gfn.containers import Trajectories
from gfn.env import Env
from gfn.modules import GFNModule
from gfn.samplers import Sampler
from gfn.states import States
Expand All @@ -18,30 +19,36 @@ class GFlowNet(nn.Module):
"""

@abstractmethod
def sample_trajectories(self, n_samples: int) -> Trajectories:
def sample_trajectories(self, env: Env, n_samples: int) -> Trajectories:
"""Sample a specific number of complete trajectories.
Args:
env: the environment to sample trajectories from.
n_samples: number of trajectories to be sampled.
Returns:
Trajectories: sampled trajectories object.
"""

def sample_terminating_states(self, n_samples: int) -> States:
def sample_terminating_states(self, env: Env, n_samples: int) -> States:
"""Rolls out the parametrization's policy and returns the terminating states.
Args:
env: the environment to sample terminating states from.
n_samples: number of terminating states to be sampled.
Returns:
States: sampled terminating states object.
"""
trajectories = self.sample_trajectories(n_samples)
trajectories = self.sample_trajectories(env, n_samples)
return trajectories.last_states

@abstractmethod
def to_training_samples(self, trajectories: Trajectories):
"""Converts trajectories to training samples. The type depends on the GFlowNet."""

@abstractmethod
def loss(self, env: Env, training_objects):
"""Computes the loss given the training objects."""


class PFBasedGFlowNet(GFlowNet):
r"""Base class for gflownets that explicitly uses $P_F$.
Expand All @@ -57,9 +64,9 @@ def __init__(self, pf: GFNModule, pb: GFNModule, on_policy: bool = False):
self.pb = pb
self.on_policy = on_policy

def sample_trajectories(self, n_samples: int = 1000) -> Trajectories:
def sample_trajectories(self, env: Env, n_samples: int) -> Trajectories:
sampler = Sampler(estimator=self.pf)
trajectories = sampler.sample_trajectories(n_trajectories=n_samples)
trajectories = sampler.sample_trajectories(env, n_trajectories=n_samples)
return trajectories


Expand Down
20 changes: 11 additions & 9 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from torchtyping import TensorType as TT

from gfn.containers import Trajectories, Transitions
from gfn.env import Env
from gfn.gflownet.base import PFBasedGFlowNet
from gfn.modules import ScalarEstimator
from gfn.modules import GFNModule, ScalarEstimator


class DBGFlowNet(PFBasedGFlowNet):
Expand All @@ -27,17 +28,18 @@ class DBGFlowNet(PFBasedGFlowNet):

def __init__(
self,
pf: GFNModule,
pb: GFNModule,
logF: ScalarEstimator,
on_policy: bool = False,
forward_looking: bool = False,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(pf, pb, on_policy=on_policy)
self.logF = logF
self.forward_looking = forward_looking
self.env = self.logF.env # TODO We don't want to store env in here...

def get_scores(
self, transitions: Transitions
self, env: Env, transitions: Transitions
) -> Tuple[
TT["n_transitions", float],
TT["n_transitions", float],
Expand Down Expand Up @@ -72,7 +74,7 @@ def get_scores(

valid_log_F_s = self.logF(states).squeeze(-1)
if self.forward_looking:
log_rewards = self.env.log_reward(states) # RM unsqueeze(-1)
log_rewards = env.log_reward(states) # RM unsqueeze(-1)
valid_log_F_s = valid_log_F_s + log_rewards

preds = valid_log_pf_actions + valid_log_F_s
Expand Down Expand Up @@ -110,12 +112,12 @@ def get_scores(

return (valid_log_pf_actions, log_pb_actions, scores)

def loss(self, transitions: Transitions) -> TT[0, float]:
def loss(self, env: Env, transitions: Transitions) -> TT[0, float]:
"""Detailed balance loss.
The detailed balance loss is described in section
3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266)."""
_, _, scores = self.get_scores(transitions)
_, _, scores = self.get_scores(env, transitions)
loss = torch.mean(scores**2)

if torch.isnan(loss):
Expand Down Expand Up @@ -182,7 +184,7 @@ def get_scores(self, transitions: Transitions) -> TT["n_trajectories", torch.flo

return scores

def loss(self, transitions: Transitions) -> TT[0, float]:
def loss(self, env: Env, transitions: Transitions) -> TT[0, float]:
"""Calculates the modified detailed balance loss."""
scores = self.get_scores(transitions)
return torch.mean(scores**2)
Expand Down
32 changes: 18 additions & 14 deletions src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torchtyping import TensorType as TT

from gfn.containers import Trajectories
from gfn.env import Env
from gfn.gflownet.base import GFlowNet
from gfn.modules import DiscretePolicyEstimator
from gfn.samplers import Sampler
Expand All @@ -28,23 +29,21 @@ class FMGFlowNet(GFlowNet):

def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0):
super().__init__()
assert not logF.greedy_eps

self.logF = logF
self.alpha = alpha
self.env = self.logF.env
if not self.env.is_discrete:

def sample_trajectories(self, env: Env, n_samples: int = 1000) -> Trajectories:
if not env.is_discrete:
raise NotImplementedError(
"Flow Matching GFlowNet only supports discrete environments for now."
)

def sample_trajectories(self, n_samples: int = 1000) -> Trajectories:
sampler = Sampler(estimator=self.logF)
trajectories = sampler.sample_trajectories(n_trajectories=n_samples)
trajectories = sampler.sample_trajectories(env, n_trajectories=n_samples)
return trajectories

def flow_matching_loss(
self, states: DiscreteStates
self, env: Env, states: DiscreteStates
) -> TT["n_trajectories", torch.float]:
"""Computes the FM for the provided states.
Expand All @@ -67,7 +66,7 @@ def flow_matching_loss(
states.forward_masks, -float("inf"), dtype=torch.float
)

for action_idx in range(self.env.n_actions - 1):
for action_idx in range(env.n_actions - 1):
valid_backward_mask = states.backward_masks[:, action_idx]
valid_forward_mask = states.forward_masks[:, action_idx]
valid_backward_states = states[valid_backward_mask]
Expand All @@ -76,9 +75,9 @@ def flow_matching_loss(
backward_actions = torch.full_like(
valid_backward_states.backward_masks[:, 0], action_idx, dtype=torch.long
).unsqueeze(-1)
backward_actions = self.env.Actions(backward_actions)
backward_actions = env.Actions(backward_actions)

valid_backward_states_parents = self.env.backward_step(
valid_backward_states_parents = env.backward_step(
valid_backward_states, backward_actions
)

Expand All @@ -101,8 +100,11 @@ def flow_matching_loss(

return (log_incoming_flows - log_outgoing_flows).pow(2).mean()

def reward_matching_loss(self, terminating_states: DiscreteStates) -> TT[0, float]:
def reward_matching_loss(
self, env: Env, terminating_states: DiscreteStates
) -> TT[0, float]:
"""Calculates the reward matching loss from the terminating states."""
del env # Unused
assert terminating_states.log_rewards is not None
log_edge_flows = self.logF(terminating_states)

Expand All @@ -111,16 +113,18 @@ def reward_matching_loss(self, terminating_states: DiscreteStates) -> TT[0, floa
log_rewards = terminating_states.log_rewards
return (terminating_log_edge_flows - log_rewards).pow(2).mean()

def loss(self, states_tuple: Tuple[DiscreteStates, DiscreteStates]) -> TT[0, float]:
def loss(
self, env: Env, states_tuple: Tuple[DiscreteStates, DiscreteStates]
) -> TT[0, float]:
"""Given a batch of non-terminal and terminal states, compute a loss.
Unlike the GFlowNets Foundations paper, we allow more flexibility by passing a
tuple of states, the first one being the internal states of the trajectories
(i.e. non-terminal states), and the second one being the terminal states of the
trajectories."""
intermediary_states, terminating_states = states_tuple
fm_loss = self.flow_matching_loss(intermediary_states)
rm_loss = self.reward_matching_loss(terminating_states)
fm_loss = self.flow_matching_loss(env, intermediary_states)
rm_loss = self.reward_matching_loss(env, terminating_states)
return fm_loss + self.alpha * rm_loss

def to_training_samples(
Expand Down
19 changes: 11 additions & 8 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from torchtyping import TensorType as TT

from gfn.containers import Trajectories
from gfn.gflownet.base import PFBasedGFlowNet, TrajectoryBasedGFlowNet
from gfn.modules import ScalarEstimator
from gfn.env import Env
from gfn.gflownet.base import TrajectoryBasedGFlowNet
from gfn.modules import GFNModule, ScalarEstimator


class SubTBGFlowNet(TrajectoryBasedGFlowNet):
Expand Down Expand Up @@ -43,7 +44,10 @@ class SubTBGFlowNet(TrajectoryBasedGFlowNet):

def __init__(
self,
pf: GFNModule,
pb: GFNModule,
logF: ScalarEstimator,
on_policy: bool = False,
weighting: Literal[
"DB",
"ModifiedDB",
Expand All @@ -56,9 +60,8 @@ def __init__(
lamda: float = 0.9,
log_reward_clip_min: float = -12, # roughly log(1e-5)
forward_looking: bool = False,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(pf, pb, on_policy=on_policy)
self.logF = logF
self.weighting = weighting
self.lamda = lamda
Expand Down Expand Up @@ -89,7 +92,7 @@ def cumulative_logprobs(
)

def get_scores(
self, trajectories: Trajectories
self, env: Env, trajectories: Trajectories
) -> Tuple[List[TT[0, float]], List[TT[0, float]]]:
"""Scores all submitted trajectories.
Expand Down Expand Up @@ -123,7 +126,7 @@ def get_scores(

log_F = self.logF(valid_states).squeeze(-1)
if self.forward_looking:
log_rewards = self.logF.env.log_reward(states).unsqueeze(-1)
log_rewards = env.log_reward(states).unsqueeze(-1)
log_F = log_F + log_rewards
log_state_flows[mask[:-1]] = log_F

Expand Down Expand Up @@ -188,9 +191,9 @@ def get_scores(
flattening_masks,
)

def loss(self, trajectories: Trajectories) -> TT[0, float]:
def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]:
# Get all scores and masks from the trajectories.
scores, flattening_masks = self.get_scores(trajectories)
scores, flattening_masks = self.get_scores(env, trajectories)
flattening_mask = torch.cat(flattening_masks)
all_scores = torch.cat(scores, 0)

Expand Down
24 changes: 18 additions & 6 deletions src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from torchtyping import TensorType as TT

from gfn.containers import Trajectories
from gfn.env import Env
from gfn.gflownet.base import TrajectoryBasedGFlowNet
from gfn.modules import GFNModule


class TBGFlowNet(TrajectoryBasedGFlowNet):
Expand All @@ -28,16 +30,18 @@ class TBGFlowNet(TrajectoryBasedGFlowNet):

def __init__(
self,
pf: GFNModule,
pb: GFNModule,
on_policy: bool = False,
init_logZ: float = 0.0,
log_reward_clip_min: float = -12, # roughly log(1e-5)
**kwargs,
):
super().__init__(**kwargs)
super().__init__(pf, pb, on_policy=on_policy)

self.logZ = nn.Parameter(torch.tensor(init_logZ))
self.log_reward_clip_min = log_reward_clip_min

def loss(self, trajectories: Trajectories) -> TT[0, float]:
def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]:
"""Trajectory balance loss.
The trajectory balance loss is described in 2.3 of
Expand All @@ -46,6 +50,7 @@ def loss(self, trajectories: Trajectories) -> TT[0, float]:
Raises:
ValueError: if the loss is NaN.
"""
del env # unused
_, _, scores = self.get_trajectories_scores(trajectories)
loss = (scores + self.logZ).pow(2).mean()
if torch.isnan(loss):
Expand All @@ -64,17 +69,24 @@ class LogPartitionVarianceGFlowNet(TrajectoryBasedGFlowNet):
ValueError: if the loss is NaN.
"""

def __init__(self, log_reward_clip_min: float = -12, **kwargs):
super().__init__(**kwargs)
def __init__(
self,
pf: GFNModule,
pb: GFNModule,
on_policy: bool = False,
log_reward_clip_min: float = -12,
):
super().__init__(pf, pb, on_policy=on_policy)

self.log_reward_clip_min = log_reward_clip_min # -12 is roughly log(1e-5)

def loss(self, trajectories: Trajectories) -> TT[0, float]:
def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]:
"""Log Partition Variance loss.
This method is described in section 3.2 of
[ROBUST SCHEDULING WITH GFLOWNETS](https://arxiv.org/abs/2302.05446))
"""
del env # unused
_, _, scores = self.get_trajectories_scores(trajectories)
loss = (scores - scores.mean()).pow(2).mean()
if torch.isnan(loss):
Expand Down
Loading

0 comments on commit c8712dc

Please sign in to comment.