-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add uniform grid sampler * Add tests for regular grid sampler --------- Co-authored-by: Samuel Burbulla <[email protected]>
- Loading branch information
1 parent
5bd53b3
commit 5bde5cd
Showing
6 changed files
with
357 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
""" | ||
`continuity.discrete.regular_grid` | ||
Samplers sampling on a regular grid from n-dimensional boxes. | ||
""" | ||
|
||
import torch | ||
|
||
from .box_sampler import BoxSampler | ||
|
||
|
||
class RegularGridSampler(BoxSampler): | ||
"""Regular Grid sampler class. | ||
A class for generating regularly spaced samples within an n-dimensional box | ||
defined by its minimum and maximum corner points. This sampler creates a | ||
regular grid evenly spaced in each dimension, trying to create sub-boxes | ||
that are as close to cubical as possible. | ||
If cubical sub-boxes are not possible (e.g., drawing 8 samples from a unit | ||
square as 8 is not a power of 2), the `prefer_more_samples` flag determines | ||
of we either over or undersample the domain: If the product of the number of | ||
samples in each dimension does not equal the requested number of samples, | ||
the most under/over-sampled dimension will gain/lose one sample. | ||
Args: | ||
x_min: The minimum corner point of the n-dimensional box, specifying the start of each dimension. | ||
x_max: The maximum corner point of the n-dimensional box, specifying the end of each dimension. | ||
prefer_more_samples: Flag indicating whether to prefer a sample count slightly above (True) or below (False) the | ||
desired total if an exact match isn't possible due to the properties of the regular grid. Defaults to True. | ||
Example: | ||
``` | ||
min_corner = torch.tensor([0, 0, 0]) # Define the minimum corner of the box | ||
max_corner = torch.tensor([1, 1, 1]) # Define the maximum corner of the box | ||
sampler = RegularGridSampler(min_corner, max_corner, prefer_more_samples=True) | ||
samples = sampler(100) | ||
print(samples.shape) | ||
``` | ||
Output: | ||
``` | ||
torch.Size([125, 3]) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, x_min: torch.Tensor, x_max: torch.Tensor, prefer_more_samples: bool = True | ||
): | ||
super().__init__(x_min, x_max) | ||
self.prefer_more_samples = prefer_more_samples | ||
if torch.allclose(self.x_delta, torch.zeros(self.x_delta.shape)): | ||
# all samples are drawn from the same point | ||
self.x_aspect = torch.zeros(self.x_delta.shape) | ||
self.x_aspect[0] = 1.0 | ||
else: | ||
abs_x_delta = torch.abs(self.x_delta) | ||
self.x_aspect = abs_x_delta / torch.sum(abs_x_delta) | ||
|
||
def __call__(self, n_samples: int) -> torch.Tensor: | ||
"""Generate a uniformly spaced grid of samples within an n-dimensional box. | ||
Args: | ||
n_samples: The number of samples to generate. | ||
Returns: | ||
Tensor containing the samples of shape (~n_samples, ndim) | ||
""" | ||
samples_per_dim = self.__calculate_samples_per_dim(n_samples) | ||
samples_per_dim = self.__adjust_samples_to_fit(n_samples, samples_per_dim) | ||
|
||
# Generate grid | ||
grids = [ | ||
torch.linspace(start, end, n_samples_dim) | ||
for start, end, n_samples_dim in zip( | ||
self.x_min, self.x_max, samples_per_dim | ||
) | ||
] | ||
mesh = torch.meshgrid(*grids, indexing="ij") | ||
|
||
return torch.stack(mesh, dim=-1).reshape(-1, self.ndim) | ||
|
||
def __calculate_samples_per_dim(self, n_samples: int) -> torch.Tensor: | ||
"""Calculate the (floating point) number of samples in each dimension to | ||
obtain an evenly spaced grid. This method also ensures that there is at | ||
least one sample in each dimension. The implemented method is best | ||
understood by the following example. | ||
Example: | ||
For `x_min = [0, 0, 1]`, `xmax = [1, 2, 1]` and `n_samples = 200`, this method computes: | ||
``` | ||
x_aspect = [1/3, 2/3, 0] | ||
mask = [1, 1, 0] | ||
scale_fac = 2/9 | ||
relevant_ndim = 2 | ||
samples_per_dim = (200 / (2/9))^(1 / 2) = sqrt(900) = 30 | ||
samples_per_dim = x_aspect*30 = [10, 20, 0] | ||
samples_per_dim = max(samples_per_dim, 1) = [10, 20, 1] | ||
``` | ||
Output: | ||
``` | ||
tensor([10, 20, 1]) | ||
``` | ||
Args: | ||
n_samples: Desired total number of samples. | ||
Returns: | ||
Approximate number of samples for each dimension as a float vector. | ||
""" | ||
mask = ~torch.isclose(self.x_aspect, torch.zeros(self.x_aspect.shape)) | ||
scale_fac = torch.prod(self.x_aspect[mask]) | ||
relevant_ndim = torch.sum(mask) | ||
samples_per_dim = torch.pow(n_samples / scale_fac, 1 / relevant_ndim) | ||
samples_per_dim = self.x_aspect * samples_per_dim | ||
samples_per_dim = torch.max( | ||
samples_per_dim, torch.ones(samples_per_dim.shape) | ||
) # ensure every dimension is sampled | ||
return samples_per_dim | ||
|
||
def __adjust_samples_to_fit( | ||
self, n_samples: int, samples_per_dim: torch.Tensor | ||
) -> torch.Tensor: | ||
"""Round and adjust the `samples_per_dim` to fit the `n_samples` requirement. | ||
The result of `__calculate_samples_per_dim` is a floating point | ||
representation, which is rounded by this method to the next integer value. | ||
If the product of the rounded samples equals the required number of samples, | ||
we return this number. Otherwise, the most under-sampled dimension or | ||
the most over-sampled dimension will gain or lose one sample, according | ||
to the `prefer_more_samples` flag. | ||
Args: | ||
n_samples: Desired total number of samples. | ||
samples_per_dim: Initial distribution of samples across dimensions. | ||
Returns: | ||
Adjusted number of samples for each dimension as an integer vector. | ||
""" | ||
samples_per_dim_int = torch.round(samples_per_dim).to(dtype=torch.int) | ||
current_total = torch.prod(samples_per_dim_int) | ||
|
||
if current_total == n_samples: | ||
# no need to adjust anymore | ||
return samples_per_dim_int | ||
|
||
sample_diff = samples_per_dim - samples_per_dim_int | ||
|
||
if current_total > n_samples and not self.prefer_more_samples: | ||
# decrease samples in most over-sampled dimension | ||
dim = torch.argmin(sample_diff) | ||
samples_per_dim_int[dim] -= 1 | ||
|
||
elif current_total < n_samples and self.prefer_more_samples: | ||
# increase samples in most under-sampled dimension | ||
dim = torch.argmax(sample_diff) | ||
samples_per_dim_int[dim] += 1 | ||
|
||
return samples_per_dim_int |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import pytest | ||
import torch | ||
|
||
from continuity.discrete.box_sampler import BoxSampler | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def generic_box_sampler(): | ||
class GenericBoxSampler(BoxSampler): | ||
def __call__(self, n_samples: int) -> torch.Tensor: | ||
return torch.zeros((n_samples, self.ndim)) | ||
|
||
return GenericBoxSampler | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def unit_box_sampler(generic_box_sampler): | ||
return generic_box_sampler( | ||
torch.zeros( | ||
5, | ||
), | ||
torch.ones( | ||
5, | ||
), | ||
) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def random_box_sampler(generic_box_sampler): | ||
return generic_box_sampler( | ||
-2.0 | ||
* torch.rand( | ||
5, | ||
), | ||
2 | ||
* torch.rand( | ||
5, | ||
), | ||
) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def sampler_list(unit_box_sampler, random_box_sampler): | ||
return [unit_box_sampler, random_box_sampler] | ||
|
||
|
||
def test_can_initialize(sampler_list, generic_box_sampler): | ||
for sampler in sampler_list: | ||
assert isinstance(sampler, generic_box_sampler) | ||
|
||
|
||
def test_delta_correct(sampler_list): | ||
for sampler in sampler_list: | ||
assert torch.allclose(sampler.x_max - sampler.x_min, sampler.x_delta) | ||
|
||
|
||
def test_dim_correct(sampler_list): | ||
for sampler in sampler_list: | ||
assert sampler.ndim == 5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import pytest | ||
import torch | ||
from typing import List | ||
|
||
from continuity.discrete import RegularGridSampler | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def regular_grid_sampler() -> RegularGridSampler: | ||
return RegularGridSampler(x_min=torch.zeros((7,)), x_max=torch.ones((7,))) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def regular_grid_sampler_negative() -> RegularGridSampler: | ||
return RegularGridSampler(x_min=torch.zeros(2), x_max=torch.tensor(([1.0, -1.0]))) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def regular_grid_random_sampler() -> RegularGridSampler: | ||
x_min = -2.0 * torch.rand((5,)) | ||
x_max = 2.0 * torch.rand((5,)) | ||
return RegularGridSampler(x_min=x_min, x_max=x_max) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def regular_grid_sampler_under() -> RegularGridSampler: | ||
return RegularGridSampler( | ||
x_min=torch.zeros((7,)), x_max=torch.ones((7,)), prefer_more_samples=False | ||
) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def regular_grid_random_sampler_under() -> RegularGridSampler: | ||
x_min = -2.0 * torch.rand((5,)) | ||
x_max = 2.0 * torch.rand((5,)) | ||
return RegularGridSampler(x_min=x_min, x_max=x_max, prefer_more_samples=False) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def sampler_list_over( | ||
regular_grid_sampler, regular_grid_sampler_negative, regular_grid_random_sampler | ||
) -> List[RegularGridSampler]: | ||
return [ | ||
regular_grid_sampler, | ||
regular_grid_sampler_negative, | ||
regular_grid_random_sampler, | ||
] | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def sampler_list_under( | ||
regular_grid_sampler_under, regular_grid_random_sampler_under | ||
) -> List[RegularGridSampler]: | ||
return [regular_grid_sampler_under, regular_grid_random_sampler_under] | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def sampler_list(sampler_list_over, sampler_list_under) -> List[RegularGridSampler]: | ||
return sampler_list_over + sampler_list_under | ||
|
||
|
||
def test_can_initialize(sampler_list): | ||
for sampler in sampler_list: | ||
assert isinstance(sampler, RegularGridSampler) | ||
|
||
|
||
def test_sample_within_bounds(sampler_list): | ||
n_samples = 2**12 | ||
for sampler in sampler_list: | ||
samples = sampler(n_samples) | ||
samples_min, _ = samples.min(dim=0) | ||
samples_max, _ = samples.max(dim=0) | ||
box_min = torch.min(sampler.x_min, sampler.x_max) | ||
box_max = torch.max(sampler.x_min, sampler.x_max) | ||
assert torch.greater_equal(samples_min, box_min).all() | ||
assert torch.less_equal(samples_max, box_max).all() | ||
|
||
|
||
def test_perfect_samples(regular_grid_sampler, regular_grid_sampler_under): | ||
for sampler in [regular_grid_sampler, regular_grid_sampler_under]: | ||
n_samples = 10**sampler.ndim | ||
samples = sampler(n_samples) | ||
|
||
assert samples.size(0) == n_samples | ||
|
||
|
||
def test_samples_under(sampler_list_under): | ||
for sampler in sampler_list_under: | ||
n_samples = 10**sampler.ndim + 1 | ||
samples = sampler(n_samples) | ||
|
||
assert samples.size(0) < n_samples | ||
|
||
|
||
def test_samples_over(sampler_list_over): | ||
for sampler in sampler_list_over: | ||
n_samples = 10**sampler.ndim + 1 | ||
samples = sampler(n_samples) | ||
|
||
assert samples.size(0) > n_samples | ||
|
||
|
||
def test_dist_zero_single(): | ||
"""delta x in a single dimension is zero.""" | ||
n_samples = 121 | ||
sampler = RegularGridSampler(torch.zeros(3), torch.tensor([1.0, 1.0, 0.0])) | ||
samples = sampler(n_samples) | ||
|
||
assert samples.size(0) == n_samples | ||
|
||
|
||
def test_dist_zero_double(): | ||
"""delta x in multiple dimensions is zero.""" | ||
n_samples = 100 | ||
sampler = RegularGridSampler(torch.zeros(3), torch.tensor([0.0, 1.0, 0.0])) | ||
samples = sampler(n_samples) | ||
|
||
assert samples.size(0) == n_samples | ||
|
||
|
||
def test_dist_zero_all(): | ||
"""samples are drawn from a single point""" | ||
n_samples = 100 | ||
sampler = RegularGridSampler(torch.zeros(3), torch.zeros(3)) | ||
samples = sampler(n_samples) | ||
|
||
assert samples.size(0) == n_samples | ||
|
||
|
||
def test_dist_neg(): | ||
n_samples = 100 | ||
sampler = RegularGridSampler(torch.zeros(3), torch.tensor([0.0, -1.0, 0.0])) | ||
samples = sampler(n_samples) | ||
|
||
assert samples.size(0) == n_samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters