Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of the MVAE model #80

Open
miguelsvasco opened this issue Jul 26, 2019 · 3 comments
Open

Implementation of the MVAE model #80

miguelsvasco opened this issue Jul 26, 2019 · 3 comments

Comments

@miguelsvasco
Copy link

First of all thank you for the code! Just two slight remark:

  • The implementation of the MVAE model provided considers KL divergence terms between the distributions of each modality-specific encoder and the distribution of the Product-of-Experts (POE) encoder.

dkl

However, the original formulation of the MVAE model (in the paper Multimodal Generative Models for Scalable Weakly-Supervised Learning), does not consider such terms, only a KL divergence term between the distribution of the POE encoder and the prior:

elbo

When I remove the kl_x and kl_y terms from the regularizer and train, the model seems unable to perform cross-modality inference:

cmi

  • The authors of the paper also mention a subsampling training procedure, where the total loss function for a batch is the sum of three different losses: (1) a joint-loss, where image and text are given as input to the model, (2) an image loss, where only image is given as input to the model, (3) a text loss, where only text is given as input to the model. Would that be possible to implement with the Pixyz framework?
@masa-su
Copy link
Owner

masa-su commented Aug 6, 2019

@miguelsvasco
Hi, thank you for your detailed comments and I'm sorry for my late reply.

However, the original formulation of the MVAE model (in the paper Multimodal Generative Models for Scalable Weakly-Supervised Learning), does not consider such terms, only a KL divergence term between the distribution of the POE encoder and the prior

Yes. this loss function comes from not MIVAE but JMVAE (originally proposed in this paper as JMVAE-kl). Though the PoE encoder is not used in the original paper of JMVAE, we wanted to see if this PoE encoder works well on the JMVAE loss. Anyway, I'm sorry for the confusion.

When I remove the kl_x and kl_y terms from the regularizer and train, the model seems unable to perform cross-modality inference:

This might be due to not training "unimodal" inferences of the PoE encoder, q(z|x) and q(z|y). Without it, inferred z from unimodal input (especially label or attribute) might be collapsed (a similar issue is also referred to our preprint paper as the "missing modality difficulty").
In the JMVAE, these are trained by making close them to "bimodal" inference q(z|x,y), which corresponds to the additional KL terms you pointed out.

Would that be possible to implement with the Pixyz framework?

Yes, but you should use the Model class instead of the VAE class because the loss function becomes more complex.
The implementation of the original MVAE model with Pixyz is as follows.
スクリーンショット 2019-08-06 13 19 35

Given your comments, I replaced the name of the previous notebook from mvae_poe.ipynb to jmvae_poe.ipynb (to avoid confusion), and added the new notebook mvae.ipynb which includes the implementation of the original MVAE model.

Thank you!

@sgalkina
Copy link

@masa-su
Thank you for the framework.
For the MVAE implementation you provided above, how the model should be trained for the semi-supervised case? Let's say for the MNIST dataset only a share of labels is available. Should two Model objects which share the networks but have different loss functions be created for 1) the image and the label available and 2) only the label available?

@masa-su
Copy link
Owner

masa-su commented Oct 15, 2019

@sgalkina
Thank you for your comment!
I don't know what kind of loss functions for each supervised and unsupervised you are going to implement, but you can use the replace_var method in the Distribution class to share the same network in different losses, e.g., supervised and unsupervised losses.

For an example of the usage, please see the implementation of the M2 model, which is the well-known semi-supervised VAE model.

If you have any trouble understanding how to use it, please feel free to ask!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants