Skip to content

Commit

Permalink
add energy-based mmd kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
felixp8 authored Jun 4, 2024
1 parent 1bfb253 commit 56946e8
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions labproject/metrics/MMD_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ def linear_kernel(x, y):
return x @ y.t()


def energy_kernel(x, y):
x_norm = torch.linalg.norm(x, dim=-1)
y_norm = torch.linalg.norm(y, dim=-1)
return x_norm[:, None] + y_norm[None, :] - torch.cdist(x, y)


def median_heuristic(x, y):
return torch.median(torch.cdist(x, y))

Expand Down Expand Up @@ -71,3 +77,12 @@ def compute_linear_mmd_naive(x, y):
def compute_linear_mmd(x, y):
delta = torch.mean(x, 0) - torch.mean(y, 0)
return torch.norm(delta, 2) ** 2


@register_metric("mmd_energy")
def compute_energy_mmd(x, y):
x_kernel = energy_kernel(x, x)
y_kernel = energy_kernel(y, y)
xy_kernel = energy_kernel(x, y)
mmd = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
return mmd

0 comments on commit 56946e8

Please sign in to comment.