Skip to content

Commit

Permalink
Merge pull request #35 from aai-institute/join/feature-dataset
Browse files Browse the repository at this point in the history
Join: feature dataset
  • Loading branch information
JakobEliasWagner authored Feb 14, 2024
2 parents 7acac5d + 9df9edf commit f7f371e
Show file tree
Hide file tree
Showing 23 changed files with 864 additions and 609 deletions.
196 changes: 121 additions & 75 deletions notebooks/basics.ipynb

Large diffs are not rendered by default.

156 changes: 102 additions & 54 deletions notebooks/physicsinformed.ipynb

Large diffs are not rendered by default.

96 changes: 61 additions & 35 deletions notebooks/selfsupervised.ipynb

Large diffs are not rendered by default.

121 changes: 74 additions & 47 deletions notebooks/superresolution.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions src/continuity/benchmarks/sine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Sine benchmark."""

from continuity.benchmarks import Benchmark
from continuity.data import split
from continuity.data.datasets import Sine
from continuity.data import Sine, split
from continuity.operators.losses import Loss, MSELoss
from torch.utils.data import Dataset

Expand Down
25 changes: 18 additions & 7 deletions src/continuity/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
import os
import torch

from .dataset import OperatorDataset, SelfSupervisedOperatorDataset
from .shape import DatasetShapes
from .sine import Sine
from .flame import Flame, FlameDataLoader

__all__ = [
"OperatorDataset",
"SelfSupervisedOperatorDataset",
"DatasetShapes",
"Sine",
"Flame",
"FlameDataLoader",
"device",
"split",
]


def get_device() -> torch.device:
"""Get torch device.
Expand All @@ -34,11 +50,6 @@ def get_device() -> torch.device:
device = get_device()


def tensor(x):
"""Default conversion for tensors."""
return torch.tensor(x, dtype=torch.float32)


def split(dataset, split=0.5, seed=None):
"""
Split data set into two parts.
Expand Down Expand Up @@ -70,7 +81,7 @@ def dataset_loss(dataset, operator, loss_fn):
loss = 0.0

for x, u, y, v in dataset:
batch_size = x.shape[0]
loss += loss_fn(operator, x, u, y, v) / batch_size
x, u, y, v = x.unsqueeze(0), u.unsqueeze(0), y.unsqueeze(0), v.unsqueeze(0)
loss += loss_fn(operator, x, u, y, v)

return loss
161 changes: 161 additions & 0 deletions src/continuity/data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""
`continuity.data`
Data sets in Continuity.
Every data set is a list of `(x, u, y, v)` tuples.
"""

import torch
import torch.utils.data as td
from typing import Tuple

from .shape import DatasetShapes, TensorShape


class OperatorDataset(td.Dataset):
"""A dataset for operator training.
In operator training, at least one function is mapped onto a second one. To fulfill the properties discretization
invariance, domain independence and learn operators with physics-based loss access to at least four different
discretized spaces is necessary. One on which the input is sampled (x), the input function sampled on these points
(u), the discretization of the output space (y), and the output of the operator (v) sampled on these points. Not
all loss functions and/or operators need access to all of these attributes.
Args:
x: Tensor of shape (#observations, #sensors, x-dim) with sensor positions.
u: Tensor of shape (#observations, #sensors, u-dim) with evaluations of the input functions at sensor positions.
y: Tensor of shape (#observations, #evaluations, y-dim) with evaluation positions.
v: Tensor of shape (#observations, #evaluations, v-dim) with ground truth operator mappings.
Attributes:
shapes (dataclass): Shape of all tensors.
transform (dict): Transformations for each tensor.
"""

def __init__(
self,
x: torch.Tensor,
u: torch.Tensor,
y: torch.Tensor,
v: torch.Tensor,
x_transform=None,
u_transform=None,
y_transform=None,
v_transform=None,
):
assert x.ndim == u.ndim == y.ndim == v.ndim == 3, "Wrong number of dimensions."
assert (
x.size(0) == u.size(0) == y.size(0) == v.size(0)
), "Inconsistent number of observations."
assert x.size(1) == u.size(1), "Inconsistent number of sensors."
assert y.size(1) == v.size(1), "Inconsistent number of evaluations."

super().__init__()

self.x = x
self.u = u
self.y = y
self.v = v

# used to initialize architectures
self.shapes = DatasetShapes(
num_observations=int(x.size(0)),
x=TensorShape(*x.size()[1:]),
u=TensorShape(*u.size()[1:]),
y=TensorShape(*y.size()[1:]),
v=TensorShape(*v.size()[1:]),
)

self.transform = {
dim: tf
for dim, tf in [
("x", x_transform),
("u", u_transform),
("y", y_transform),
("v", v_transform),
]
if tf is not None
}

def __len__(self) -> int:
"""Return the number of samples.
Returns:
number of samples in the entire set.
"""
return self.shapes.num_observations

def __getitem__(
self, idx
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Retrieves the input-output pair at the specified index and applies transformations.
Parameters:
- idx: The index of the sample to retrieve.
Returns:
A tuple containing the three input tensors and the output tensor for the given index.
"""
sample = {
"x": self.x[idx],
"u": self.u[idx],
"y": self.y[idx],
"v": self.v[idx],
}

# transform
for dim, val in sample.items():
if dim in self.transform:
sample[dim] = self.transform[dim](val)

return sample["x"], sample["u"], sample["y"], sample["v"]


class SelfSupervisedOperatorDataset(OperatorDataset):
"""
A `SelfSupervisedOperatorDataset` is a data set that contains data for self-supervised learning.
Every data point is created by taking one sensor as a label.
Every observation consists of tuples `(x, u, y, v)`, where `x` contains the sensor
positions, `u` the sensor values, and `y = x_i` and `v = u_i` are
the label's coordinate its value for all `i`.
Args:
x: Sensor positions of shape (num_observations, num_sensors, coordinate_dim)
u: Sensor values of shape (num_observations, num_sensors, num_channels)
"""

def __init__(self, x: torch.Tensor, u: torch.Tensor):
self.num_observations = u.shape[0]
self.num_sensors = u.shape[1]
self.coordinate_dim = x.shape[-1]
self.num_channels = u.shape[-1]

# Check consistency across observations
for i in range(self.num_observations):
assert (
x[i].shape[-1] == self.coordinate_dim
), "Inconsistent coordinate dimension."
assert (
u[i].shape[-1] == self.num_channels
), "Inconsistent number of channels."

xs, us, ys, vs = [], [], [], []

for i in range(self.num_observations):
# Add one data point for every sensor
for j in range(self.num_sensors):
y = x[i][j].unsqueeze(0)
v = u[i][j].unsqueeze(0)

xs.append(x[i])
us.append(u[i])
ys.append(y)
vs.append(v)

xs = torch.stack(xs)
us = torch.stack(us)
ys = torch.stack(ys)
vs = torch.stack(vs)

super().__init__(xs, us, ys, vs)
143 changes: 0 additions & 143 deletions src/continuity/data/datasets.py

This file was deleted.

Loading

0 comments on commit f7f371e

Please sign in to comment.