Skip to content

Commit

Permalink
Feature: Regular Grid Sampler (#64)
Browse files Browse the repository at this point in the history
* Add uniform grid sampler
* Add tests for regular grid sampler

---------

Co-authored-by: Samuel Burbulla <[email protected]>
  • Loading branch information
JakobEliasWagner and samuelburbulla authored Feb 28, 2024
1 parent 5bd53b3 commit 5bde5cd
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
- Move all content of `__init__.py` files to sub-modules.
- Add `Trainer` class to replace `operator.fit` method.
- Implement `BelNet`.
- Add `Sampler`, `BoxSampler`, and `UniformBoxSampler` classes.
- Add `Sampler`, `BoxSampler`, `UniformBoxSampler`, and `RegularGridSampler` classes.
- Moved `DataLoader` into the `fit` method of the `Trainer`.
Therefore, `Trainer.fit` expects an `OperatorDataset` now.

Expand Down
5 changes: 2 additions & 3 deletions src/continuity/discrete/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""

from .uniform import UniformBoxSampler
from .regular_grid import RegularGridSampler

__all__ = [
"UniformBoxSampler",
]
__all__ = ["UniformBoxSampler", "RegularGridSampler"]
160 changes: 160 additions & 0 deletions src/continuity/discrete/regular_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""
`continuity.discrete.regular_grid`
Samplers sampling on a regular grid from n-dimensional boxes.
"""

import torch

from .box_sampler import BoxSampler


class RegularGridSampler(BoxSampler):
"""Regular Grid sampler class.
A class for generating regularly spaced samples within an n-dimensional box
defined by its minimum and maximum corner points. This sampler creates a
regular grid evenly spaced in each dimension, trying to create sub-boxes
that are as close to cubical as possible.
If cubical sub-boxes are not possible (e.g., drawing 8 samples from a unit
square as 8 is not a power of 2), the `prefer_more_samples` flag determines
of we either over or undersample the domain: If the product of the number of
samples in each dimension does not equal the requested number of samples,
the most under/over-sampled dimension will gain/lose one sample.
Args:
x_min: The minimum corner point of the n-dimensional box, specifying the start of each dimension.
x_max: The maximum corner point of the n-dimensional box, specifying the end of each dimension.
prefer_more_samples: Flag indicating whether to prefer a sample count slightly above (True) or below (False) the
desired total if an exact match isn't possible due to the properties of the regular grid. Defaults to True.
Example:
```
min_corner = torch.tensor([0, 0, 0]) # Define the minimum corner of the box
max_corner = torch.tensor([1, 1, 1]) # Define the maximum corner of the box
sampler = RegularGridSampler(min_corner, max_corner, prefer_more_samples=True)
samples = sampler(100)
print(samples.shape)
```
Output:
```
torch.Size([125, 3])
```
"""

def __init__(
self, x_min: torch.Tensor, x_max: torch.Tensor, prefer_more_samples: bool = True
):
super().__init__(x_min, x_max)
self.prefer_more_samples = prefer_more_samples
if torch.allclose(self.x_delta, torch.zeros(self.x_delta.shape)):
# all samples are drawn from the same point
self.x_aspect = torch.zeros(self.x_delta.shape)
self.x_aspect[0] = 1.0
else:
abs_x_delta = torch.abs(self.x_delta)
self.x_aspect = abs_x_delta / torch.sum(abs_x_delta)

def __call__(self, n_samples: int) -> torch.Tensor:
"""Generate a uniformly spaced grid of samples within an n-dimensional box.
Args:
n_samples: The number of samples to generate.
Returns:
Tensor containing the samples of shape (~n_samples, ndim)
"""
samples_per_dim = self.__calculate_samples_per_dim(n_samples)
samples_per_dim = self.__adjust_samples_to_fit(n_samples, samples_per_dim)

# Generate grid
grids = [
torch.linspace(start, end, n_samples_dim)
for start, end, n_samples_dim in zip(
self.x_min, self.x_max, samples_per_dim
)
]
mesh = torch.meshgrid(*grids, indexing="ij")

return torch.stack(mesh, dim=-1).reshape(-1, self.ndim)

def __calculate_samples_per_dim(self, n_samples: int) -> torch.Tensor:
"""Calculate the (floating point) number of samples in each dimension to
obtain an evenly spaced grid. This method also ensures that there is at
least one sample in each dimension. The implemented method is best
understood by the following example.
Example:
For `x_min = [0, 0, 1]`, `xmax = [1, 2, 1]` and `n_samples = 200`, this method computes:
```
x_aspect = [1/3, 2/3, 0]
mask = [1, 1, 0]
scale_fac = 2/9
relevant_ndim = 2
samples_per_dim = (200 / (2/9))^(1 / 2) = sqrt(900) = 30
samples_per_dim = x_aspect*30 = [10, 20, 0]
samples_per_dim = max(samples_per_dim, 1) = [10, 20, 1]
```
Output:
```
tensor([10, 20, 1])
```
Args:
n_samples: Desired total number of samples.
Returns:
Approximate number of samples for each dimension as a float vector.
"""
mask = ~torch.isclose(self.x_aspect, torch.zeros(self.x_aspect.shape))
scale_fac = torch.prod(self.x_aspect[mask])
relevant_ndim = torch.sum(mask)
samples_per_dim = torch.pow(n_samples / scale_fac, 1 / relevant_ndim)
samples_per_dim = self.x_aspect * samples_per_dim
samples_per_dim = torch.max(
samples_per_dim, torch.ones(samples_per_dim.shape)
) # ensure every dimension is sampled
return samples_per_dim

def __adjust_samples_to_fit(
self, n_samples: int, samples_per_dim: torch.Tensor
) -> torch.Tensor:
"""Round and adjust the `samples_per_dim` to fit the `n_samples` requirement.
The result of `__calculate_samples_per_dim` is a floating point
representation, which is rounded by this method to the next integer value.
If the product of the rounded samples equals the required number of samples,
we return this number. Otherwise, the most under-sampled dimension or
the most over-sampled dimension will gain or lose one sample, according
to the `prefer_more_samples` flag.
Args:
n_samples: Desired total number of samples.
samples_per_dim: Initial distribution of samples across dimensions.
Returns:
Adjusted number of samples for each dimension as an integer vector.
"""
samples_per_dim_int = torch.round(samples_per_dim).to(dtype=torch.int)
current_total = torch.prod(samples_per_dim_int)

if current_total == n_samples:
# no need to adjust anymore
return samples_per_dim_int

sample_diff = samples_per_dim - samples_per_dim_int

if current_total > n_samples and not self.prefer_more_samples:
# decrease samples in most over-sampled dimension
dim = torch.argmin(sample_diff)
samples_per_dim_int[dim] -= 1

elif current_total < n_samples and self.prefer_more_samples:
# increase samples in most under-sampled dimension
dim = torch.argmax(sample_diff)
samples_per_dim_int[dim] += 1

return samples_per_dim_int
59 changes: 59 additions & 0 deletions tests/discrete/test_box.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
import torch

from continuity.discrete.box_sampler import BoxSampler


@pytest.fixture(scope="module")
def generic_box_sampler():
class GenericBoxSampler(BoxSampler):
def __call__(self, n_samples: int) -> torch.Tensor:
return torch.zeros((n_samples, self.ndim))

return GenericBoxSampler


@pytest.fixture(scope="module")
def unit_box_sampler(generic_box_sampler):
return generic_box_sampler(
torch.zeros(
5,
),
torch.ones(
5,
),
)


@pytest.fixture(scope="module")
def random_box_sampler(generic_box_sampler):
return generic_box_sampler(
-2.0
* torch.rand(
5,
),
2
* torch.rand(
5,
),
)


@pytest.fixture(scope="module")
def sampler_list(unit_box_sampler, random_box_sampler):
return [unit_box_sampler, random_box_sampler]


def test_can_initialize(sampler_list, generic_box_sampler):
for sampler in sampler_list:
assert isinstance(sampler, generic_box_sampler)


def test_delta_correct(sampler_list):
for sampler in sampler_list:
assert torch.allclose(sampler.x_max - sampler.x_min, sampler.x_delta)


def test_dim_correct(sampler_list):
for sampler in sampler_list:
assert sampler.ndim == 5
135 changes: 135 additions & 0 deletions tests/discrete/test_regular_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import pytest
import torch
from typing import List

from continuity.discrete import RegularGridSampler


@pytest.fixture(scope="module")
def regular_grid_sampler() -> RegularGridSampler:
return RegularGridSampler(x_min=torch.zeros((7,)), x_max=torch.ones((7,)))


@pytest.fixture(scope="module")
def regular_grid_sampler_negative() -> RegularGridSampler:
return RegularGridSampler(x_min=torch.zeros(2), x_max=torch.tensor(([1.0, -1.0])))


@pytest.fixture(scope="module")
def regular_grid_random_sampler() -> RegularGridSampler:
x_min = -2.0 * torch.rand((5,))
x_max = 2.0 * torch.rand((5,))
return RegularGridSampler(x_min=x_min, x_max=x_max)


@pytest.fixture(scope="module")
def regular_grid_sampler_under() -> RegularGridSampler:
return RegularGridSampler(
x_min=torch.zeros((7,)), x_max=torch.ones((7,)), prefer_more_samples=False
)


@pytest.fixture(scope="module")
def regular_grid_random_sampler_under() -> RegularGridSampler:
x_min = -2.0 * torch.rand((5,))
x_max = 2.0 * torch.rand((5,))
return RegularGridSampler(x_min=x_min, x_max=x_max, prefer_more_samples=False)


@pytest.fixture(scope="module")
def sampler_list_over(
regular_grid_sampler, regular_grid_sampler_negative, regular_grid_random_sampler
) -> List[RegularGridSampler]:
return [
regular_grid_sampler,
regular_grid_sampler_negative,
regular_grid_random_sampler,
]


@pytest.fixture(scope="module")
def sampler_list_under(
regular_grid_sampler_under, regular_grid_random_sampler_under
) -> List[RegularGridSampler]:
return [regular_grid_sampler_under, regular_grid_random_sampler_under]


@pytest.fixture(scope="module")
def sampler_list(sampler_list_over, sampler_list_under) -> List[RegularGridSampler]:
return sampler_list_over + sampler_list_under


def test_can_initialize(sampler_list):
for sampler in sampler_list:
assert isinstance(sampler, RegularGridSampler)


def test_sample_within_bounds(sampler_list):
n_samples = 2**12
for sampler in sampler_list:
samples = sampler(n_samples)
samples_min, _ = samples.min(dim=0)
samples_max, _ = samples.max(dim=0)
box_min = torch.min(sampler.x_min, sampler.x_max)
box_max = torch.max(sampler.x_min, sampler.x_max)
assert torch.greater_equal(samples_min, box_min).all()
assert torch.less_equal(samples_max, box_max).all()


def test_perfect_samples(regular_grid_sampler, regular_grid_sampler_under):
for sampler in [regular_grid_sampler, regular_grid_sampler_under]:
n_samples = 10**sampler.ndim
samples = sampler(n_samples)

assert samples.size(0) == n_samples


def test_samples_under(sampler_list_under):
for sampler in sampler_list_under:
n_samples = 10**sampler.ndim + 1
samples = sampler(n_samples)

assert samples.size(0) < n_samples


def test_samples_over(sampler_list_over):
for sampler in sampler_list_over:
n_samples = 10**sampler.ndim + 1
samples = sampler(n_samples)

assert samples.size(0) > n_samples


def test_dist_zero_single():
"""delta x in a single dimension is zero."""
n_samples = 121
sampler = RegularGridSampler(torch.zeros(3), torch.tensor([1.0, 1.0, 0.0]))
samples = sampler(n_samples)

assert samples.size(0) == n_samples


def test_dist_zero_double():
"""delta x in multiple dimensions is zero."""
n_samples = 100
sampler = RegularGridSampler(torch.zeros(3), torch.tensor([0.0, 1.0, 0.0]))
samples = sampler(n_samples)

assert samples.size(0) == n_samples


def test_dist_zero_all():
"""samples are drawn from a single point"""
n_samples = 100
sampler = RegularGridSampler(torch.zeros(3), torch.zeros(3))
samples = sampler(n_samples)

assert samples.size(0) == n_samples


def test_dist_neg():
n_samples = 100
sampler = RegularGridSampler(torch.zeros(3), torch.tensor([0.0, -1.0, 0.0]))
samples = sampler(n_samples)

assert samples.size(0) == n_samples
5 changes: 0 additions & 5 deletions tests/discrete/test_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ def test_can_initialize(unit_box_sampler):
assert isinstance(unit_box_sampler, UniformBoxSampler)


def test_delta_correct(sampler_list):
for sampler in sampler_list:
assert torch.allclose(sampler.x_max - sampler.x_min, sampler.x_delta)


def test_sample_within_bounds(sampler_list):
n_samples = 2**12
for sampler in sampler_list:
Expand Down

0 comments on commit 5bde5cd

Please sign in to comment.