Skip to content

Commit

Permalink
Fix event shape typehint
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Oct 13, 2023
1 parent 1293606 commit 1241200
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion normalizing_flows/bijections/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def invert(bijection: Bijection) -> Bijection:

class BijectiveComposition(Bijection):
def __init__(self,
event_shape: torch.Size, layers: List[Bijection],
event_shape: Union[torch.Size, Tuple[int, ...]],
layers: List[Bijection],
context_shape: Union[torch.Size, Tuple[int, ...]] = None):
super().__init__(event_shape=event_shape, context_shape=context_shape)
self.layers = nn.ModuleList(layers)
Expand Down

0 comments on commit 1241200

Please sign in to comment.