-
Notifications
You must be signed in to change notification settings - Fork 0
/
methods.tex
186 lines (170 loc) · 15.1 KB
/
methods.tex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
\section{Defining calibration at the DeepProbLog model level}
In section \ref{subsection_calibration_definition}, we defined perfectly calibrated (NN) classifiers and several metrics to measure to what degree (NN) classifiers are that. From this definition one of a perfectly calibrated DeepProbLog model trivially follows. Consider that given a KB and a query $Q$ with output variable $X$, DeepProbLog will use the KB to return a value $V$ for $X$ with a correctness probability (or confidence) $p$. Hence, if $X$ is a discrete variable we can consider the DeepProbLog model to act as a (multiclass) classifier with prediction confidences, allowing us to apply all of section \ref{subsection_calibration_definition}'s definitions and metrics onto it.
\section{The role of DeepProbLog model neural networks' calibratedness in whole-model calibratedness}
We know that (Deep)ProbLog inference is based on sound weighted \#SAT solving. This implies that if, for all $Q$s $p$ is accurate, then that must mean the learned unknown probabilities $\alpha_{i}$ in the model are accurate and that the members of the set $N$ of the model's NNs must, on average, be well-calibrated. To summarise, to calibrate a DeepProbLog model we must estimate the unknown probabilities $\alpha_{i}$ as good as possible and calibrate its constituent NNs. We will have to do the latter in our experiments, in which we will assume the null hypothesis that DeepProbLog's model training process miscalibrates model NNs and separate calibration is required.
\section{DeepProbLog application neural network calibration implementation scheme}
If DeepProbLog's model training process does not calibrate the model NNs, then general calibration methods must be applied to the model NNs as either a pre, post or intermediary processing step. We will apply the temperature scaling method discussed in section \ref{subsubsection_platt_temperature_scaling}. As \cite{guo2017calibration} showed, this method exhibits first-rate performance while being relatively computationally cheap and thus seems the best choice to apply in the context of DeepProbLog, which is already an innately computationally expensive framework. \par
In our experiments, we will derive datasets for each NN with
\begin{enumerate}
\item the inputs to the NN generated during training query inference and
\item the correct targets derived from the inputs in an exogenous, example specific fashion.
\end{enumerate}
We will call these datasets raw NN datasets. Having these raw NN datasets allows us to apply temperature scaling to the individual model NNs. We will try to apply it after each model training iteration (calibration as intermediary processing) and after the entire training process has finished (calibration as post processing). When calibration is applied, we must and will exclude the part of the training queries set that was used to generate the raw model NN datasets from the model training set. DeepProbLog defines a wrapper class \texttt{deepproblog.network.Network} for standard PyTorch network modules that makes those NNs that can be included in its models. We subclassed it to create the abstract class \texttt{deepproblog.calibrated\_network.CalibratedNetwork} that defines a model NN which can be calibrated both during and after model training. We provided the concrete class \\ \texttt{deepproblog.calibrated\_network.TemperatureScalingNetwork} as an implementation of \texttt{deepproblog.calibrated\_network.CalibratedNetwork} that applies temperature scaling. In order to monitor individual model NN properties during training, we defined the abstract base class \\ \texttt{deepproblog.networks\_evolution\_collector.NetworksEvolutionCollector}. \\ We modified DeepProbLog's model training code so that the callback methods in that abstract base class are called at certain points in the training cycle for all instances of it passed to the model training object or function. For instance, we defined a concrete subclass \texttt{deepproblog.calibrated\_network.NetworkECECollector} to monitor individual model NN ECE during training. \par
In section \ref{managing_querying_DeepProbLog_from_python} of Appendix A we showed listing \ref{lst:MNIST_addition_source_code} to demonstrate the usage of DeepProbLog KBs in Python. We provide listing \ref{lst:MNIST_addition_calibration_source_code} as a version of the same example that applies post-training model NN calibration using the aforementioned classes we provided.
\begin{lstlisting}[language=Python, caption={deepproblog/src/deepproblog/examples/MNIST/basic\_addition\_calibration.py, the post training calibrated basic MNIST single-digit addition example}, numbers=left, label={lst:MNIST_addition_calibration_source_code}, captionpos=b]
import torch
from torch.utils.data import DataLoader as TorchDataLoader
from deepproblog.calibrated_network import TemperatureScalingNetwork, NetworkECECollector
from deepproblog.dataset import DataLoader
from deepproblog.engines import ExactEngine
from deepproblog.evaluate import get_confusion_matrix
from deepproblog.model import Model
from deepproblog.train import train_model
from deepproblog.examples.MNIST.data import MNISTOperator, MNIST_train, MNIST_test, RawMNISTValidationDataset
from deepproblog.examples.MNIST.network import MNIST_Net
if __name__ == "__main__":
# General DeepProbLog flow - step 1
# Load and create dataset objects (of (sub)type deepproblog.dataset.Dataset)
train_dataset = MNISTOperator(
dataset_name = "train",
function_name = "addition",
operator = sum,
size = 1,
arity = 2
)
test_dataset = MNISTOperator(
dataset_name = "test",
function_name = "addition",
operator = sum,
size = 1,
arity = 2
)
# General DeepProbLog flow - step 2
# Create data loader objects based on the dataset objects (of (sub)type deepproblog.dataset.DataLoader)
batch_size = 2
train_dataloader = DataLoader(train_dataset, batch_size, False)
# For calibrating the model, create a regular PyTorch data loader
# based on a raw NN dataset. Note that we use data that is no longer
# used in training or test queries, because this is a requirement of
# temperature scaling.
validation_loader_for_calibration = TorchDataLoader(RawMNISTValidationDataset(), batch_size)
# General DeepProbLog flow - step 3
# Create PyTorch NN objects
MNIST_net_pytorch = MNIST_Net()
# General DeepProbLog flow - step 4
# Create DeepProbLog network objects (type deepproblog.network.Network) based on the PyTorch NNs
# For calibrating the model, use subclass of subclass
# deepproblog.calibrated_network.CalibratedNetwork, such as
# deepproblog.calibrated_network.TemperatureScalingNetwork to apply temperature scaling,
# and define a deepproblog.calibrated_network.NetworkECECollector to monitor model NN ECE evolution
# as a deepproblog.networks_evolution_collector.NetworksEvolutionCollector
# during training.
networks_evolution_collectors = {}
MNIST_net = TemperatureScalingNetwork(MNIST_net_pytorch, "mnist_net", validation_loader_for_calibration, batching = True, calibrate_after_each_train_iteration = False)
MNIST_net.optimizer = torch.optim.Adam(MNIST_net_pytorch.parameters(), lr = 1e-3)
networks_evolution_collectors["calibration_collector"] = NetworkECECollector(
{"mnist_net": validation_loader_for_calibration},
iteration_collect_iter = 100
)
# General DeepProbLog flow - step 5
# Construct a DeepProbLog model object (type deepproblog.model.Model) based on the KB file and the DeepProbLog Network objects
model = Model("models/addition.pl", [MNIST_net])
# General DeepProbLog flow - step 6
# Create an engine object and add it to the model. An approximate (class deepproblog.engines.ApproximateEngine) and exact (deepproblog.engines.ExactEngine) inference engine are provided in the standard DeepProbLog distribution. Both engines have cases in which they are or are not appropriate.
model.set_engine(ExactEngine(model), cache = True)
# General DeepProbLog flow - step 7
# Couple tensor source objects to the model using its add_tensor_source method
model.add_tensor_source("train", MNIST_train)
model.add_tensor_source("test", MNIST_test)
# General DeepProbLog flow - step 8
# Use the deepproblog.train.train_model function to train the DeepProbLog model (which means optimizing the unknown model probabilities/parameters and the model's NNs' weights for model accuracy on a test set of queries)
# For calibrating the model, we pass along the networks evolution collectors
# to monitor model NN ECE evolution during training
train = train_model(
model,
train_dataloader,
1,
networks_evolution_collectors,
log_iter = 100,
profile = 0
)
model.save_state(f"snapshot/basic_MNIST_model.pth")
accuracy = get_confusion_matrix(model, test_dataset, verbose = 0).accuracy()
print(f"Done.\nThe model acccuracy was {accuracy}.")
for networks_evolution_collector in train.networks_evolution_collectors.values():
ece_histories = networks_evolution_collector.collection_as_dict()["ece_history"]
for network_name in ece_histories:
initial_ECE = ece_histories[network_name][0]
final_ECE = ece_histories[network_name][-1]
print(f"Model NN {network_name} initial ECE was {initial_ECE}")
print(f"Model NN {network_name} final ECE was {final_ECE}")
\end{lstlisting}
\section{Experimental design and setup}
\label{experimental_design_and_setup_section}
We want to study the effect of calibration on DeepProbLog models in a representative set of use cases. In section \ref{literature_review_applications_of_interest} we established symbolic knowledge \& structured data processing, visual information processing and natural language processing as the top three use-cases for symbolic AI systems. It just so happens that the set of examples shipped with DeepProbLog at the time of writing fell in these categories:
\begin{itemize}
\item Symbolic knowledge \& structured data processing
\begin{itemize}
\item The addition program induction example
\item The sorting program induction example
\item The word algebra problem solving program induction example
\end{itemize}
\item Visual information processing
\begin{itemize}
\item The MNIST addition example
\item The noisy MNIST addition example
\item The hand-written formula (HWF) example
\item The coins example
\item The poker example
\end{itemize}
\item Natural language processing
\begin{itemize}
\item The CLUTRR family relationships induction example
\end{itemize}
\end{itemize}
For this reason, we will carry out our experiments based on these example applications shipped with DeepProbLog:
\begin{itemize}
\item The MNIST addition example\par
Given two pictures of handwritten numbers, the DeepProbLog model is queried for their sum.
\item The Noisy MNIST addition example\par
Given two pictures of handwritten numbers, the DeepProbLog model is queried for their sum. However, now label noise is applied to the output variable of the training queries.
\item The Hand-written formula (HWF) example\par
Given three pictures of two numbers and an operator, the DeepProbLog model is queried for the product of the resulting formula. This experiment is based on \cite{li2020hwf}.
\item The coins example\par
Given a picture of two coins, the DeepProbLog model is queried for their visible side being the same.
\item The poker example\par
Given a picture of a hand of cards in poker, the DeepProbLog model is queried for the (most likely) outcome of a poker game when the player has this hand.
\item The program induction examples
Given three incomplete DeepProbLog programs:
\begin{itemize}
\item adding numbers
\item sorting numbers
\item solving word algebra problems (WAPs)
\end{itemize}
complete them with learned neural predicates. This example is based on \cite{riedel2016forth}.
\item The CLUTRR family relationships induction example\par
Given a text and a list of entities in this text (in this case people), the DeepProbLog model is queried for the relationships (in this case family relationships) between the entities. This example is based on \cite{sinha2019clutrr}.
\end{itemize}
For each of these examples, we will investigate
\begin{itemize}
\item The difference in final model accuracy over all different calibration models.
\item The difference in average model NN ECE before and after calibration
\item The co-evolution of model loss and average model NN ECE
\end{itemize}
With calibration models we mean the following three calibration strategies:
\begin{itemize}
\item No calibration at all
\item Calibration of every model NN after each training iteration
\item Calibration of every model NN after model training is finished
\end{itemize}\par
As a final experiment we will compare the calibratedness (ECE) of a DenseNet trained on the CIFAR100 dataset that was calibrated once to one which was calibrated twice in a row.\par
The result of the investigations of the examples will show us whether or not applying temperature scaling (and then, by extension, probably all calibration methods) can or cannot improve the calibratedness and/or accuracy of DeepProbLog models. If there is a tendency for calibration to improve calibratedness, which is what we expect, we can infer that DeepProbLog's training algorithm does not naturally calibrate models. If there is a tendency for it to, counterintuitively, worsen calibratedness then our conclusion will depend on the final experiment. If calibrating a representative standalone network twice in a row worsens its calibratedness, we have evidence to support DeepProbLog's training algorithm does indeed result in calibrated models. If it does not, further research (with a plethora of possible hypotheses) into why calibration worsens DeepProbLog model calibratedness is needed and we will not be able to immediately draw strong conclusions. We will then only be able to suggest some starting points for the further research.
\section{Experiment reproducibility}
To make our experiments reproducible, we will distribute all code used to carry them out along with this text and in the following GitHub repositories:
\begin{itemize}
\item \url{https://github.com/Joshua-Schroijen/deepproblog}: \par Contains the source code of the DeepProbLog examples extended with calibration. The Python script \\ \texttt{src/deepproblog/examples/calibration\_evaluator.py} generates the raw results of all DeepProbLog experiments discussed here in a new directory named \\ \texttt{src/deepproblog/examples/calibration\_evaluator\_results}.
\item \url{https://github.com/Joshua-Schroijen/base_nn_calibration_research.git}: \par
Contains a modified version of the calibration demonstration distributed by \cite{guo2017calibration} that carries out the last double calibration experiment discussed here.
\item \url{https://github.com/Joshua-Schroijen/text-callibration-in-deepproblog}: \par Contains the \LaTeX source code of this text, our raw results obtained by running the \texttt{calibration\_evaluator.py} script and a Python script \\ \texttt{Present\_calibration\_evaluation\_results/generate\_assets.py} that generates all the tables and figures used in this chapter.
\end{itemize}
Along with this text ZIP file snapshots of these repositories will be found in our distributions.