Skip to content

Commit

Permalink
update wasserstein metrics documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
gmoss13 committed Mar 6, 2024
1 parent 7d7842b commit 5c1314a
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion labproject/metrics/sliced_wasserstein.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# STOLEN from Julius: https://github.com/mackelab/wasserstein_source/blob/main/wasser/sliced_wasserstein.py
# this implementation is from https://github.com/mackelab/sourcerer/blob/main/sourcerer/sliced_wasserstein.py
# Removed numpy dependency

import torch
Expand Down
1 change: 0 additions & 1 deletion labproject/metrics/wasserstein_kuhn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


# Implementation taken from https://python.plainenglish.io/hungarian-algorithm-introduction-python-implementation-93e7c0890e15
# TODO: implement fully in pytorch for differentiability


def min_zero_row(zero_mat, mark_zero):
Expand Down
2 changes: 2 additions & 0 deletions labproject/metrics/wasserstein_sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import warnings
from labproject.metrics.utils import register_metric

# This implementation is adapted from https://github.com/gpeyre/SinkhornAutoDiff/blob/master/sinkhorn_pointcloud.py


def sinkhorn_algorithm(
x: torch.Tensor, y: torch.Tensor, epsilon: float = 1e-3, niter: int = 1000, p: int = 2
Expand Down

0 comments on commit 5c1314a

Please sign in to comment.