From 38d29e1b37400a67b013ded56b14370a24a2edda Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Sat, 19 Oct 2024 16:47:55 +0200 Subject: [PATCH] Fix typehint syntaxt --- .../bijections/finite/autoregressive/conditioning/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchflows/bijections/finite/autoregressive/conditioning/transforms.py b/torchflows/bijections/finite/autoregressive/conditioning/transforms.py index e34d947..008a2cc 100644 --- a/torchflows/bijections/finite/autoregressive/conditioning/transforms.py +++ b/torchflows/bijections/finite/autoregressive/conditioning/transforms.py @@ -273,7 +273,7 @@ def __init__(self, *args, **kwargs): class FeedForward(TensorConditionerTransform): def __init__(self, input_event_shape: Union[torch.Size, Tuple[int, ...]], - parameter_shape: torch.Union[torch.Size, Tuple[int, ...]], + parameter_shape: Union[torch.Size, Tuple[int, ...]], context_shape: Union[torch.Size, Tuple[int, ...]] = None, n_hidden: int = None, n_layers: int = 2,