For a quick tryout of the ZigZag method, check the Colab notebooks below!
Whereas the ability of deep networks to produce useful predictions on many kinds of data has been amply demonstrated, estimating the reliability of these predictions remains challenging. Sampling approaches such as MC-Dropout and Deep Ensembles have emerged as the most popular ones for this purpose. Unfortunately, they require many forward passes at inference time, which slows them down. Sampling-free approaches can be faster but often suffer from other drawbacks, such as lower reliability of uncertainty estimates, difficulty of use, and limited applicability to different types of tasks and data.
In this work, we introduce a sampling-free approach that is generic and easy to deploy, while producing reliable uncertainty estimates on par with state-of-the-art methods at a significantly lower computational cost. It is predicated on training the network to produce the same output with and without additional information about it. At inference time, when no prior information is given, we use the network's own prediction as the additional information. We then take the distance between the predictions with and without prior information as our uncertainty measure.
ZigZaging: At inference time, we make two forward passes. First, we use
Motivation: The second pass reconstructs the second input, expecting lower error for in-distribution data and higher for out-of-distribution, enabling uncertainty estimation. When given a correct label with input , the network, trained to minimize the difference between outputs, indicates in-distribution data. If is incorrect, this out-of-distribution sample prompts an unpredictable response, which we use to gauge uncertainty. This mechanism addresses both epistemic uncertainty when is OOD and aleatoric uncertainty when is errornous.
Integrating ZigZag into standard models is notably straightforward, requiring only minimal modifications to the first layer to accept an additional input. This simplicity enables the model to efficiently make two types of predictions—initially without and then with its own previous outputs as inputs.
Original architecture:
class MLP(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.out_features = out_features
self.input = nn.Linear(in_features, 8)
self.hidden1 = nn.Linear(8, 16)
self.hidden2 = nn.Linear(16, 8)
self.output = nn.Linear(8, out_features)
self.activation = nn.ReLU()
def forward(self, x, y=None):
x = self.activation(self.input(x))
x = self.activation(self.hidden1(x))
x = self.activation(self.hidden2(x))
return self.output(x)
Modified architecture:
class ZigZagMLP(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.out_features = out_features
# Modifying the first layer
self.input = nn.Linear(in_features + out_features, 8)
self.hidden1 = nn.Linear(8, 16)
self.hidden2 = nn.Linear(16, 8)
self.output = nn.Linear(8, out_features)
self.activation = nn.ReLU()
def forward(self, x, y=None):
# Adding the second input
if y is None:
batch = x.shape[0]
y = torch.zeros((batch, self.out_features))
x = torch.cat([x, y], dim=1)
x = self.activation(self.input(x))
x = self.activation(self.hidden1(x))
x = self.activation(self.hidden2(x))
return self.output(x)
python -m venv zigzag
source zigzag/bin/activate
pip install torch==2.2.1+cpu torchvision torchaudio -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install numpy matplotlib
Uncertainty Estimation for Regression: The task is to regress
We trained three different models—ZigZag, an ensemble of simple MLPs, and MC Dropout—on various UCI datasets including Boston Housing, Yacht, Power Plant, and Energy to predict their respective target variables. We split the data into in- and out-of-distribution samples based on the values of a specific feature. We then used uncertainty measures and AUC metrics to evaluate the accuracy of uncertainty predictions.
MNIST vs FashionMNIST: We train the networks on MNIST and compute the accuracy and calibration metrics (rAULC). We then use the uncertainty measure they produce to classify images from the test sets of MNIST and FashionMNIST as being within the MNIST distribution or not to compute the OOD metrics, ROC- and PR-AUCs. We use a standard architecture with several convolution and pooling layers, followed by fully connected layers with LeakyReLU activations.
If you find this code useful, please consider citing our paper:
Durasov, Nikita, et al. "ZigZag: Universal Sampling-free Uncertainty Estimation Through Two-Step Inference." TMLR 20224.
@article{durasov2024zigzag,
title = {ZigZag: Universal Sampling-free Uncertainty Estimation Through Two-Step Inference},
author = {Nikita Durasov and Nik Dorndorf and Hieu Le and Pascal Fua},
journal = {Transactions on Machine Learning Research},
issn = {2835-8856},
year = {2024}
}