diff --git a/CHANGELOG.md b/CHANGELOG.md index 27e92096..3db26f5c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ - Move all content of `__init__.py` files to sub-modules. - Add `Trainer` class to replace `operator.fit` method. - Implement `BelNet`. -- Add `Sampler`, `BoxSampler`, and `UniformBoxSampler` classes. +- Add `Sampler`, `BoxSampler`, `UniformBoxSampler`, and `RegularGridSampler` classes. - Moved `DataLoader` into the `fit` method of the `Trainer`. Therefore, `Trainer.fit` expects an `OperatorDataset` now. diff --git a/src/continuity/discrete/__init__.py b/src/continuity/discrete/__init__.py index bea677c3..7161e7cb 100644 --- a/src/continuity/discrete/__init__.py +++ b/src/continuity/discrete/__init__.py @@ -5,7 +5,6 @@ """ from .uniform import UniformBoxSampler +from .regular_grid import RegularGridSampler -__all__ = [ - "UniformBoxSampler", -] +__all__ = ["UniformBoxSampler", "RegularGridSampler"] diff --git a/src/continuity/discrete/regular_grid.py b/src/continuity/discrete/regular_grid.py new file mode 100644 index 00000000..41377baf --- /dev/null +++ b/src/continuity/discrete/regular_grid.py @@ -0,0 +1,160 @@ +""" +`continuity.discrete.regular_grid` + +Samplers sampling on a regular grid from n-dimensional boxes. +""" + +import torch + +from .box_sampler import BoxSampler + + +class RegularGridSampler(BoxSampler): + """Regular Grid sampler class. + + A class for generating regularly spaced samples within an n-dimensional box + defined by its minimum and maximum corner points. This sampler creates a + regular grid evenly spaced in each dimension, trying to create sub-boxes + that are as close to cubical as possible. + + If cubical sub-boxes are not possible (e.g., drawing 8 samples from a unit + square as 8 is not a power of 2), the `prefer_more_samples` flag determines + of we either over or undersample the domain: If the product of the number of + samples in each dimension does not equal the requested number of samples, + the most under/over-sampled dimension will gain/lose one sample. + + Args: + x_min: The minimum corner point of the n-dimensional box, specifying the start of each dimension. + x_max: The maximum corner point of the n-dimensional box, specifying the end of each dimension. + prefer_more_samples: Flag indicating whether to prefer a sample count slightly above (True) or below (False) the + desired total if an exact match isn't possible due to the properties of the regular grid. Defaults to True. + + + Example: + ``` + min_corner = torch.tensor([0, 0, 0]) # Define the minimum corner of the box + max_corner = torch.tensor([1, 1, 1]) # Define the maximum corner of the box + sampler = RegularGridSampler(min_corner, max_corner, prefer_more_samples=True) + samples = sampler(100) + print(samples.shape) + ``` + Output: + ``` + torch.Size([125, 3]) + ``` + """ + + def __init__( + self, x_min: torch.Tensor, x_max: torch.Tensor, prefer_more_samples: bool = True + ): + super().__init__(x_min, x_max) + self.prefer_more_samples = prefer_more_samples + if torch.allclose(self.x_delta, torch.zeros(self.x_delta.shape)): + # all samples are drawn from the same point + self.x_aspect = torch.zeros(self.x_delta.shape) + self.x_aspect[0] = 1.0 + else: + abs_x_delta = torch.abs(self.x_delta) + self.x_aspect = abs_x_delta / torch.sum(abs_x_delta) + + def __call__(self, n_samples: int) -> torch.Tensor: + """Generate a uniformly spaced grid of samples within an n-dimensional box. + + Args: + n_samples: The number of samples to generate. + + Returns: + Tensor containing the samples of shape (~n_samples, ndim) + """ + samples_per_dim = self.__calculate_samples_per_dim(n_samples) + samples_per_dim = self.__adjust_samples_to_fit(n_samples, samples_per_dim) + + # Generate grid + grids = [ + torch.linspace(start, end, n_samples_dim) + for start, end, n_samples_dim in zip( + self.x_min, self.x_max, samples_per_dim + ) + ] + mesh = torch.meshgrid(*grids, indexing="ij") + + return torch.stack(mesh, dim=-1).reshape(-1, self.ndim) + + def __calculate_samples_per_dim(self, n_samples: int) -> torch.Tensor: + """Calculate the (floating point) number of samples in each dimension to + obtain an evenly spaced grid. This method also ensures that there is at + least one sample in each dimension. The implemented method is best + understood by the following example. + + Example: + For `x_min = [0, 0, 1]`, `xmax = [1, 2, 1]` and `n_samples = 200`, this method computes: + + ``` + x_aspect = [1/3, 2/3, 0] + mask = [1, 1, 0] + scale_fac = 2/9 + relevant_ndim = 2 + samples_per_dim = (200 / (2/9))^(1 / 2) = sqrt(900) = 30 + samples_per_dim = x_aspect*30 = [10, 20, 0] + samples_per_dim = max(samples_per_dim, 1) = [10, 20, 1] + ``` + Output: + ``` + tensor([10, 20, 1]) + ``` + + Args: + n_samples: Desired total number of samples. + + Returns: + Approximate number of samples for each dimension as a float vector. + """ + mask = ~torch.isclose(self.x_aspect, torch.zeros(self.x_aspect.shape)) + scale_fac = torch.prod(self.x_aspect[mask]) + relevant_ndim = torch.sum(mask) + samples_per_dim = torch.pow(n_samples / scale_fac, 1 / relevant_ndim) + samples_per_dim = self.x_aspect * samples_per_dim + samples_per_dim = torch.max( + samples_per_dim, torch.ones(samples_per_dim.shape) + ) # ensure every dimension is sampled + return samples_per_dim + + def __adjust_samples_to_fit( + self, n_samples: int, samples_per_dim: torch.Tensor + ) -> torch.Tensor: + """Round and adjust the `samples_per_dim` to fit the `n_samples` requirement. + + The result of `__calculate_samples_per_dim` is a floating point + representation, which is rounded by this method to the next integer value. + If the product of the rounded samples equals the required number of samples, + we return this number. Otherwise, the most under-sampled dimension or + the most over-sampled dimension will gain or lose one sample, according + to the `prefer_more_samples` flag. + + Args: + n_samples: Desired total number of samples. + samples_per_dim: Initial distribution of samples across dimensions. + + Returns: + Adjusted number of samples for each dimension as an integer vector. + """ + samples_per_dim_int = torch.round(samples_per_dim).to(dtype=torch.int) + current_total = torch.prod(samples_per_dim_int) + + if current_total == n_samples: + # no need to adjust anymore + return samples_per_dim_int + + sample_diff = samples_per_dim - samples_per_dim_int + + if current_total > n_samples and not self.prefer_more_samples: + # decrease samples in most over-sampled dimension + dim = torch.argmin(sample_diff) + samples_per_dim_int[dim] -= 1 + + elif current_total < n_samples and self.prefer_more_samples: + # increase samples in most under-sampled dimension + dim = torch.argmax(sample_diff) + samples_per_dim_int[dim] += 1 + + return samples_per_dim_int diff --git a/tests/discrete/test_box.py b/tests/discrete/test_box.py new file mode 100644 index 00000000..0175cfba --- /dev/null +++ b/tests/discrete/test_box.py @@ -0,0 +1,59 @@ +import pytest +import torch + +from continuity.discrete.box_sampler import BoxSampler + + +@pytest.fixture(scope="module") +def generic_box_sampler(): + class GenericBoxSampler(BoxSampler): + def __call__(self, n_samples: int) -> torch.Tensor: + return torch.zeros((n_samples, self.ndim)) + + return GenericBoxSampler + + +@pytest.fixture(scope="module") +def unit_box_sampler(generic_box_sampler): + return generic_box_sampler( + torch.zeros( + 5, + ), + torch.ones( + 5, + ), + ) + + +@pytest.fixture(scope="module") +def random_box_sampler(generic_box_sampler): + return generic_box_sampler( + -2.0 + * torch.rand( + 5, + ), + 2 + * torch.rand( + 5, + ), + ) + + +@pytest.fixture(scope="module") +def sampler_list(unit_box_sampler, random_box_sampler): + return [unit_box_sampler, random_box_sampler] + + +def test_can_initialize(sampler_list, generic_box_sampler): + for sampler in sampler_list: + assert isinstance(sampler, generic_box_sampler) + + +def test_delta_correct(sampler_list): + for sampler in sampler_list: + assert torch.allclose(sampler.x_max - sampler.x_min, sampler.x_delta) + + +def test_dim_correct(sampler_list): + for sampler in sampler_list: + assert sampler.ndim == 5 diff --git a/tests/discrete/test_regular_grid.py b/tests/discrete/test_regular_grid.py new file mode 100644 index 00000000..487c3410 --- /dev/null +++ b/tests/discrete/test_regular_grid.py @@ -0,0 +1,135 @@ +import pytest +import torch +from typing import List + +from continuity.discrete import RegularGridSampler + + +@pytest.fixture(scope="module") +def regular_grid_sampler() -> RegularGridSampler: + return RegularGridSampler(x_min=torch.zeros((7,)), x_max=torch.ones((7,))) + + +@pytest.fixture(scope="module") +def regular_grid_sampler_negative() -> RegularGridSampler: + return RegularGridSampler(x_min=torch.zeros(2), x_max=torch.tensor(([1.0, -1.0]))) + + +@pytest.fixture(scope="module") +def regular_grid_random_sampler() -> RegularGridSampler: + x_min = -2.0 * torch.rand((5,)) + x_max = 2.0 * torch.rand((5,)) + return RegularGridSampler(x_min=x_min, x_max=x_max) + + +@pytest.fixture(scope="module") +def regular_grid_sampler_under() -> RegularGridSampler: + return RegularGridSampler( + x_min=torch.zeros((7,)), x_max=torch.ones((7,)), prefer_more_samples=False + ) + + +@pytest.fixture(scope="module") +def regular_grid_random_sampler_under() -> RegularGridSampler: + x_min = -2.0 * torch.rand((5,)) + x_max = 2.0 * torch.rand((5,)) + return RegularGridSampler(x_min=x_min, x_max=x_max, prefer_more_samples=False) + + +@pytest.fixture(scope="module") +def sampler_list_over( + regular_grid_sampler, regular_grid_sampler_negative, regular_grid_random_sampler +) -> List[RegularGridSampler]: + return [ + regular_grid_sampler, + regular_grid_sampler_negative, + regular_grid_random_sampler, + ] + + +@pytest.fixture(scope="module") +def sampler_list_under( + regular_grid_sampler_under, regular_grid_random_sampler_under +) -> List[RegularGridSampler]: + return [regular_grid_sampler_under, regular_grid_random_sampler_under] + + +@pytest.fixture(scope="module") +def sampler_list(sampler_list_over, sampler_list_under) -> List[RegularGridSampler]: + return sampler_list_over + sampler_list_under + + +def test_can_initialize(sampler_list): + for sampler in sampler_list: + assert isinstance(sampler, RegularGridSampler) + + +def test_sample_within_bounds(sampler_list): + n_samples = 2**12 + for sampler in sampler_list: + samples = sampler(n_samples) + samples_min, _ = samples.min(dim=0) + samples_max, _ = samples.max(dim=0) + box_min = torch.min(sampler.x_min, sampler.x_max) + box_max = torch.max(sampler.x_min, sampler.x_max) + assert torch.greater_equal(samples_min, box_min).all() + assert torch.less_equal(samples_max, box_max).all() + + +def test_perfect_samples(regular_grid_sampler, regular_grid_sampler_under): + for sampler in [regular_grid_sampler, regular_grid_sampler_under]: + n_samples = 10**sampler.ndim + samples = sampler(n_samples) + + assert samples.size(0) == n_samples + + +def test_samples_under(sampler_list_under): + for sampler in sampler_list_under: + n_samples = 10**sampler.ndim + 1 + samples = sampler(n_samples) + + assert samples.size(0) < n_samples + + +def test_samples_over(sampler_list_over): + for sampler in sampler_list_over: + n_samples = 10**sampler.ndim + 1 + samples = sampler(n_samples) + + assert samples.size(0) > n_samples + + +def test_dist_zero_single(): + """delta x in a single dimension is zero.""" + n_samples = 121 + sampler = RegularGridSampler(torch.zeros(3), torch.tensor([1.0, 1.0, 0.0])) + samples = sampler(n_samples) + + assert samples.size(0) == n_samples + + +def test_dist_zero_double(): + """delta x in multiple dimensions is zero.""" + n_samples = 100 + sampler = RegularGridSampler(torch.zeros(3), torch.tensor([0.0, 1.0, 0.0])) + samples = sampler(n_samples) + + assert samples.size(0) == n_samples + + +def test_dist_zero_all(): + """samples are drawn from a single point""" + n_samples = 100 + sampler = RegularGridSampler(torch.zeros(3), torch.zeros(3)) + samples = sampler(n_samples) + + assert samples.size(0) == n_samples + + +def test_dist_neg(): + n_samples = 100 + sampler = RegularGridSampler(torch.zeros(3), torch.tensor([0.0, -1.0, 0.0])) + samples = sampler(n_samples) + + assert samples.size(0) == n_samples diff --git a/tests/discrete/test_uniform.py b/tests/discrete/test_uniform.py index 04eeea9c..020cc9b2 100644 --- a/tests/discrete/test_uniform.py +++ b/tests/discrete/test_uniform.py @@ -26,11 +26,6 @@ def test_can_initialize(unit_box_sampler): assert isinstance(unit_box_sampler, UniformBoxSampler) -def test_delta_correct(sampler_list): - for sampler in sampler_list: - assert torch.allclose(sampler.x_max - sampler.x_min, sampler.x_delta) - - def test_sample_within_bounds(sampler_list): n_samples = 2**12 for sampler in sampler_list: