diff --git a/normalizing_flows/bijections/finite/linear.py b/normalizing_flows/bijections/finite/linear.py index 8b98c67..a009ba2 100644 --- a/normalizing_flows/bijections/finite/linear.py +++ b/normalizing_flows/bijections/finite/linear.py @@ -14,6 +14,19 @@ from normalizing_flows.utils import get_batch_shape, flatten_event, unflatten_event, flatten_batch, unflatten_batch +class Identity(Bijection): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + batch_shape = get_batch_shape(x, self.event_shape) + return x, torch.zeros(size=batch_shape, device=x.device) + + def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + batch_shape = get_batch_shape(z, self.event_shape) + return z, torch.zeros(size=batch_shape, device=z.device) + + class LinearBijection(Bijection): def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], matrix: InvertibleMatrix): super().__init__(event_shape)