Skip to content

Commit

Permalink
mmd median heuristic
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler committed Feb 7, 2024
1 parent c8a2fbd commit 47b93c3
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions labproject/metrics/MMD_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def compute_rbf_mmd(x, y, bandwidth=1.0):
mmd = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
return mmd

@register_metric("mmd_rbf_median_heuristic")
def compute_rbf_mmd_median_heuristic(x, y):
# https://arxiv.org/pdf/1707.07269.pdf
median = torch.median(torch.cdist(x, y))
return compute_rbf_mmd(x, y, median)


@register_metric("mmd_rbf_auto")
def compute_rbf_mmd_auto(x, y, bandwidth=1.0):
Expand Down

0 comments on commit 47b93c3

Please sign in to comment.