diff --git a/normalizing_flows/flows.py b/normalizing_flows/flows.py index e43a568..cea0e02 100644 --- a/normalizing_flows/flows.py +++ b/normalizing_flows/flows.py @@ -341,23 +341,30 @@ def log_prob(self, x: torch.Tensor, context: torch.Tensor = None): """ return self.forward_with_log_prob(x, context)[1] - def sample(self, n: int, context: torch.Tensor = None, no_grad: bool = False, return_log_prob: bool = False): + def sample(self, + sample_shape: Union[int, torch.Size, Tuple[int, ...]], + context: torch.Tensor = None, + no_grad: bool = False, + return_log_prob: bool = False): """ Sample from the normalizing flow. If context given, sample n tensors for each context tensor. Otherwise, sample n tensors. - :param n: number of tensors to sample. + :param sample_shape: shape of tensors to sample. :param context: context tensor with shape c. :param no_grad: if True, do not track gradients in the inverse pass. :return: samples with shape (n, *event_shape) if no context given or (n, *c, *event_shape) if context given. """ + if isinstance(sample_shape, int): + sample_shape = (sample_shape,) if context is not None: - sample_shape = torch.Size((n, len(context))) + sample_shape = (*sample_shape, len(context)) z = self.base_sample(sample_shape=sample_shape) - context = context[None].repeat(*[n, *([1] * len(context.shape))]) # Make context shape match z shape + context = context[None].repeat( + *[sample_shape, *([1] * len(context.shape))]) # Make context shape match z shape assert z.shape[:2] == context.shape[:2] else: sample_shape = torch.Size((n,))