Reference implementation for the paper "Multi-Object Representation Learning with Iterative Variational Inference". This repository contains:
- An IODINE implementation in Tensorflow v1.
- Configurations used in the paper (checkpoints available in Cloud Storage) for:
- CLEVR
- Multi-dSprites
- Tetrominoes
- A notebook for running and inspecting the model and plotting the results
-
Clone the DeepMind research repository:
git clone https://github.com/deepmind/deepmind-research.git cd deepmind-research
-
Download the checkpoints from GCP. A shell script is provided:
./iodine/download_checkpoints.sh
On platforms without wget, the files can be downloaded from this webpage and the unzipped
checkpoints/
folder should be placed indeepmind-research/iodine/checkpoints
. -
Prepare a Python 3 environment - virtualenv is recommended.
python3 -m venv iodine_venv source iodine_venv/bin/activate
-
Install dependencies:
pip3 install -r iodine/requirements.txt
-
The
multi_object_datasets
package installed via requirements.txt provides python code to open the data files, but not the data files themselves. Download the desired datasets either manually from the Google Cloud Storage or using the commands below:pushd iodine/multi_object_datasets # CLEVR wget https://storage.googleapis.com/multi-object-datasets/clevr_with_masks/clevr_with_masks_train.tfrecords # Multi-dSprites wget https://storage.googleapis.com/multi-object-datasets/multi_dsprites/multi_dsprites_colored_on_grayscale.tfrecords # Tetrominoes wget https://storage.googleapis.com/multi-object-datasets/tetrominoes/tetrominoes_train.tfrecords # Get back to location containing 'iodine' directory popd
See multi_object_datasets repository for further details.
-
Make sure that you have CUDA 10 and CuDNN 7 installed
Use the jupyter notebook Eval.ipynb
to load and run one of the checkpoints.
It also contains code to plot the outputs and latent traversals.
To train your own model use the Sacred experiment defined in main.py
.
The configurations used in the paper for the different datasets are available as named configs inside of configuration.py
.
-
CLEVR6
python3 -m iodine.main -f with clevr6
-
Multi-dSprites
python3 -m iodine.main -f with multi_dsprites
-
Tetrominoes
python3 -m iodine.main -f with tetrominoes
It is recommended to add an observer to your run to let Sacred record the details of run.
To add a FileStorageObserver add -F my_storage_dir
, and add -m my_db_name
for a MongoObserver.
The experiment has a configuration that can be printed and adjusted from the commandline. E.g.:
# print configuration
python3 -m iodine.main -f print_config with clevr6
# run experiment after adjusting batch_size and the size of the shuffle buffer
python3 -m iodine.main -f with clevr6 batch_size=2 data.shuffle_buffer=100
Each run stores checkpoints and summaries in the directory specified by checkpoint_dir
, to which a suffix based on the run_id is appended.
If an observer is added the run_id
is set automatically. Otherwise it should be set manually using e.g. run_id=5
.
Summaries can be viewed using tensorboard. E.g. like this for clevr6 (assuming run_id=1
):
tensorboard --log-dir iodine/checkpoints/clevr6_1
To continue a previous run pass continue_run=True
and the path of the checkpoints:
python3 -m iodine.main -f with clevr6 checkpoint_dir=iodine/checkpoints/clevr6_1
The main experiment defined in main.py
uses sacred
and the configurations for the different datasets are added as named configs and can be found in configuration.py
.
The model implementation can be found in the modules
directory and is based on tensorflow
and sonnet
:
iodine.py
The main IODINE module that assembles the decoder, refinement network, distributions and factor regressor.decoder.py
The ComponentDecoder which is a wrapper around networks that takes care of splitting the output channels into means and masks.refinement.py
The refinement components assembles the encoder network, LSTM and refinement head.networks.py
Different standard networks such as CNN, BroadcastCNN, and LSTM.distribution.py
Definition of the latent and pixel distributions.factor_eval.py
Contains the factor regressor which predicts the true factors from the inferred object latents.data.py
Dataset wrappers aroundmulti_object_datasets
that take care of shuffling, batching and preprocessing.plotting.py
Helper functions for plotting results.utils.py
General helper functions.
DISCLAIMER
This is not an officially supported Google product.