diff --git a/docs/notebooks/applications/roitman_utils.py b/docs/notebooks/applications/roitman_utils.py index bd9b513..863943c 100644 --- a/docs/notebooks/applications/roitman_utils.py +++ b/docs/notebooks/applications/roitman_utils.py @@ -33,7 +33,7 @@ def filter_roitman_data( rt = df[(df["coherence"] == coherence)]["rt"].values / 1000 decision = df[(df["animal"] == animal) & (df["coherence"] == coherence)]["decision"].values - decision = df[ (df["coherence"] == coherence)]["decision"].values + decision = df[(df["coherence"] == coherence)]["decision"].values if n_trial == "all": mask = np.ones(len(rt), dtype=bool) diff --git a/labproject/external/inception_v3.py b/labproject/external/inception_v3.py index cb9d4fe..29553ec 100644 --- a/labproject/external/inception_v3.py +++ b/labproject/external/inception_v3.py @@ -1,6 +1,6 @@ -''' +""" Code sourced from https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py -''' +""" import torch import torch.nn as nn diff --git a/labproject/metrics/gaussian_squared_wasserstein.py b/labproject/metrics/gaussian_squared_wasserstein.py index fe9d8b4..8bcbf07 100644 --- a/labproject/metrics/gaussian_squared_wasserstein.py +++ b/labproject/metrics/gaussian_squared_wasserstein.py @@ -5,7 +5,9 @@ @register_metric("wasserstein_gauss_squared") -def gaussian_squared_w2_distance(real_samples: Tensor, fake_samples: Tensor, real_mu=None, real_cov=None) -> Tensor: +def gaussian_squared_w2_distance( + real_samples: Tensor, fake_samples: Tensor, real_mu=None, real_cov=None +) -> Tensor: r""" Compute the squared Wasserstein distance between Gaussian approximations of real and fake samples. Dimensionality of the samples must be the same and >=2 (for covariance calculation). @@ -48,16 +50,18 @@ def gaussian_squared_w2_distance(real_samples: Tensor, fake_samples: Tensor, rea # check input (n,d only) assert len(real_samples.size()) == 2, "Real samples must be 2-dimensional, (n,d)" assert len(fake_samples.size()) == 2, "Fake samples must be 2-dimensional, (n,d)" - + if real_samples.shape[-1] == 1: mu_real = real_samples.mean(dim=0) var_real = real_samples.var(dim=0) - + mu_fake = fake_samples.mean(dim=0) var_fake = fake_samples.var(dim=0) - - w2_squared_dist = (mu_real - mu_fake)**2 + (var_real + var_fake - 2 * (var_real * var_fake).sqrt()) - + + w2_squared_dist = (mu_real - mu_fake) ** 2 + ( + var_real + var_fake - 2 * (var_real * var_fake).sqrt() + ) + return w2_squared_dist else: # calculate mean and covariance of real and fake samples