From 7575f24bc1815697c4950c967aae2542bb4d4c93 Mon Sep 17 00:00:00 2001 From: Jakob Date: Tue, 12 Mar 2024 14:37:43 +0100 Subject: [PATCH] fix input type --- src/continuity/discrete/regular_grid.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/continuity/discrete/regular_grid.py b/src/continuity/discrete/regular_grid.py index 41377baf..01d534cc 100644 --- a/src/continuity/discrete/regular_grid.py +++ b/src/continuity/discrete/regular_grid.py @@ -5,6 +5,7 @@ """ import torch +from typing import Union from .box_sampler import BoxSampler @@ -45,7 +46,10 @@ class RegularGridSampler(BoxSampler): """ def __init__( - self, x_min: torch.Tensor, x_max: torch.Tensor, prefer_more_samples: bool = True + self, + x_min: Union[torch.Tensor, list], + x_max: Union[torch.Tensor, list], + prefer_more_samples: bool = True, ): super().__init__(x_min, x_max) self.prefer_more_samples = prefer_more_samples