Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Masked Dataset #151

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/continuiti/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
Every data set is a list of `(x, u, y, v)` tuples.
"""

from .dataset import OperatorDataset
from .dataset import OperatorDataset, MaskedOperatorDataset
from .utility import split, dataset_loss

__all__ = [
"OperatorDataset",
"MaskedOperatorDataset",
"split",
"dataset_loss",
]
288 changes: 234 additions & 54 deletions src/continuiti/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,47 @@

import torch
import torch.utils.data as td
from typing import Optional, Tuple
from abc import ABC, abstractmethod
from torch.nn.utils.rnn import pad_sequence
from typing import Optional, Tuple, List, Union
from abc import ABC
from continuiti.transforms import Transform
from continuiti.operators.shape import OperatorShapes, TensorShape


class OperatorDatasetBase(td.Dataset, ABC):
"""Abstract base class of a dataset for operator training."""

shapes: OperatorShapes
def __init__(self, shapes: OperatorShapes, n_observations: int) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, shapes: OperatorShapes, n_observations: int) -> None:
def __init__(self, shapes: OperatorShapes, n_observations: int):

super().__init__()
self.shapes = shapes
self.n_observations = n_observations

@abstractmethod
def __len__(self) -> int:
"""Return the number of samples.
def _apply_transformations(
self, src: List[Tuple[torch.Tensor, Optional[Transform]]]
) -> List[torch.Tensor]:
"""Applies class transformations to four tensors.

Args:
src:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
src:
src: List of tuples containing a tensor and a transformation to apply to it.


Returns:
number of samples in the entire set.
Input src with class transformations applied.
"""
out = []
for src_tensor, transformation in src:
if transformation is None:
out.append(src_tensor)
continue
out.append(transformation(src_tensor))
Comment on lines +40 to +41
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
continue
out.append(transformation(src_tensor))
else:
out.append(transformation(src_tensor))

return out

@abstractmethod
def __getitem__(
self, idx
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Retrieves the input-output pair at the specified index and applies transformations.

Parameters:
- idx: The index of the sample to retrieve.
def __len__(self) -> int:
"""Return the number of observations in the dataset.

Returns:
A tuple containing the three input tensors and the output tensor for the given index.
Number of observations in the entire dataset.
"""
return self.n_observations


class OperatorDataset(OperatorDatasetBase):
Expand All @@ -57,7 +67,6 @@ class OperatorDataset(OperatorDatasetBase):

Attributes:
shapes: Shape of all tensors.
transform: Transformations for each tensor.
"""

def __init__(
Expand Down Expand Up @@ -85,39 +94,25 @@ def __init__(
assert x_size == u_size, "Inconsistent number of sensors."
assert y_size == v_size, "Inconsistent number of evaluations."

super().__init__()

self.x = x
self.u = u
self.y = y
self.v = v

# used to initialize architectures
self.shapes = OperatorShapes(
shapes = OperatorShapes(
x=TensorShape(dim=x_dim, size=x_size),
u=TensorShape(dim=u_dim, size=u_size),
y=TensorShape(dim=y_dim, size=y_size),
v=TensorShape(dim=v_dim, size=v_size),
)

self.transform = {
dim: tf
for dim, tf in [
("x", x_transform),
("u", u_transform),
("y", y_transform),
("v", v_transform),
]
if tf is not None
}
super().__init__(shapes, len(x))

def __len__(self) -> int:
"""Return the number of samples.

Returns:
Number of samples in the entire set.
"""
return self.x.size(0)
self.x_transform = x_transform
self.u_transform = u_transform
self.y_transform = y_transform
self.v_transform = v_transform

def __getitem__(
self,
Expand All @@ -131,29 +126,214 @@ def __getitem__(
Returns:
A tuple containing the three input tensors and the output tensor for the given index.
"""
return self._apply_transformations(
self.x[idx], self.u[idx], self.y[idx], self.v[idx]
tensors = self._apply_transformations(
[
(self.x[idx], self.x_transform),
(self.u[idx], self.u_transform),
(self.y[idx], self.y_transform),
(self.v[idx], self.v_transform),
]
)

def _apply_transformations(
self, x: torch.Tensor, u: torch.Tensor, y: torch.Tensor, v: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Applies class transformations to four tensors.
return tensors[0], tensors[1], tensors[2], tensors[3]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return tensors[0], tensors[1], tensors[2], tensors[3]
return tuple(tensors)



class MaskedOperatorDataset(OperatorDatasetBase):
"""A dataset for operator training containing masks in addition to tensors describing the mapping.

Data, especially described on unstructured grids, can vary in the number of evaluations or sensors. Even
measurements of phenomena do not always contain the same number of sensors and or evaluations. This dataset is able
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
measurements of phenomena do not always contain the same number of sensors and or evaluations. This dataset is able
measurements of phenomena do not always contain the same number of sensors and/or evaluations. This dataset is able

to handle datasets that have differing number of sensors or evaluations. For this masks, both for the input and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
to handle datasets that have differing number of sensors or evaluations. For this masks, both for the input and
to handle datasets that have differing number of sensors or evaluations. For this, masks, both for the input and

output space, describe which values are relevant to the dataset and which are irrelevant padding values, that
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output space, describe which values are relevant to the dataset and which are irrelevant padding values, that
output space, describe which values are relevant to the dataset and which are irrelevant padding values that

should be ignored during training and evaluation. Padding tensors is important to efficiently handle data in
batches.

Args:
x: Tensor of shape (num_observations, x_dim, num_sensors) with sensor positions or list containing
num_observations tensors of shape (x_dim, ...).
u: Tensor of shape (num_observations, u_dim, num_sensors) with evaluations of the input functions at sensor
positions or list containing num_observations tensors of shape (u_dim, ...).
y: Tensor of shape (num_observations, y_dim, num_evaluations) with evaluation positions or list containing
num_observations tensors of shape (y_dim, ...).
v: Tensor of shape (num_observations, v_dim, num_evaluations) with ground truth operator mappings or list
containing num_observations tensors of shape (v_dim, ...).
ipt_mask:Boolean tensor of shape (num_observations, num_sensors) with True indicating that a value pair of the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ipt_mask:Boolean tensor of shape (num_observations, num_sensors) with True indicating that a value pair of the
input_mask: Boolean tensor of shape (num_observations, num_sensors) with True indicating that a value pair of the

input space should be taken into consideration during training.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input space should be taken into consideration during training.
input space should be taken into consideration during training. Automatically constructed if `x` is a list.

opt_mask: Boolean tensor of shape (num_observations, num_evaluations) with True indicating that a value pair of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
opt_mask: Boolean tensor of shape (num_observations, num_evaluations) with True indicating that a value pair of
output_mask: Boolean tensor of shape (num_observations, num_evaluations) with True indicating that a value pair of

the output space should be taken into consideration during training.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
the output space should be taken into consideration during training.
the output space should be taken into consideration during training. Automatically constructed if `y` is a list.

x_transform: Transformation applied to x.
u_transform: Transformation applied to u.
y_transform: Transformation applied to y.
v_transform: Transformation applied to v.

"""

def __init__(
self,
x: Union[torch.Tensor, List[torch.Tensor]],
u: Union[torch.Tensor, List[torch.Tensor]],
y: Union[torch.Tensor, List[torch.Tensor]],
v: Union[torch.Tensor, List[torch.Tensor]],
ipt_mask: Optional[torch.Tensor] = None,
opt_mask: Optional[torch.Tensor] = None,
Comment on lines +177 to +178
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer more verbose names like input_mask and output_mask

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar for input_is_list and output_is_list

x_transform: Optional[Transform] = None,
u_transform: Optional[Transform] = None,
y_transform: Optional[Transform] = None,
v_transform: Optional[Transform] = None,
) -> None:
assert (
len(x) == len(u) == len(y) == len(v)
), f"All tensors need to have the same number of observations, but found {len(x)}, {len(u)}, {len(y)}, {len(v)}."
ipt_is_list = isinstance(x, list)
assert self._is_valid_space(x, u, ipt_mask)
opt_is_list = isinstance(y, list)
assert self._is_valid_space(y, v, opt_mask)

if ipt_is_list:
x, u, ipt_mask = self._pad_list_space(x, u)

if opt_is_list:
y, v, opt_mask = self._pad_list_space(y, v)

self.x = x
self.u = u
self.y = y
self.v = v

self.ipt_mask = ipt_mask
self.opt_mask = opt_mask

self.x_transform = x_transform
self.u_transform = u_transform
self.y_transform = y_transform
self.v_transform = v_transform

super().__init__(
shapes=OperatorShapes(
x=TensorShape(
dim=x[0].size(1), size=torch.Size([])
), # size agnostic dataset
Comment on lines +212 to +215
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
shapes=OperatorShapes(
x=TensorShape(
dim=x[0].size(1), size=torch.Size([])
), # size agnostic dataset
# size agnostic dataset
shapes=OperatorShapes(
x=TensorShape(dim=x[0].size(1), size=torch.Size([])),

u=TensorShape(dim=u[0].size(1), size=torch.Size([])),
y=TensorShape(dim=y[0].size(1), size=torch.Size([])),
v=TensorShape(dim=v[0].size(1), size=torch.Size([])),
),
n_observations=len(x),
)

def _is_valid_space(
self,
member: Union[torch.Tensor, List[torch.Tensor]],
values: Union[torch.Tensor, List[torch.Tensor]],
mask: Optional[torch.Tensor],
) -> bool:
"""Asseses whether a space is in alignment with its respective requirements.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Asseses whether a space is in alignment with its respective requirements.
"""Assesses whether a space is in alignment with its respective requirements.


Depending on whether a space is described by a list of tensors or a tensor certain argument need to align.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Depending on whether a space is described by a list of tensors or a tensor certain argument need to align.
Depending on whether a space is described by a list of tensors or a tensor certain arguments need to align.

All observations need to have the same dimensions in their respective samples. The domain and the function
values need to be described by the same number of observations.

Args:
x: Tensor of shape (num_observations, x_dim, num_sensors...) with sensor positions.
u: Tensor of shape (num_observations, u_dim, num_sensors...) with evaluations of the input functions at sensor positions.
y: Tensor of shape (num_observations, y_dim, num_evaluations...) with evaluation positions.
v: Tensor of shape (num_observations, v_dim, num_evaluations...) with ground truth operator mappings.
member: Tensor of locations.
values: Function values evaluated in the domain locations.
mask: Boolean mask where a True value indicates that a specific sample should be taken into consideration.

Returns:
Input samples with class transformations applied.
A boolean value True when the space description is valid.
"""
sample = {"x": x, "u": u, "y": y, "v": v}
assert type(member) is type(
values
), f"All types of tensors in one space need to match. But found {type(member)} and {type(values)}."
Comment on lines +243 to +245
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert type(member) is type(
values
), f"All types of tensors in one space need to match. But found {type(member)} and {type(values)}."
assert type(member) is type(values), (
f"All types of tensors in one space need to match, "
f"but found {type(member)} and {type(values)}.")


ndim: int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ndim: int

if mask is not None:
assert isinstance(
member, torch.Tensor
), f"When providing a mask the member and values need to be tensors. But found {type(member)}"
Comment on lines +249 to +251
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert isinstance(
member, torch.Tensor
), f"When providing a mask the member and values need to be tensors. But found {type(member)}"
assert isinstance(member, torch.Tensor), (
f"When providing a mask the member and values need to be tensors, but found {type(member)}")

ndim = member.dim() - 1 # remove batch dimension
else:
assert all(
[di.size(0) == member[0].size(0) for di in member[1:]]
), "Dimensions of all samples of the member need to match."
assert all(
[vi.size(0) == values[0].size(0) for vi in values[1:]]
), "Dimensions of all function values need to match."
ndim = member[0].dim()

assert (
ndim == 2
), f"{self.__class__.__name__} currently only supports exactly one dim and one size dimension."
Comment on lines +262 to +264
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that? So I cannot have shape of (batch_size, 3, 128, 128)?


return True

def _pad_list_space(
self, member: List[torch.Tensor], values: List[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transforms a space described by lists of tensors (image) to a space described by padded tensors and a mask.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(image)?


Args:
member: List of tensors describing the locations of the samples.
values: List of tensors describing the function values of the samples.

Returns:
padded member tensor, padded values tensor, and matching mask.
"""
assert not any(
[torch.any(torch.isinf(mi)) for mi in member]
), "Expects domain to be truncated in finite space."
Comment on lines +280 to +282
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this assertion necessary? Someone might come up with a good reason for using infs in the data, do we have to prevent that?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see we're using inf for padding. However, does it hurt to have more infs (non-masked) in the dataset?


member_padded = pad_sequence(
[mi.transpose(0, 1) for mi in member],
batch_first=True,
padding_value=torch.inf,
).transpose(1, 2)
values_padded = pad_sequence(
[vi.transpose(0, 1) for vi in values], batch_first=True, padding_value=0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[vi.transpose(0, 1) for vi in values], batch_first=True, padding_value=0
[vi.transpose(0, 1) for vi in values],
batch_first=True,
padding_value=0,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we pad once with inf and once with 0? Seems arbitrary

).transpose(1, 2)

mask = member_padded != torch.inf
member_padded[
~mask
] = 0 # mask often applied by adding a tensor with -inf values in masked locations (e.g. in scaled dot product).
Comment on lines +293 to +296
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is different here if why just used 0 for padding in l. 287?


return member_padded, values_padded, mask

def __getitem__(
self,
idx: int,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
"""Retrieves the input-output pair at the specified index and applies transformations.

Parameters:
idx: The index of the sample to retrieve.

Returns:
A tuple containing the three input tensors, the output tensor, and masks for both the input and output for
the given index.
"""
tensors = self._apply_transformations(
[
(self.x[idx], self.x_transform),
(self.u[idx], self.u_transform),
(self.y[idx], self.y_transform),
(self.v[idx], self.v_transform),
]
)

if self.ipt_mask is not None:
ipt_mask = self.ipt_mask[idx]
else:
ipt_mask = None

# transform
for dim, val in sample.items():
if dim in self.transform:
sample[dim] = self.transform[dim](val)
if self.opt_mask is not None:
opt_mask = self.opt_mask[idx]
else:
opt_mask = None

return sample["x"], sample["u"], sample["y"], sample["v"]
return tensors[0], tensors[1], tensors[2], tensors[3], ipt_mask, opt_mask
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return tensors[0], tensors[1], tensors[2], tensors[3], ipt_mask, opt_mask
return *tuple(tensors), ipt_mask, opt_mask

Loading
Loading