Neural solver implementation #564
Unanswered
pietrocipolla
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi!
I want to implement the OT neural solver, but I am a newbie to JAX. In particular, I am interested in the
ott.neural.methods.neuraldual.W2NeuralDual
function. In the tutorial, it is not clear how to define the data loader class for the neural solvers. In my problem, I create a sample using the following function:Then, I want to compute the distance between the unconditioned distribution of data in columns
output[:,1:2]
and the distribution obtained by fixing the first variable (output[output[:,0]==m,1:2]
). This problem can be easily solved using other OTT solvers such as the Sinkhorn's. However, it is not clear to me how to implement the data loader for the neural solver. I hope that you can help me.Beta Was this translation helpful? Give feedback.
All reactions