diff --git a/labproject/metrics/MMD_torch.py b/labproject/metrics/MMD_torch.py index 0f60c2b..0d79eb7 100644 --- a/labproject/metrics/MMD_torch.py +++ b/labproject/metrics/MMD_torch.py @@ -19,6 +19,12 @@ def linear_kernel(x, y): return x @ y.t() +def energy_kernel(x, y): + x_norm = torch.linalg.norm(x, dim=-1) + y_norm = torch.linalg.norm(y, dim=-1) + return x_norm[:, None] + y_norm[None, :] - torch.cdist(x, y) + + def median_heuristic(x, y): return torch.median(torch.cdist(x, y)) @@ -71,3 +77,12 @@ def compute_linear_mmd_naive(x, y): def compute_linear_mmd(x, y): delta = torch.mean(x, 0) - torch.mean(y, 0) return torch.norm(delta, 2) ** 2 + + +@register_metric("mmd_energy") +def compute_energy_mmd(x, y): + x_kernel = energy_kernel(x, x) + y_kernel = energy_kernel(y, y) + xy_kernel = energy_kernel(x, y) + mmd = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel) + return mmd \ No newline at end of file