A Unified Framework for Implicit Sinkhorn Differentiation #125
Replies: 1 comment 1 reply
-
Hi Adam, and a big thanks for using OTT-JAX! thanks also for this reference! This paper is very cool! it has a lot of very nice illustrations of the interest of implicit differentiation, when applied to the Sinkhorn algorithm in the balanced setting. If you are suggesting to use it to implement it in OTT-JAX, I have good news for your: implicit diff is already there :) Implicit differentiation is the default method to differentiate the outputs of the sinkhorn algorithm (except the If you would like to check the inner workings of this generic approach that applies to any formal input of the Sinkhorn algorithm, you can take a look at https://github.com/ott-jax/ott/blob/main/ott/core/implicit_differentiation.py In that file we rely a lot on jax's Another nice pointer for an "automated implicit differentiation" that goes beyond the sinkhorn algorithm can be found in JAXOPT. |
Beta Was this translation helpful? Give feedback.
-
In this new paper, the authors propose a "framework
based on the most general formulation of the Sinkhorn operator.
It allows for any type of loss function, while both the
target capacities and cost matrices are differentiated jointly.
We further construct error bounds of the resulting algorithm
for approximate inputs. Finally, we demonstrate that for a
number of applications, simply replacing automatic differentiation
with our algorithm directly improves the stability
and accuracy of the obtained gradients. Moreover, we show
that it is computationally more efficient, particularly when
resources like GPU memory are scarce."
Here is the pytorch implementation.
https://github.com/marvin-eisenberger/implicit-sinkhorn
Beta Was this translation helpful? Give feedback.
All reactions