-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
conditional gfn #188
conditional gfn #188
Changes from 9 commits
6e8dc4d
39fb5ee
2bc2263
99afaf3
e6d25a0
2c72bf9
580c455
a74872f
056d935
96b725c
5cd32a7
877c4a0
279a313
65135c1
b4c418c
738b062
4434e5f
851e03e
5152295
1d64b55
348ee82
9120afe
f59f4de
c5ef7ea
d67dfd5
d56a798
db8844c
e03c03a
6b47e06
988faf0
f2bbce3
eb13a2d
fd3d9dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,8 +73,14 @@ def __init__( | |
self._output_dim_is_checked = False | ||
self.is_backward = is_backward | ||
|
||
def forward(self, states: States) -> TT["batch_shape", "output_dim", float]: | ||
out = self.module(self.preprocessor(states)) | ||
def forward( | ||
self, input: States | torch.Tensor | ||
) -> TT["batch_shape", "output_dim", float]: | ||
if isinstance(input, States): | ||
input = self.preprocessor(input) | ||
|
||
out = self.module(input) | ||
|
||
if not self._output_dim_is_checked: | ||
self.check_output_dim(out) | ||
self._output_dim_is_checked = True | ||
|
@@ -193,6 +199,56 @@ def to_probability_distribution( | |
|
||
return UnsqueezedCategorical(probs=probs) | ||
|
||
# LogEdgeFlows are greedy, as are more P_B. | ||
# LogEdgeFlows are greedy, as are most P_B. | ||
else: | ||
return UnsqueezedCategorical(logits=logits) | ||
|
||
|
||
class ConditionalDiscretePolicyEstimator(DiscretePolicyEstimator): | ||
r"""Container for forward and backward policy estimators for discrete environments. | ||
|
||
$s \mapsto (P_F(s' \mid s, c))_{s' \in Children(s)}$. | ||
|
||
or | ||
|
||
$s \mapsto (P_B(s' \mid s, c))_{s' \in Parents(s)}$. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. might be worth mentioning that this is a s very specific conditioning use-case, where the condition is encoded separately, and embeddings are concatenated. I don't think we can do a generic one, but this should be enough as an example ! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What other conditioning approaches would be worth including? Cross attention? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general I would think the conditioning should be embedded / encoded separately --- or would the conditioning just need to be concatenated to the state before input? I could add support for that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think there is an exhaustive list of ways we can process the condition. What you have is great as an example. I suggest you just add a comment or doc that the user might want to write their own module |
||
|
||
Attributes: | ||
temperature: scalar to divide the logits by before softmax. | ||
sf_bias: scalar to subtract from the exit action logit before dividing by | ||
temperature. | ||
epsilon: with probability epsilon, a random action is chosen. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
state_module: nn.Module, | ||
conditioning_module: nn.Module, | ||
final_module: nn.Module, | ||
n_actions: int, | ||
preprocessor: Preprocessor | None, | ||
is_backward: bool = False, | ||
): | ||
"""Initializes a estimator for P_F for discrete environments. | ||
|
||
Args: | ||
n_actions: Total number of actions in the Discrete Environment. | ||
is_backward: if False, then this is a forward policy, else backward policy. | ||
""" | ||
super().__init__(state_module, n_actions, preprocessor, is_backward) | ||
self.n_actions = n_actions | ||
self.conditioning_module = conditioning_module | ||
self.final_module = final_module | ||
|
||
def forward( | ||
self, states: States, conditioning: torch.tensor | ||
) -> TT["batch_shape", "output_dim", float]: | ||
state_out = self.module(self.preprocessor(states)) | ||
conditioning_out = self.conditioning_module(conditioning) | ||
out = self.final_module(torch.cat((state_out, conditioning_out), -1)) | ||
|
||
if not self._output_dim_is_checked: | ||
self.check_output_dim(out) | ||
self._output_dim_is_checked = True | ||
|
||
return out |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,16 +21,14 @@ class Sampler: | |
estimator: the submitted PolicyEstimator. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
estimator: GFNModule, | ||
) -> None: | ||
def __init__(self, estimator: GFNModule) -> None: | ||
self.estimator = estimator | ||
|
||
def sample_actions( | ||
self, | ||
env: Env, | ||
states: States, | ||
conditioning: torch.Tensor = None, | ||
save_estimator_outputs: bool = False, | ||
save_logprobs: bool = True, | ||
**policy_kwargs: Optional[dict], | ||
|
@@ -45,6 +43,7 @@ def sample_actions( | |
estimator: A GFNModule to pass to the probability distribution calculator. | ||
env: The environment to sample actions from. | ||
states: A batch of states. | ||
conditioning: An optional tensor of conditioning information. | ||
save_estimator_outputs: If True, the estimator outputs will be returned. | ||
save_logprobs: If True, calculates and saves the log probabilities of sampled | ||
actions. | ||
|
@@ -68,7 +67,28 @@ def sample_actions( | |
the sampled actions under the probability distribution of the given | ||
states. | ||
""" | ||
estimator_output = self.estimator(states) | ||
# TODO: Should estimators instead ignore None for the conditioning vector? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wouldn't it be cleaner with fewer if else blocks ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes there's a bit of cruft with all the if-else blocks, but as it stands an estimator can either accept one or two arguments and I think it's good if it fails noisily... what do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok ! makes sense. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added these |
||
if conditioning is not None: | ||
try: | ||
estimator_output = self.estimator(states, conditioning) | ||
except TypeError as e: | ||
print( | ||
"conditioning was passed but `estimator` is {}".format( | ||
type(self.estimator) | ||
) | ||
) | ||
raise e | ||
else: | ||
try: | ||
estimator_output = self.estimator(states) | ||
except TypeError as e: | ||
print( | ||
"conditioning was not passed but `estimator` is {}".format( | ||
type(self.estimator) | ||
) | ||
) | ||
raise e | ||
|
||
dist = self.estimator.to_probability_distribution( | ||
states, estimator_output, **policy_kwargs | ||
) | ||
|
@@ -94,6 +114,7 @@ def sample_trajectories( | |
self, | ||
env: Env, | ||
states: Optional[States] = None, | ||
conditioning: Optional[torch.Tensor] = None, | ||
n_trajectories: Optional[int] = None, | ||
save_estimator_outputs: bool = False, | ||
save_logprobs: bool = True, | ||
|
@@ -105,6 +126,7 @@ def sample_trajectories( | |
env: The environment to sample trajectories from. | ||
states: If given, trajectories would start from such states. Otherwise, | ||
trajectories are sampled from $s_o$ and n_trajectories must be provided. | ||
conditioning: An optional tensor of conditioning information. | ||
n_trajectories: If given, a batch of n_trajectories will be sampled all | ||
starting from the environment's s_0. | ||
save_estimator_outputs: If True, the estimator outputs will be returned. This | ||
|
@@ -136,6 +158,9 @@ def sample_trajectories( | |
), "States should be a linear batch of states" | ||
n_trajectories = states.batch_shape[0] | ||
|
||
if conditioning is not None: | ||
assert states.batch_shape == conditioning.shape[: len(states.batch_shape)] | ||
|
||
device = states.tensor.device | ||
|
||
dones = ( | ||
|
@@ -166,9 +191,15 @@ def sample_trajectories( | |
# during sampling. This is useful if, for example, you want to evaluate off | ||
# policy actions later without repeating calculations to obtain the env | ||
# distribution parameters. | ||
if conditioning is not None: | ||
masked_conditioning = conditioning[~dones] | ||
else: | ||
masked_conditioning = None | ||
|
||
valid_actions, actions_log_probs, estimator_outputs = self.sample_actions( | ||
env, | ||
states[~dones], | ||
masked_conditioning, | ||
save_estimator_outputs=True if save_estimator_outputs else False, | ||
save_logprobs=save_logprobs, | ||
**policy_kwargs, | ||
|
@@ -201,6 +232,7 @@ def sample_trajectories( | |
# Increment the step, determine which trajectories are finisihed, and eval | ||
# rewards. | ||
step += 1 | ||
|
||
# new_dones means those trajectories that just finished. Because we | ||
# pad the sink state to every short trajectory, we need to make sure | ||
# to filter out the already done ones. | ||
|
@@ -236,6 +268,7 @@ def sample_trajectories( | |
trajectories = Trajectories( | ||
env=env, | ||
states=trajectories_states, | ||
conditioning=conditioning, | ||
actions=trajectories_actions, | ||
when_is_done=trajectories_dones, | ||
is_backward=self.estimator.is_backward, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pleasantly surprised no change is needed for the LogPartitionVarianceLoss. Right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need the conditioning information here, and I agree it's nice that the code naturally reflected that. Please correct me if I misunderstand this loss.