Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning the two-moons distribution with a normalizing flow #2

Open
EiffL opened this issue Nov 8, 2021 · 40 comments
Open

Learning the two-moons distribution with a normalizing flow #2

EiffL opened this issue Nov 8, 2021 · 40 comments
Assignees

Comments

@EiffL
Copy link
Contributor

EiffL commented Nov 8, 2021

@Justinezgh here are some examples I have lying around of building a nomalizing flow in jax, and training it on the two moons distribution:

So I would say, you can try to rewrite a small notebook, using 1 as an example for how to generate examples from the two moons dataset, and 2 for an example of a slightly better implementation using haiku

Learning objectives:

  • Learn how to implement a Normalizing Flow
  • Get familiar with TensorFlow Probability distributions
  • Get familiar with Jax and DeepMind's Haiku
@EiffL
Copy link
Contributor Author

EiffL commented Nov 8, 2021

Ok well, turns out I had a colab notebook with everything in one place: https://colab.research.google.com/drive/1HPom85QIjugHaL2RkO-5TWle6ZeoVBWC?usp=sharing

Can you see if it is working for you? and if so, can you add your version of this to this repo?

@EiffL
Copy link
Contributor Author

EiffL commented Nov 8, 2021

Ah, and instead of using the sklearn two moons dataset, you can use the pure TFP one from this notebook:
https://colab.research.google.com/drive/1yRsh1Kmb6O1J6Rx3v1hX7-oS9cQUyGiM?usp=sharing

The advantage is that it will also allow you to compute gradients ;-)

image

@Justinezgh
Copy link
Contributor

Ok thanks a lot ! I will look into all of this :)

@Justinezgh
Copy link
Contributor

Justinezgh commented Nov 15, 2021

Learning the two moons from tensorflow using RealNVP
https://colab.research.google.com/drive/1E2o54mt8KHlnWkwJCaEpzBunmTR3NmWC?usp=sharing

@Justinezgh
Copy link
Contributor

Justinezgh commented Nov 15, 2021

Learning the two moons from tensorflow using RealNVP + using the score
https://colab.research.google.com/drive/1t4DaL02o31OCOFifDaQS2B1f_QN5-iRq?usp=sharing

@Justinezgh
Copy link
Contributor

I can't use @jax.jit for the get_batch function (from this notebook : https://colab.research.google.com/drive/1t4DaL02o31OCOFifDaQS2B1f_QN5-iRq?usp=sharing ), when I use it I get this error : 'IndexError: tuple index out of range'

@EiffL
Copy link
Contributor Author

EiffL commented Nov 15, 2021

Could you try the following?

@jax.jit
def get_batch(batch_size, seed):
  batch = get_two_moons(sigma= 0.05).sample(batch_size, seed=seed)
  score = jax.vmap(jax.grad(two_moons.log_prob))(batch)
  return batch, score

Maybe it's an issue coming from the fact that you build the distribution outside of the jitted function

@EiffL
Copy link
Contributor Author

EiffL commented Nov 15, 2021

So, @Justinezgh, I think you are already pretty much all setup to start some fun research and experiments, so I want to show you some preliminary work on this stuff we did with @b-remy last year.

We were testing a technique called denoising score matching to learn the score field (not the distribution itself), and we did some tests against what a conventional Normalizing Flow could achieve. Here is a relevant plot:
image
(from this notebook: https://github.com/b-remy/score-estimation-comparison/blob/normalizing_flows/notebooks/NF-DAE-SN-comparison.ipynb)

It shows that when training a Normalizing Flow just for density estimation, the score field can go all wonky. Also, if you think about the change of variable formula in a Normalizing Flow, the score will have two terms, one that comes from the inverse mapping, and one that comes from the Jacobian determinant. For a RealNVP, @b-remy also made this plot:
image
(https://github.com/b-remy/score-estimation-comparison/blob/normalizing_flows/notebooks/NFlows_where_come_from_the_failures.ipynb)
which shows that the determinant part seems be responsible for most of the bad behavior, it probably implies that the particular shape of the RealNVP determinant is not very regular.

@EiffL
Copy link
Contributor Author

EiffL commented Nov 15, 2021

This makes me think we can take as a first angle of attack is to check that for a given choice of normalizing flow architecture, the log density is indeed correctly continuously differentiable. And thinking about the log determinant term is probably a good idea.

You can also have a look at one of the seminal papers on score matching: https://www.cs.helsinki.fi/u/ahyvarin/papers/JMLR05.pdf

@Justinezgh
Copy link
Contributor

Justinezgh commented Nov 19, 2021

Impact of the nb of coupling layers (affine coupling layers) on the score field
: https://colab.research.google.com/drive/1H0Q_hgb0Yjtqvyg9RKeqTvt5lSZBNiap?usp=sharing

@Justinezgh
Copy link
Contributor

Justinezgh commented Nov 22, 2021

Same but with Neural Spline Flows : https://colab.research.google.com/drive/1IFDmsNUTsHIjQpjnXKIAG3PUeyx6NLux?usp=sharing

And I still have problems with @jax.jit

  • for the get_batch() : 'Non-shape-like value: Traced<ShapedArray(int32[2])>with<DynamicJaxprTrace(level=0/1)> (type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>)'
  • for the loss_fn() and update() : 'Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>, 2].
    If using jit, try using static_argnums or applying jit to smaller subfunctions.'

@EiffL
Copy link
Contributor Author

EiffL commented Nov 22, 2021

batch_size=512

@jax.jit
def get_batch(seed):
  two_moons = get_two_moons(sigma= 0.05)
  batch = two_moons.sample(batch_size, seed=seed)
  return batch

simple fix for the jitting of of get_batch, removing the batch_size argument

@EiffL
Copy link
Contributor Author

EiffL commented Nov 22, 2021

@Justinezgh this is all super interesting. Two questions:

  • Have you checked (at least theoretically) that the log prob of a normalizing flow using a realNVP is at least twice differentiable? The leaky-relu for instance shouldn't be, and I think it actually has 0 second order gradients (so 0 gradients of the score) almost everywhere. Which could explain why we are having difficulties training on the score.

  • Can you try to learn the score field with a simple regression network instead of a Normalizing Flow, so directly training a function s_\theta(x) to learn the score field, with a dense neural network for instance. If this works well, it means that there is nothing in principle wrong with the score matching loss, and that if there are difficulties, they must come from the particular architecture of the Normalizing Flow.

@b-remy
Copy link
Collaborator

b-remy commented Nov 22, 2021

Hi @Justinezgh , note that if it is more convenient, you can also keep the batch_size argument by specifying to jit that it is a static argument.

import functools

@functools.partial(jax.jit, static_argnums=(1,))
def get_batch(seed, batch_size):
  two_moons = get_two_moons(sigma= 0.05)
  batch = two_moons.sample(batch_size, seed=seed)
  return batch

@EiffL
Copy link
Contributor Author

EiffL commented Nov 22, 2021

Sorry, I was too curious.... I quickly tried to train a regression network under a score matching loss to make sure things were not crazy. And it seems to work pretty well:
image

The loss function nicely goes to zero instead of jumping around as in the NF examples.
image

Training the grads of the NN

For fun I also tried to train the same network, but making it output just a scalar and constraining its gradients, and that doesnt train at all:
image

image

Training the grads of the NN with a C\infty neural network

And for even more fun, I tried training the same model again by constraining the grads, replacing relu activation by a sin function, as proposed in https://arxiv.org/abs/2006.09661

And BAM! By magic it works \o/
image
(note: for the background on the right I use exp( scalar output of the network ))

And training goes super easily:
image

=> All codes available here: #4

@Justinezgh
Copy link
Contributor

https://colab.research.google.com/drive/1OnL56FPKzinJrnL16xFdYXcBiKSPsOmy?usp=sharing :)

@EiffL
Copy link
Contributor Author

EiffL commented Nov 22, 2021

Ahaha yep ^^ sorry, this had been bugging me all afternoon and was dying to try, it's pretty fun stuff :-)

@EiffL
Copy link
Contributor Author

EiffL commented Nov 23, 2021

So, the next logical step is to build a NF that is by construction Cinfty.

@b-remy reminded me of this paper: https://arxiv.org/pdf/2110.00351.pdf where they actually propose a coupling layer that should be continuously differentiable, to place in a RealNVP. Probably worthwhile to take a look.

@Justinezgh
Copy link
Contributor

Justinezgh commented Nov 23, 2021

So just to see, I tried to use the sin activation function for the NN of the affine coupling layer :

image

For the NF with 3 coupling layers :

image

Notebook : https://colab.research.google.com/drive/1ZU-w76vJ81-PArB9vr1x9fi7qpZ1AOnu?usp=sharing

@EiffL
Copy link
Contributor Author

EiffL commented Nov 23, 2021

interesting interesting yeah, it doesn't seem to help directly :-/

So here is what they say in section 4 of (2110.00351):
image
image

So what they are saying is that with an affine coupling layer, you lose expressivity in the gradients of the log p. And they also say that the Neural Spline Flows have poor gradients because they are only C1.

So I think we could try the following: Using a C\infty coupling layer, and training under the Score Matching loss (because whether or not the flow can train under the SM loss on its own will tell us if the model is well adapted).

Alternatively, it might be possible to use a MAF instead of a RealNVP, because it's possible that if the masked autoencoder in the MAF layer is Cinfty, so will be the flow

@EiffL
Copy link
Contributor Author

EiffL commented Nov 23, 2021

(just for the record, what I said there about MAF was stupid, you still have an Affine Coupling with a MAF)

@Justinezgh
Copy link
Contributor

Justinezgh commented Nov 23, 2021

Just to be sure, this is this function that we want to place in a RealNVP ?

image

If yes, do we want to use f to define the shift and the scale part ? Because both shift and scale are R^d -> R^(D-d) so do we have to do some kind of projection for f(x) ? Like we define f(x_i) := (1-c).((g(x_i)-g(0))/(...)) + c.x_i so f(x) \in R^d and then we project in R^(D-d) ?

Actually I'm not sure that f was made to be used in a RealNVP, idk..

@EiffL
Copy link
Contributor Author

EiffL commented Nov 24, 2021

That's a good question. And yes that's the coupling we might want to use, instead of an affine coupling.

So you don't generate shift and scale parameters, instead you generate these a,b,c parameters which are the outputs of some neural network which takes R^d inputs and return R^(D-d) outputs, and the function g is a bijection in R^(D-d).

You can have a look at how the Spline flows work, it's a bit different, but illustrates how a parametrisation can deviate from affine.

@EiffL
Copy link
Contributor Author

EiffL commented Nov 24, 2021

It may not be 100% trivial because I think you would have to define a TFP bijector to implement the mapping f. It shouldn't be too difficult, but will take a bit of coding.

Ah, and there is another approach we could take i think.... We could use a ffjord, and there is one easily usable in the TF version of TFP (so not in Jax unfortunately). I think if the ode function is sufficiently smooth, so is the ode flow.

@b-remy
Copy link
Collaborator

b-remy commented Nov 24, 2021

+1 I was also thinking that Continuous Normalizing Flows (the flow of transformations being continuous here) such as Neural ODE (1806.07366) or FFJORD (1810.01367) would be an interesting approach to look at in parallel!

@Justinezgh
Copy link
Contributor

Justinezgh commented Dec 2, 2021

I'm not sure that the function f(x) = (1-c)((g(x)-g(0))/(g(1)-g(0)))+cx has an analytical inverse.
At leat for rho(x) = exp(-1/alpha*x**beta)

@EiffL
Copy link
Contributor Author

EiffL commented Dec 4, 2021

hummmmmmmmmmmmm that sounds surprising

@EiffL
Copy link
Contributor Author

EiffL commented Dec 4, 2021

ok, maybe the exp is hard to find an analytical inv ^^' the monomial should be easier, and otherwise we could impllement a generic purpose inverse function, with gradients computed by the implici function theorem.

@EiffL
Copy link
Contributor Author

EiffL commented Dec 4, 2021

;-) wink wink @b-remy

@b-remy
Copy link
Collaborator

b-remy commented Dec 6, 2021

I've been looking at ffjord, and we can indeed observe that working with a Continuous Normalizing Flow, which makes smooth transformations, yields a smoother score function!

image
https://colab.research.google.com/drive/1nCs0UH8CfToW6L4ZNehzERBdIx84Eg6k?usp=sharing

Here I used maximum likelihood only, no score matching loss because I have not figured out how to implement it with tensorflow yet...

Maybe we should open a specific issue dedicated to ODE flows, to discuss different loss functions or how the gradients are actually computed. And maybe consider implementing a JAX version because taking gradients, or computing vjp, is not as easy in TF :-)

@EiffL
Copy link
Contributor Author

EiffL commented Dec 6, 2021

Yep @b-remy agreed, we can open a separate issue to discuss using an ODE flow for this :-)
We can use this as a plan B, if plan A of using custom coupling layer doesn't work.

@Justinezgh do you have some news on building an invertible coupling? If not analytically possible, we can use an implicit function trick to define the gradients of a numerical inverse. @b-remy already has experience with this, it's a little bit more involved, but if we don;t analytic inverses it should work.

@Justinezgh
Copy link
Contributor

Justinezgh commented Dec 6, 2021

I think the best I can do is rho(x) = x**2 :/

https://colab.research.google.com/drive/1kRA4ReFryVqFJfLxwtL7nXg-Uwsn1sal?usp=sharing

( I can't specify the domain if a,b and c are symbols)

I was trying to compute f^-1 as a function of x,a,b,c in order to use it directly in the bijector but I didn't manage to do it and I don't think that sympy is jaxssifiable
-> Ok I think I just managed to do it

Ok so the best I can do is rho(x) = x**3

@EiffL
Copy link
Contributor Author

EiffL commented Dec 6, 2021

\o/ x^3 should work for our purposes! and maybe x^2 is actually enough... we just need one more order of smoothness than the typical affine coupling.

Let's see what this gives us in practice in a bijector :-)

@EiffL
Copy link
Contributor Author

EiffL commented Dec 7, 2021

Really awesome that you used sympy for solving this BTW!

@EiffL
Copy link
Contributor Author

EiffL commented Dec 8, 2021

Any luck with implementing a bijector? Don't hesitate if you have questions ;-)

@Justinezgh
Copy link
Contributor

I have "some" bugs :D

https://colab.research.google.com/drive/1cmtlXbH-xX7s7m7MtiL4DWyD_UriSoIg?usp=sharing

When I try to train the NF I have this error (1024 is the batch size) : 'ValueError: The arguments to _cofactor_solve must have shapes a=[..., m, m] and b=[..., m, m]; got a=(1024, 1, 1024, 1) and b=(1024, 1, 1024, 1)'

So I tried with batch_size = 1 and I noticed that the loss is NaN. I tried the same thing with an easier bijector Exp() and I have the same pb for the loss. So I tried to print() everything in the NN to get a,b,c and for some reason the initialization part fails

@EiffL
Copy link
Contributor Author

EiffL commented Dec 8, 2021

so several things,

  • batch_size=1 is for sure going to give you unstable traininig
  • An exp bijector is almost surely going to explode during trainining, so you might have nans very quickly
  • This way of computing the log det is probably going to be instable: jnp.log(jnp.abs(jnp.linalg.det(jax.jacfwd(f, argnums = 0)(x,self.a, self.b, self.c)))) is there an analytic jacobian for this bijector ?
  • The reason why it fails for your batch size >1 is probably because of the jax.jacfwd, you need to make sure you compute a batched jacobian matrix, it should give you [1024, 1, 1], not [1024,1,1024,1]

@EiffL
Copy link
Contributor Author

EiffL commented Dec 9, 2021

did it help ^^' ?

@Justinezgh
Copy link
Contributor

Justinezgh commented Dec 9, 2021

yup !
I computed the gradients with Sympy : https://colab.research.google.com/drive/1URrqY8TVf0EbtO2DHpjqEnR4jIvs2j-P?usp=sharing
I don't know if it's faster to have the Jacobian for both f and f^-1 or to use the fact that forward_log_det_jacobian is the negative of inverse_log_det_jacobian, evaluated at f^{-1}(y).

And so now I'm dealing with a new pb :D
image
Just have to find a way to have x \in [-1/2a+b, 1/2a+b]

@EiffL
Copy link
Contributor Author

EiffL commented Dec 10, 2021

This looks good Justine, but I didn't quite get your point about x in a given range... To keep things simple for now, can we define a flow that remains between (0,1) ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants