Skip to content

Commit

Permalink
Support for distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler authored and michaeldeistler committed Feb 5, 2024
1 parent 528bb17 commit fe4b53a
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 39 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ __pycache__/

# Distribution / packaging
.Python
data/
plots/
build/
develop-eggs/
Expand Down
126 changes: 87 additions & 39 deletions labproject/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
## Hetzner Storage Box API functions ----

DATASETS = {}
DISTRIBUTIONS = {}


def upload_file(local_path: str, remote_path: str):
Expand Down Expand Up @@ -130,6 +131,47 @@ def get_dataset(name: str) -> torch.Tensor:
return DATASETS[name]


def register_distribution(name: str) -> callable:
r"""This decorator wrapps a function that should return a dataset and ensures that the dataset is a PyTorch tensor, with the correct shape.
Args:
func (callable): Dataset generator function
Returns:
callable: Dataset generator function wrapper
Example:
>>> @register_dataset("random")
>>> def random_dataset(n=1000, d=10):
>>> return torch.randn(n, d)
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Call the original function
distribution = func(*args, **kwargs)
return distribution

DISTRIBUTIONS[name] = wrapper
return wrapper

return decorator


def get_distribution(name: str) -> torch.Tensor:
r"""Get a distribution by name
Args:
name (str): Name of the distribution
Returns:
torch.Tensor: Distribution
"""
assert name in DISTRIBUTIONS, f"Distribution {name} not found, please register it first "
return DISTRIBUTIONS[name]


def load_cifar10(
n: int, save_path="data", train=True, batch_size=100, shuffle=False, num_workers=1, device="cpu"
) -> torch.Tensor:
Expand Down Expand Up @@ -186,45 +228,51 @@ def random_dataset(n=1000, d=10):
return torch.randn(n, d)


@register_dataset("toy_2d")
def toy_mog_2D(n=1000, d=2):
"""Generate samples from a 2D mixture of 4 Gaussians that look funky.
Args:
n (int): number of samples to generate
d (int): dimensionality of the samples, always 2. Changing it does nothing.
Returns:
tensor: samples of shape (num_samples, 2)
"""
means = torch.tensor(
[
[0.0, 0.5],
[-3.0, -0.5],
[0.0, -1.0],
[-4.0, -3.0],
]
)
covariances = torch.tensor(
[
[[1.0, 0.8], [0.8, 1.0]],
[[1.0, -0.5], [-0.5, 1.0]],
[[1.0, 0.0], [0.0, 1.0]],
[[0.5, 0.0], [0.0, 0.5]],
]
)
weights = torch.tensor([0.2, 0.3, 0.3, 0.2])

# Create a list of 2D Gaussian distributions
gaussians = [
MultivariateNormal(mean, covariance) for mean, covariance in zip(means, covariances)
]

# Sample from the mixture
categorical = Categorical(weights)
sample_indices = categorical.sample([n])
samples = torch.stack([gaussians[i].sample() for i in sample_indices])
return samples
@register_distribution("normal")
def normal_distribution():
return torch.distributions.Normal(0, 1)


@register_distribution("toy_2d")
def normal_distribution():
class Toy2D:
def __init__(self):
self.means = torch.tensor(
[
[0.0, 0.5],
[-3.0, -0.5],
[0.0, -1.0],
[-4.0, -3.0],
]
)
self.covariances = torch.tensor(
[
[[1.0, 0.8], [0.8, 1.0]],
[[1.0, -0.5], [-0.5, 1.0]],
[[1.0, 0.0], [0.0, 1.0]],
[[0.5, 0.0], [0.0, 0.5]],
]
)
self.weights = torch.tensor([0.2, 0.3, 0.3, 0.2])

# Create a list of 2D Gaussian distributions
self.gaussians = [
MultivariateNormal(mean, covariance)
for mean, covariance in zip(self.means, self.covariances)
]

def sample(self, sample_shape):
# Sample from the mixture
categorical = Categorical(self.weights)
sample_indices = categorical.sample(sample_shape)
return torch.stack([self.gaussians[i].sample() for i in sample_indices])

def log_prob(self, input):
probs = torch.stack([g.log_prob(input).exp() for g in self.gaussians])
probs = probs.T * self.weights
return torch.sum(probs, dim=1).log()

return Toy2D()


@register_dataset("cifar10_train")
Expand Down

0 comments on commit fe4b53a

Please sign in to comment.