Skip to content

Commit

Permalink
Format code
Browse files Browse the repository at this point in the history
  • Loading branch information
Baschdl committed Mar 19, 2024
1 parent df6890f commit 832bdc7
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/notebooks/applications/roitman_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def filter_roitman_data(
rt = df[(df["coherence"] == coherence)]["rt"].values / 1000

decision = df[(df["animal"] == animal) & (df["coherence"] == coherence)]["decision"].values
decision = df[ (df["coherence"] == coherence)]["decision"].values
decision = df[(df["coherence"] == coherence)]["decision"].values

if n_trial == "all":
mask = np.ones(len(rt), dtype=bool)
Expand Down
4 changes: 2 additions & 2 deletions labproject/external/inception_v3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'''
"""
Code sourced from https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py
'''
"""

import torch
import torch.nn as nn
Expand Down
16 changes: 10 additions & 6 deletions labproject/metrics/gaussian_squared_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@


@register_metric("wasserstein_gauss_squared")
def gaussian_squared_w2_distance(real_samples: Tensor, fake_samples: Tensor, real_mu=None, real_cov=None) -> Tensor:
def gaussian_squared_w2_distance(
real_samples: Tensor, fake_samples: Tensor, real_mu=None, real_cov=None
) -> Tensor:
r"""
Compute the squared Wasserstein distance between Gaussian approximations of real and fake samples.
Dimensionality of the samples must be the same and >=2 (for covariance calculation).
Expand Down Expand Up @@ -48,16 +50,18 @@ def gaussian_squared_w2_distance(real_samples: Tensor, fake_samples: Tensor, rea
# check input (n,d only)
assert len(real_samples.size()) == 2, "Real samples must be 2-dimensional, (n,d)"
assert len(fake_samples.size()) == 2, "Fake samples must be 2-dimensional, (n,d)"

if real_samples.shape[-1] == 1:
mu_real = real_samples.mean(dim=0)
var_real = real_samples.var(dim=0)

mu_fake = fake_samples.mean(dim=0)
var_fake = fake_samples.var(dim=0)

w2_squared_dist = (mu_real - mu_fake)**2 + (var_real + var_fake - 2 * (var_real * var_fake).sqrt())


w2_squared_dist = (mu_real - mu_fake) ** 2 + (
var_real + var_fake - 2 * (var_real * var_fake).sqrt()
)

return w2_squared_dist
else:
# calculate mean and covariance of real and fake samples
Expand Down

0 comments on commit 832bdc7

Please sign in to comment.