This folder contains the code for 'Scalable Variational Approaches for Bayesian Causal Discovery'.
To install, use conda with
conda env create -f environment.yml
.
If this fails for some reason, the key packages are
jax jaxlib ott-jax cdt sklearn matplotlib optax dm-haiku tensorflow_probability torch wandb cython fuzzywuzzy python-Levenshtein sumu lingam
You may have to recompile the cython module for the Hungarian algorithm by running
cython -3 mine.pyx
and then
g++ -shared -pthread -fPIC -fwrapv -O3 -Wall -fno-strict-aliasing -o mine.so mine.c
in the c_modules
directory.
Run with the --use_wandb
flag to write results to a new weights and biases project. Otherwise, the results will be printed to stout.In utils.py
you may need to uncomment line 11 and replace your path to the Rscript
binary
To run BCD Nets and GOLEM experiments in figure 1, for one random seed use arguments such as
python main.py -s 0 --n_data 100 --dim 32 --degree 1 --num_steps 30000 --do_ev_noise --sem_type linear-gauss --batch_size 256 --print_golem_solution --degree 1
To run the baselines, run
python main.py --eval_eid --run_baselines --n_data 100 --dim 32 --sem_type linear-gauss --only_baselines --degree 2 --do_ev_noise --n_baseline_seeds 3
To run GOLEM, run
python main.py --eval_eid --print_golem_solution --n_data 100 --dim 32 --sem_type linear-gauss --degree 2 --do_ev_noise --num_steps 10
To run on the Sachs dataset, include the --use_sachs
flag.