Skip to content

Latest commit

 

History

History
253 lines (207 loc) · 10.7 KB

README.md

File metadata and controls

253 lines (207 loc) · 10.7 KB

Example: AVICI for custom data-generating processes

This README covers a running example of training an AVICI model in causal discovery for a custom data-generating process. The components we provide here can be used as a starting point for a new project. As an illustrative example, we implement AVICI trained on SCMs with sinusoidal functions and random tree graphs.

This folder contains the following three files:


1) func.py: Python file defining the data-generating processes

This file contains all custom classes that make up the generative model of our domain. All functions should be written using standard numpy and not jax.numpy. This is both faster and avoids conflicting resource usage of the CPU workers that continually update the training data buffers and the hardware accelerators used by jax for the actual network training.

The provided func.py implements two example classes for sampling random trees and SCMs with sinusoidal functions, respectively, which are not implemented already in avici.synthetic. Each custom data-generating process must subclass one of the following two abstract base classes and implement the __call__ function with the correct signatures:

a) GraphModel

Subclasses of GraphModel implement functionality for sampling training graphs and can be used to define (part of) the causal graph distribution p(G). Each child class has to implement __call__ acceping two arguments:

  • rng (np.random.Generator) – numpy pseudorandom number generator
  • n_vars (int) – number of nodes in the graph

Returns:

  • ndarray – binary adjacency matrix of shape [n_vars, n_vars]

Example:

import numpy as onp
from avici.synthetic import GraphModel

class DummyGraph(GraphModel):
   def __call__(self, rng, n_vars):
       return onp.zeros((n_vars, n_vars))

b) MechanismModel

Subclasses of MechanismModel implement functionality for sampling observational and interventional data given a causal graph. These classes can be used to define (part of) the data-generating distribution p(D | G). Each child class has to implement __call__ acceping four arguments:

  • rng (np.random.Generator) – numpy pseudorandom number generator
  • g (ndarray) – binary adjacency matrix of shape [n_vars, n_vars] as generated by a GraphModel subclass
  • n_observations_obs (int) – number of observational data points to be sampled
  • n_observations_int (int) – number of interventional data points to be sampled

Returns:

  • avici.synthetic.Data – namedtuple containing x_obs, x_int and boolean is_count_data. The data matrices x_obs and x_int must have shapes [n_observations_obs, n_vars, 2] and [n_observations_int, n_vars, 2], respectively. The first value in the last axis (i.e. x_int[..., 0]) contains the values and the second axis (i.e. x_int[..., 1]) contains either 0 or 1, indicating which nodes were intervened upon in which observations.
    Accordingly, x_obs[..., 1] has only zeros as it always contains observational data.
    is_count_data is used to determine how the data is standardized. (Default for all real-valued data should be False, which implies the usual z-standardization.)

Example:

import numpy as onp
from avici.synthetic import MechanismModel, Data

class DummyMechanism(MechanismModel):
   def __call__(self, rng, g, n_observations_obs, n_observations_int):
       n_vars = g.shape[-1]
       return Data(
           x_obs=onp.zeros((n_observations_obs, n_vars, 2)),
           x_int=onp.zeros((n_observations_int, n_vars, 2)),
           is_count_data=False,
       )

Both GraphModel and MechanismModel subclasses can be initialized with and store an arbitrary number of arguments for later use inside __call__, like function parameters or other sampling functions. For GraphModel, this is also where additional details on the interventions ought to be specified, e.g., how many nodes are intervened upon and in what fashion.


2) domain.yaml: YAML file defining the training data distribution

This YAML file is the configuration file that defines the distribution over datasets our structure learning model is trained on. The file has to be structured in the following way:

---
train_n_vars: [5, 10]
test_n_vars: [20]
test_n_datasets: 10
additional_modules:
  - "./func.py"
data:
  - n_observations_obs: 300
    n_observations_int: 100
    graph:
      - __class__: ErdosRenyi
        edges_per_var: [ 1.0, 2.0, 3.0 ]
    mechanism:
      - __class__: LinearAdditive
        param:
          - __class__: SignedUniform
            low: 1.0
            high: 3.0
        bias:  ...
        noise:  ...
        noise_scale:  ...
        n_interv_vars:  ...
        interv_dist:  ...
      - ...

The top-level keywords specify the following:

  • train_n_vars – list of integers specifying the numbers of variables in the causal graphs and datasets during training
  • test_n_vars – list of integers specifying the numbers of variables used for validation
  • test_n_datasets – number of validation datasets
  • additional_modules – list of paths (relative or absolute) defining additional data-generating processes (e.g., our func.py file)
  • data – nested combination of dicts and lists specifying the full data-generating distribution

The data entry specifies the distribution over training datasets. During training, we continually generate fresh data for data buffers of the different numbers of variables according to this distribution. The configuration of the data field maintains the following invariants:

  1. If any (nested) part of the data tree is a list, one configuration of it is selected uniformly at random in each new sample. For example, in the above configuration, all graphs are Erdos-Renyi, in which the expected number of edges per node is either 1, 2, or 3, selected randomly for each new dataset. Internally, the nested dict of lists is expanded into a single list of all possible combinations of dicts, so be careful not to specify too many combinations (>1000).

  2. Each (list) element in the top level of _data needs to specify: graph, mechanism, n_observations_obs, and n_observations_int (satisfying the avici.synthetic.SyntheticSpec signature). The integers n_observations_obs and n_observations_int specify the number of data points generated for each dataset. At training time, these observations are subbatched further depending on the optimization parameters.

  3. Each (list) element in the top level of data.graph needs to define a GraphModel subclass, and each (list) element of data.mechanism a MechanismModel subclass. The class name is specified via the __class__ key. All other arguments the class expects at initialization time (via __init__) are specified alongside. Please refer to the signature of avici.synthetic.LinearAdditive to verify this in the above example.

    The class arguments may be (lists of) classes themselves, defined recursively in the same way. For example, avici.synthetic.Distribution subclasses specify how the weights and noise of the linear function SCM LinearAdditive is sampled. Likewise, avici.synthetic.NoiseModel subclasses specify the noise scale in the SCM.

  4. All classes not available inside avici.synthetic need to be defined in other files and specified via their path in the additional_modules field. When specified this way, they can be used in the configuration exactly like all other members already provided in avici.synthetic.

The easiest way of understanding how domain.yaml is configured is to look at a few examples. The following configurations define the training distributions of the models trained in Lorch et al., (2022), whose checkpoints are available for download via avici.load_pretrained: linear.yaml, rff.yaml, gene.yaml. These config files directly correspond to the Tables given in Appendix A of the paper.

Currently, we provide the following data-generating processes in avici.synthetic:

  • GraphModel subclasses:

    • ErdosRenyi
    • ScaleFree
    • ScaleFreeTranspose
    • WattsStrogatz
    • SBM
    • GRG
    • Yeast
    • Ecoli
  • MechanismModel subclasses:

    • LinearAdditive
    • RFFAdditive
    • GRNSergio
  • Distribution subclasses:

    • Gaussian
    • Laplace
    • Cauchy
    • Uniform
    • SignedUniform
    • RandInt
    • Beta
  • NoiseModel subclasses:

    • SimpleNoise
    • HeteroscedasticRFFNoise

These classes can be used in a domain.yaml configuration out-of-the-box and without further specifications.


3) train.py: Python training script

This is the main training script. Our provided script automatically performs multi-device training. Hence, if you run this script on a machine with mulitple GPUs, all accelerators will be used directly using jax.pmap and corresponding functions.

Given our above domain configuration, we can train a first (small) model to check the script by changing directory to example-custom/ and running

python train.py --config "./domain.yaml"

where --config specifies an (absolute or relative) path to our YAML domain configuration. The above call uses --smoke_test true by default, which sets the network and training parameters to small dummy values for testing. For further information about the other command line arguments, run python train.py --help.

To train a large model with the same hyperparameters as Lorch et al., (2022), run

python train.py --config "./domain.yaml" --smoke_test false

Each different n_vars of the training data distribution requires a seperate jax.jit compilation. Therefore, it is normal that the first few steps of training take relatively long. After each n_vars has been seen, update steps are fast (~0.5sec/step for the full model on Quadro RTX 6000 GPUs).

The script automatically generates checkpoints, which can be used both for continuing training and for downstream predictions. By default, the checkpoints are stored in ./checkpoints/. To re-start training with a checkpoint, simply re-run train.py with the same checkpoint directory and the code will automatically detect the most recent checkpoint.

Analogous to the pretrained checkpoints we provide for automatic download, the checkpoints created during training with this script can be loaded using the avici.load_pretrained function:

import avici
model = avici.load_pretrained(checkpoint_dir="path/to/checkpoint", expects_counts=False)