Skip to content

Commit

Permalink
Feature: Quantile Scaler (#112)
Browse files Browse the repository at this point in the history
* Add quantile scaler.

Co-authored-by: Samuel Burbulla <[email protected]>

---------

Co-authored-by: Samuel Burbulla <[email protected]>
  • Loading branch information
JakobEliasWagner and samuelburbulla authored Apr 10, 2024
1 parent a82dfd4 commit fcd6656
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- Add `FourierLayer` and `FourierNeuralOperator` with example.
- Add `benchmarks` infrastructure.
- An `Operator` now takes a `device` argument.
- Add `QuantileScaler` class.

## 0.0.0 (2024-02-22)

Expand Down
7 changes: 2 additions & 5 deletions src/continuity/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from .transform import Transform
from .compose import Compose
from .scaling import Normalize
from .quantile_scaler import QuantileScaler

__all__ = [
"Transform",
"Compose",
"Normalize",
]
__all__ = ["Transform", "Compose", "Normalize", "QuantileScaler"]
178 changes: 178 additions & 0 deletions src/continuity/transforms/quantile_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""
`continuity.transforms.quantile_scaler`
Quantile Scaler class.
"""

import torch
from continuity.transforms import Transform
from typing import Union, Tuple


class QuantileScaler(Transform):
"""Quantile Scaler Class.
A transform for scaling input data to a specified target distribution using quantiles. This is
particularly useful for normalizing data in a way that is more robust to outliers than standard
z-score normalization.
The transformation maps the quantiles of the input data to the quantiles of the target distribution,
effectively performing a non-linear scaling that preserves the relative distribution of the data.
Args:
src: tensor from which the source distribution is drawn.
n_quantile_intervals: Number of individual bins into which the data is categorized.
target_mean: Mean of the target Gaussian distribution. Can be float (all dimensions use the same mean), or
tensor (allows for different means along different dimensions).
target_std: Std of the target Gaussian distribution. Can be float (all dimensions use the same std), or
tensor (allows for different stds along different dimensions).
eps: Small value to bound the target distribution to a finite interval.
"""

def __init__(
self,
src: torch.Tensor,
n_quantile_intervals: int = 1000,
target_mean: Union[float, torch.Tensor] = 0.0,
target_std: Union[float, torch.Tensor] = 1.0,
eps: float = 1e-3,
):
assert eps <= 0.5
assert eps >= 0

if isinstance(target_mean, float):
target_mean = target_mean * torch.ones(1)
if isinstance(target_std, float):
target_std = target_std * torch.ones(1)
self.target_mean = target_mean
self.target_std = target_std

assert n_quantile_intervals > 0
self.n_quantile_intervals = n_quantile_intervals
self.n_q_points = n_quantile_intervals + 2 # n intervals have n + 2 edges

self.n_dim = src.size(-1)

# source "distribution"
self.quantile_fractions = torch.linspace(0, 1, self.n_q_points)
self.quantile_points = torch.quantile(
src.view(-1, self.n_dim),
self.quantile_fractions,
dim=0,
interpolation="linear",
)
self.deltas = self.quantile_points[1:] - self.quantile_points[:-1]

# target distribution
self.target_distribution = torch.distributions.normal.Normal(
target_mean, target_std
)
self.target_quantile_fractions = torch.linspace(
0 + eps, 1 - eps, self.n_q_points
) # bounded domain
target_quantile_points = self.target_distribution.icdf(
self.target_quantile_fractions
)
self.target_quantile_points = target_quantile_points.unsqueeze(1).repeat(
1, self.n_dim
)
self.target_deltas = (
self.target_quantile_points[1:] - self.target_quantile_points[:-1]
)

super().__init__()

def _get_scaling_indices(
self, src: torch.Tensor, quantile_tensor: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Method to get the indices of a tensor closest to src.
Args:
src: Input tensor.
quantile_tensor: Tensor containing quantile interval information of a distribution.
Returns:
Tuple containing the indices with the same shape as src with indices of quantile_tensor where the distance
between src and quantile_tensor is minimal, according to the last dim.
"""
assert src.size(-1) == self.n_dim

# preprocess tensors
v1 = src
v2 = quantile_tensor
work_ndim = max([v1.ndim, v2.ndim])

v2_shape = [1] * (work_ndim - v2.ndim) + list(v2.shape)
v2 = v2.view(*v2_shape)
v2 = v2.unsqueeze(0)

v1_shape = [1] * (work_ndim - v1.ndim) + list(v1.shape)
v1 = v1.view(*v1_shape)
v1 = v1.unsqueeze(v2.ndim - 2)

work_dims = torch.Size([max([a, b]) for a, b in zip(v1.shape, v2.shape)])
v1 = v1.expand(work_dims)
v2 = v2.expand(work_dims)

# find left boundary inside quantile intervals
diff = v2 - v1
diff[diff >= 0] = -torch.inf # discard right boundaries
indices = diff.argmax(dim=-2) # defaults to zero when all values are -inf
indices[indices > self.n_quantile_intervals] -= 1 # right boundary overflow

# prepare for indexing
return (
indices.view(-1),
torch.arange(self.n_dim).repeat(src.nelement() // self.n_dim),
)

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""Transforms the input tensor to match the target distribution using quantile scaling.
Args:
tensor: The input tensor to transform.
Returns:
The transformed tensor, scaled to the target distribution.
"""
indices = self._get_scaling_indices(tensor, self.quantile_points)
# Scale input tensor to the unit interval based on source quantiles
p_min = self.quantile_points[indices].view(tensor.shape)
delta = self.deltas[indices].view(tensor.shape)
out = tensor - p_min
out = out / delta

# Scale and shift to match the target distribution
p_t_min = self.target_quantile_points[indices].view(tensor.shape)
delta_t = self.target_deltas[indices].view(tensor.shape)
out = out * delta_t
out = out + p_t_min

return out

def undo(self, tensor: torch.Tensor) -> torch.Tensor:
"""Reverses the transformation applied by the forward method, mapping the tensor back to its original
distribution.
Args:
tensor: The tensor to reverse the transformation on.
Returns:
The tensor with the quantile scaling transformation reversed according to the src distribution.
"""
indices = self._get_scaling_indices(tensor, self.target_quantile_points)

# Scale input tensor to the unit interval based on the target distribution
p_t_min = self.target_quantile_points[indices].view(tensor.shape)
delta_t = self.target_deltas[indices].view(tensor.shape)
out = tensor - p_t_min
out = out / delta_t

# Scale and shift to match the src distribution
p_min = self.quantile_points[indices].view(tensor.shape)
delta = self.deltas[indices].view(tensor.shape)
out = out * delta
out = out + p_min

return out
106 changes: 106 additions & 0 deletions tests/transforms/test_quantile_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import pytest
import torch

from continuity.transforms import QuantileScaler


@pytest.fixture(scope="module")
def random_multiscale_tensor():
t = torch.rand(97, 37, 7)
t *= 5
t = 10**t
return t


@pytest.fixture(scope="module")
def quantile_scaler(random_multiscale_tensor):
return QuantileScaler(
src=random_multiscale_tensor, target_mean=0.0, target_std=1.0, eps=1e-2
)


class TestQuantileScaler:
def test_can_initialize(self, quantile_scaler):
isinstance(quantile_scaler, QuantileScaler)

def test_forward_shape(self, quantile_scaler, random_multiscale_tensor):
out = quantile_scaler(random_multiscale_tensor)
assert out.shape == random_multiscale_tensor.shape

def test_forward(self, quantile_scaler, random_multiscale_tensor):
out = quantile_scaler(random_multiscale_tensor)

dist = torch.distributions.normal.Normal(
torch.zeros(
1,
),
torch.ones(
1,
),
)
limit = dist.icdf(torch.linspace(1e-2, 1 - 1e-2, 2))

assert torch.all(torch.greater_equal(out, limit[0]))
assert torch.all(torch.less_equal(out, limit[1]))

def test_forward_ood(self, quantile_scaler):
"""out of distribution"""
exp = (torch.rand(97, 37, 7) - 0.5) * 2 * 10
t = (torch.rand(97, 37, 7) - 0.5) * 2
t = t**exp

_ = quantile_scaler(t)
assert True

def test_forward_zero_dim(self, quantile_scaler):
"""out of distribution"""
t = torch.rand(7)

out = quantile_scaler.undo(t)
assert out.shape == t.shape

def test_forward_many_dim(self, quantile_scaler):
"""out of distribution"""
t = torch.rand(1, 2, 3, 4, 7)

out = quantile_scaler.undo(t)
assert out.shape == t.shape

def test_undo_shape(self, quantile_scaler, random_multiscale_tensor):
out = quantile_scaler.undo(random_multiscale_tensor)
assert out.shape == random_multiscale_tensor.shape

def test_undo(self, quantile_scaler, random_multiscale_tensor):
out = quantile_scaler(random_multiscale_tensor)
undone = quantile_scaler.undo(out)
assert torch.allclose(undone, random_multiscale_tensor, atol=1e-5)

def test_undo_ood(self, quantile_scaler, random_multiscale_tensor):
"""out of distribution"""
dist = torch.distributions.normal.Normal(
torch.zeros(
1,
),
torch.ones(
1,
),
)
limit = dist.icdf(torch.linspace(1e-2, 1 - 1e-2, 2))
limit *= 10 # max and min by src dist
test_tensor = torch.linspace(*limit, 700).reshape(1, 100, 7)
_ = quantile_scaler.undo(test_tensor)
assert True

def test_undo_zero_dim(self, quantile_scaler):
"""out of distribution"""
t = torch.rand(7)

out = quantile_scaler.undo(t)
assert out.shape == t.shape

def test_undo_many_dim(self, quantile_scaler):
"""out of distribution"""
t = torch.rand(1, 2, 3, 4, 7)

out = quantile_scaler.undo(t)
assert out.shape == t.shape

0 comments on commit fcd6656

Please sign in to comment.