-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Learning the two-moons distribution with a normalizing flow #2
Comments
Ok well, turns out I had a colab notebook with everything in one place: https://colab.research.google.com/drive/1HPom85QIjugHaL2RkO-5TWle6ZeoVBWC?usp=sharing Can you see if it is working for you? and if so, can you add your version of this to this repo? |
Ah, and instead of using the sklearn two moons dataset, you can use the pure TFP one from this notebook: The advantage is that it will also allow you to compute gradients ;-) |
Ok thanks a lot ! I will look into all of this :) |
Learning the two moons from tensorflow using RealNVP |
Learning the two moons from tensorflow using RealNVP + using the score |
I can't use @jax.jit for the get_batch function (from this notebook : https://colab.research.google.com/drive/1t4DaL02o31OCOFifDaQS2B1f_QN5-iRq?usp=sharing ), when I use it I get this error : 'IndexError: tuple index out of range' |
Could you try the following? @jax.jit
def get_batch(batch_size, seed):
batch = get_two_moons(sigma= 0.05).sample(batch_size, seed=seed)
score = jax.vmap(jax.grad(two_moons.log_prob))(batch)
return batch, score Maybe it's an issue coming from the fact that you build the distribution outside of the jitted function |
So, @Justinezgh, I think you are already pretty much all setup to start some fun research and experiments, so I want to show you some preliminary work on this stuff we did with @b-remy last year. We were testing a technique called denoising score matching to learn the score field (not the distribution itself), and we did some tests against what a conventional Normalizing Flow could achieve. Here is a relevant plot: It shows that when training a Normalizing Flow just for density estimation, the score field can go all wonky. Also, if you think about the change of variable formula in a Normalizing Flow, the score will have two terms, one that comes from the inverse mapping, and one that comes from the Jacobian determinant. For a RealNVP, @b-remy also made this plot: |
This makes me think we can take as a first angle of attack is to check that for a given choice of normalizing flow architecture, the log density is indeed correctly continuously differentiable. And thinking about the log determinant term is probably a good idea. You can also have a look at one of the seminal papers on score matching: https://www.cs.helsinki.fi/u/ahyvarin/papers/JMLR05.pdf |
Impact of the nb of coupling layers (affine coupling layers) on the score field |
Same but with Neural Spline Flows : https://colab.research.google.com/drive/1IFDmsNUTsHIjQpjnXKIAG3PUeyx6NLux?usp=sharing And I still have problems with @jax.jit
|
batch_size=512
@jax.jit
def get_batch(seed):
two_moons = get_two_moons(sigma= 0.05)
batch = two_moons.sample(batch_size, seed=seed)
return batch simple fix for the jitting of of get_batch, removing the batch_size argument |
@Justinezgh this is all super interesting. Two questions:
|
Hi @Justinezgh , note that if it is more convenient, you can also keep the import functools
@functools.partial(jax.jit, static_argnums=(1,))
def get_batch(seed, batch_size):
two_moons = get_two_moons(sigma= 0.05)
batch = two_moons.sample(batch_size, seed=seed)
return batch |
Sorry, I was too curious.... I quickly tried to train a regression network under a score matching loss to make sure things were not crazy. And it seems to work pretty well: The loss function nicely goes to zero instead of jumping around as in the NF examples. Training the grads of the NNFor fun I also tried to train the same network, but making it output just a scalar and constraining its gradients, and that doesnt train at all: Training the grads of the NN with a C\infty neural networkAnd for even more fun, I tried training the same model again by constraining the grads, replacing relu activation by a And BAM! By magic it works \o/ And training goes super easily: => All codes available here: #4 |
Ahaha yep ^^ sorry, this had been bugging me all afternoon and was dying to try, it's pretty fun stuff :-) |
So, the next logical step is to build a NF that is by construction Cinfty. @b-remy reminded me of this paper: https://arxiv.org/pdf/2110.00351.pdf where they actually propose a coupling layer that should be continuously differentiable, to place in a RealNVP. Probably worthwhile to take a look. |
So just to see, I tried to use the sin activation function for the NN of the affine coupling layer : For the NF with 3 coupling layers : Notebook : https://colab.research.google.com/drive/1ZU-w76vJ81-PArB9vr1x9fi7qpZ1AOnu?usp=sharing |
(just for the record, what I said there about MAF was stupid, you still have an Affine Coupling with a MAF) |
Just to be sure, this is this function that we want to place in a RealNVP ? If yes, do we want to use f to define the shift and the scale part ? Because both shift and scale are R^d -> R^(D-d) so do we have to do some kind of projection for f(x) ? Like we define f(x_i) := (1-c).((g(x_i)-g(0))/(...)) + c.x_i so f(x) \in R^d and then we project in R^(D-d) ? Actually I'm not sure that f was made to be used in a RealNVP, idk.. |
That's a good question. And yes that's the coupling we might want to use, instead of an affine coupling. So you don't generate shift and scale parameters, instead you generate these a,b,c parameters which are the outputs of some neural network which takes R^d inputs and return R^(D-d) outputs, and the function g is a bijection in R^(D-d). You can have a look at how the Spline flows work, it's a bit different, but illustrates how a parametrisation can deviate from affine. |
It may not be 100% trivial because I think you would have to define a TFP bijector to implement the mapping f. It shouldn't be too difficult, but will take a bit of coding. Ah, and there is another approach we could take i think.... We could use a ffjord, and there is one easily usable in the TF version of TFP (so not in Jax unfortunately). I think if the ode function is sufficiently smooth, so is the ode flow. |
+1 I was also thinking that Continuous Normalizing Flows (the flow of transformations being continuous here) such as Neural ODE (1806.07366) or FFJORD (1810.01367) would be an interesting approach to look at in parallel! |
I'm not sure that the function f(x) = (1-c)((g(x)-g(0))/(g(1)-g(0)))+cx has an analytical inverse. |
hummmmmmmmmmmmm that sounds surprising |
ok, maybe the exp is hard to find an analytical inv ^^' the monomial should be easier, and otherwise we could impllement a generic purpose inverse function, with gradients computed by the implici function theorem. |
;-) wink wink @b-remy |
I've been looking at ffjord, and we can indeed observe that working with a Continuous Normalizing Flow, which makes smooth transformations, yields a smoother score function!
Here I used maximum likelihood only, no score matching loss because I have not figured out how to implement it with tensorflow yet... Maybe we should open a specific issue dedicated to ODE flows, to discuss different loss functions or how the gradients are actually computed. And maybe consider implementing a JAX version because taking gradients, or computing vjp, is not as easy in TF :-) |
Yep @b-remy agreed, we can open a separate issue to discuss using an ODE flow for this :-) @Justinezgh do you have some news on building an invertible coupling? If not analytically possible, we can use an implicit function trick to define the gradients of a numerical inverse. @b-remy already has experience with this, it's a little bit more involved, but if we don;t analytic inverses it should work. |
I think the best I can do is rho(x) = x**2 :/ https://colab.research.google.com/drive/1kRA4ReFryVqFJfLxwtL7nXg-Uwsn1sal?usp=sharing ( I can't specify the domain if a,b and c are symbols) I was trying to compute f^-1 as a function of x,a,b,c in order to use it directly in the bijector but I didn't manage to do it and I don't think that sympy is jaxssifiable Ok so the best I can do is rho(x) = x**3 |
\o/ x^3 should work for our purposes! and maybe x^2 is actually enough... we just need one more order of smoothness than the typical affine coupling. Let's see what this gives us in practice in a bijector :-) |
Really awesome that you used sympy for solving this BTW! |
Any luck with implementing a bijector? Don't hesitate if you have questions ;-) |
I have "some" bugs :D https://colab.research.google.com/drive/1cmtlXbH-xX7s7m7MtiL4DWyD_UriSoIg?usp=sharing When I try to train the NF I have this error (1024 is the batch size) : 'ValueError: The arguments to _cofactor_solve must have shapes a=[..., m, m] and b=[..., m, m]; got a=(1024, 1, 1024, 1) and b=(1024, 1, 1024, 1)' So I tried with batch_size = 1 and I noticed that the loss is NaN. I tried the same thing with an easier bijector Exp() and I have the same pb for the loss. So I tried to print() everything in the NN to get a,b,c and for some reason the initialization part fails |
so several things,
|
did it help ^^' ? |
yup ! And so now I'm dealing with a new pb :D |
This looks good Justine, but I didn't quite get your point about x in a given range... To keep things simple for now, can we define a flow that remains between (0,1) ? |
@Justinezgh here are some examples I have lying around of building a nomalizing flow in jax, and training it on the two moons distribution:
[1] full implementation of a NF in JAX+flax but it is kind of outdated: https://github.com/EiffL/jax-nf (see this notebook in particular https://github.com/EiffL/jax-nf/blob/master/notebooks/Vanilla-NVP.ipynb)
[2] notebook with a NF implementaion in JAX+haiku: https://github.com/EiffL/Quarks2CosmosDataChallenge/blob/main/notebooks/PartII-GenerativeModels-Solution.ipynb (see
Step III: Latent Normalizing Flow
and ignore all the rest, it just shows you how to build a NF with jax and haiku)So I would say, you can try to rewrite a small notebook, using 1 as an example for how to generate examples from the two moons dataset, and 2 for an example of a slightly better implementation using haiku
Learning objectives:
The text was updated successfully, but these errors were encountered: