-
Notifications
You must be signed in to change notification settings - Fork 3
/
rm_make.Rmd
184 lines (141 loc) · 8.26 KB
/
rm_make.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
---
title: "README.rmd"
author: "Brian Cho"
date: "6/1/2022"
output: github_document
always_allow_html: true
---
# Distillery: Interpretable Machine Learning Methods and Surrogate Model Methods
`Distillery` provides several methods for model distillation and interpretability
for general black box machine learning models. This package provides implementations
of the partial dependence plot (PDP), individual conditional expectation (ICE),
and accumulated local effect (ALE) methods, which are model-agnostic interpretability
methods (work with any supervised machine learning model). This package also
provides a novel method for building a surrogate model that approximates the
behavior of its initial algorithm.
Below, we provide a simple example that outlines how to use this package. For
further details on surrogate distillation, advanced interpretability features,
or local surrogate methods, see the articles provided on this page.
For documentation, see [https://forestry-labs.github.io/Distillery/index.html](https://forestry-labs.github.io/Distillery/index.html)
## A Simple Example: Predicting Carparace Width of Leptograpsus Crabs
Throughout this section, we provide a tutorial on using the package with a random
forest predictor for the Carparace Width of Leptograpsus Crabs. We demonstrate how
to plot PDP, ICE, and ALE curves for machine learning interpretability, and
show how to build the surrogate model that approximates the behavior of the initial
random forest predictor.
### General Prediction Wrapper
First we load in the crabs data set. This contains physical measurements of
several species of crabs collected at Fremantle, West Australia.
```{r}
library(MASS)
library(Distillery)
library(Rforestry)
library(ggplot2)
set.seed(491)
data <- MASS::crabs
levels(data$sex) <- list(Male = "M", Female = "F")
levels(data$sp) <- list(Orange = "O", Blue = "B")
colnames(data) <- c("Species","Sex","Index","Frontal Lobe",
"Rear Width", "Carapace Length","Carapace Width",
"Body Depth")
```
We can train a random forest to estimate the Carapace Width of the crabs based on the
other features. In order to use the interpretability features, we must create
a `Predictor` class for the estimator we want to interpret. This class
standardizes the predictions, tracks the outcome feature, and stores the
training data.
```{r}
# Get training data set
set.seed(491)
test_ind <- sample(1:nrow(data), nrow(data)%/%5)
train_reg <- data[-test_ind,]
test_reg <- data[test_ind,]
# Train a random forest on the data set
forest <- forestry(x=train_reg[,-which(names(train_reg)=="Carapace Width")],
y=train_reg[,which(names(train_reg)=="Carapace Width")])
# Create a predictor wrapper for the forest
# this allows us to use a standard wrapper for querying any
# trained estimator
forest_predictor <- Predictor$new(model = forest,
data=train_reg,
y="Carapace Width",
task = "regression")
```
### Interpretability Wrapper
Once we have initialized a `Predictor` object for the forest, we can pass this to the
`Interpreter` class. By default, the `Interpreter` class subsamples the training data
to be at most 1000 samples in order to speed up computation for interpretabilitiy methods.
This class provides the names and classes of the features, the indicies of the
sampled data points, lists of univariate and bivariate PDP functions, and stores
additional information for plot settings.
```{r}
forest_interpret <- Interpreter$new(predictor = forest_predictor)
print(forest_interpret)
```
The PDP functions are stored in two lists, one for univariate PDP functions and one for bivariate PDP functions.For any feature, we can retrieve the pdp function by selecting the entry in the list with that feature name. We can directly use these PDP functions by specifying values for a specific feature. The functions then return the PDP curve's values. For univariate functions, we specify values through a vector of values. For bivariate functions, we input a dataframe or matrix with two columns, with each row providing a pair of values and each column representing a specific feature.
```{r}
# univariate PDP
one_feat <- train_reg$`Frontal Lobe`[1:10]
preds_pdp <- forest_interpret$pdp.1d$`Frontal Lobe`(one_feat)
print(preds_pdp)
# bivariate PDP
two_feat <- cbind(train_reg$`Frontal Lobe`[1:10],
train_reg$`Rear Width`[1:10])
preds_pdp_2d <- forest_interpret$pdp.2d$`Frontal Lobe`$`Rear Width`(two_feat)
print(preds_pdp_2d)
```
### Basic Plotting
For univariate and bivariate interpretability methods, we can use the `plot` method
for the Interpreter class. For univariate summaries of the model's behavior, we have three main options: PDP, ICE, and ALE curves. To plot a specific curve for a given set of feature, we simply specify the `method` parameter in `plot` function, as shown below:
```{r}
# plotting PDP functions
plot(forest_interpret,
method = "pdp",
features = c("Frontal Lobe", "Rear Width"))
plot(forest_interpret,
method = "ice",
features = c("Frontal Lobe", "Rear Width"))
## default option (does this without specifying method)
plot(forest_interpret,
method = "pdp+ice",
features = c("Frontal Lobe", "Rear Width"))
plot(forest_interpret,
method = "ale",
features = c("Frontal Lobe", "Rear Width"))
```
For bivariate summary plots, the package provides two distinct methods. Given a continuous and categorical feature, the `plot` function provides conditional PDP curves, which separates the mean values based on the categorical feature value. For two continuous features, the `plot` function provides a heatmap. To input the pairs of features to plot, we specify this in the form of a two-column dataframe of feature names, where each row represents a single pair.
```{r}
plot(forest_interpret,
features.2d = data.frame(feat.1 = c("Frontal Lobe", "Frontal Lobe"),
feat.2 = c("Sex", "Rear Width")))
```
For more advanced plotting features, such as clustering ICE curves or specifying the number of points plotted, please refer to the article "Advanced Plotting Features".
### Local Surrogates
Even with a heatmap or conditional plots, two dimensional summaries may be difficult to interpret. The function `localSurrogate` provides a local summary of how changes in a pair of features affect the predictions of the model by providing a simple decision tree summary.
```{r}
local.surr <- localSurrogate(forest_interpret,
features.2d = data.frame(feat.1 = c("Frontal Lobe",
"Frontal Lobe"),
feat.2 = c("Sex",
"Rear Width")))
plot(local.surr$models$`Frontal Lobe.Sex`)
plot(local.surr$models$`Frontal Lobe.Rear Width`)
```
For additional details on the `localSurrogate` method, such as specifying the depth
or number of trees in the weak learner, please refer to the article "Local Surrogate".
### Distillation: Creating the Default Surrogate Model
In this package, we also provide an implementation of a new algorithm, which creates a linear recombination of the univariate PDP curves to generate a surrogate model. To do this, we use the `distill` method on an interpeter object, which returns a surrogate model. With this surrogate model, we can make predictions, and compare the original predictions of the random forest and those of the surrogate model below.
```{r}
forest_surrogate <- distill(forest_interpret)
predictions_forest <- predict(forest,
test_reg[,-which(names(test_reg) == "Carapace Width")])
# surrogate predictions are returned as a one-column dataframe
predictions_surrogate <- predict(forest_surrogate,
test_reg[,-which(names(test_reg) == "Carapace Width")])
plot.comparison <- data.frame(original = predictions_forest,
surrogate = predictions_surrogate[,1])
ggplot(data = plot.comparison, aes(x = original, y = surrogate)) +
geom_point() + geom_abline(col = "red")
```
For additional details on creating the distilled surrogate models, please refer to
the article "Distillation Methods".