Skip to content

Commit

Permalink
Refactor bayesian neural network example (#670)
Browse files Browse the repository at this point in the history
- Improve axis labels
- Remove `mutable=True` argument (will get deprecated)
- Fix shape warning
- Fix minibatch sampling issue (cf. #654)
  • Loading branch information
ElisabethBrockhausQC authored Jun 15, 2024
1 parent 9003ee2 commit 86b37ad
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 169 deletions.
274 changes: 120 additions & 154 deletions examples/variational_inference/bayesian_neural_network_advi.ipynb

Large diffs are not rendered by default.

34 changes: 19 additions & 15 deletions examples/variational_inference/bayesian_neural_network_advi.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt
import seaborn as sns
from sklearn.datasets import make_moons
Expand Down Expand Up @@ -103,7 +102,7 @@ ax.scatter(X[Y == 0, 0], X[Y == 0, 1], color="C0", label="Class 0")
ax.scatter(X[Y == 1, 0], X[Y == 1, 1], color="C1", label="Class 1")
sns.despine()
ax.legend()
ax.set(xlabel="X", ylabel="Y", title="Toy binary classification data set");
ax.set(xlabel="X1", ylabel="X2", title="Toy binary classification data set");
```

### Model specification
Expand All @@ -127,11 +126,11 @@ def construct_nn(ann_input, ann_output):
"hidden_layer_1": np.arange(n_hidden),
"hidden_layer_2": np.arange(n_hidden),
"train_cols": np.arange(X_train.shape[1]),
# "obs_id": np.arange(X_train.shape[0]),
"obs_id": np.arange(X_train.shape[0]),
}
with pm.Model(coords=coords) as neural_network:
ann_input = pm.Data("ann_input", X_train, mutable=True, dims=("obs_id", "train_cols"))
ann_output = pm.Data("ann_output", Y_train, mutable=True, dims="obs_id")
ann_input = pm.Data("ann_input", X_train, dims=("obs_id", "train_cols"))
ann_output = pm.Data("ann_output", Y_train, dims="obs_id")
# Weights from input to hidden layer
weights_in_1 = pm.Normal(
Expand Down Expand Up @@ -215,14 +214,15 @@ pred = ppc.posterior_predictive["out"].mean(("chain", "draw")) > 0.5

```{code-cell} ipython3
fig, ax = plt.subplots()
ax.scatter(X_test[pred == 0, 0], X_test[pred == 0, 1], color="C0")
ax.scatter(X_test[pred == 1, 0], X_test[pred == 1, 1], color="C1")
ax.scatter(X_test[pred == 0, 0], X_test[pred == 0, 1], color="C0", label="Predicted 0")
ax.scatter(X_test[pred == 1, 0], X_test[pred == 1, 1], color="C1", label="Predicted 1")
sns.despine()
ax.set(title="Predicted labels in testing set", xlabel="X", ylabel="Y");
ax.legend()
ax.set(title="Predicted labels in testing set", xlabel="X1", ylabel="X2");
```

```{code-cell} ipython3
print(f"Accuracy = {(Y_test == pred.values).mean() * 100}%")
print(f"Accuracy = {(Y_test == pred.values).mean() * 100:.2f}%")
```

Hey, our neural network did all right!
Expand All @@ -240,16 +240,21 @@ jupyter:
---
grid = pm.floatX(np.mgrid[-3:3:100j, -3:3:100j])
grid_2d = grid.reshape(2, -1).T
dummy_out = np.ones(grid.shape[1], dtype=np.int8)
dummy_out = np.ones(grid_2d.shape[0], dtype=np.int8)
```

```{code-cell} ipython3
---
jupyter:
outputs_hidden: true
---
coords_eval = {
"train_cols": np.arange(grid_2d.shape[1]),
"obs_id": np.arange(grid_2d.shape[0]),
}
with neural_network:
pm.set_data(new_data={"ann_input": grid_2d, "ann_output": dummy_out})
pm.set_data(new_data={"ann_input": grid_2d, "ann_output": dummy_out}, coords=coords_eval)
ppc = pm.sample_posterior_predictive(trace)
```

Expand All @@ -268,7 +273,7 @@ contour = ax.contourf(
ax.scatter(X_test[pred == 0, 0], X_test[pred == 0, 1], color="C0")
ax.scatter(X_test[pred == 1, 0], X_test[pred == 1, 1], color="C1")
cbar = plt.colorbar(contour, ax=ax)
_ = ax.set(xlim=(-3, 3), ylim=(-3, 3), xlabel="X", ylabel="Y")
_ = ax.set(xlim=(-3, 3), ylim=(-3, 3), xlabel="X1", ylabel="X2")
cbar.ax.set_ylabel("Posterior predictive mean probability of class label = 0");
```

Expand All @@ -285,7 +290,7 @@ contour = ax.contourf(
ax.scatter(X_test[pred == 0, 0], X_test[pred == 0, 1], color="C0")
ax.scatter(X_test[pred == 1, 0], X_test[pred == 1, 1], color="C1")
cbar = plt.colorbar(contour, ax=ax)
_ = ax.set(xlim=(-3, 3), ylim=(-3, 3), xlabel="X", ylabel="Y")
_ = ax.set(xlim=(-3, 3), ylim=(-3, 3), xlabel="X1", ylabel="X2")
cbar.ax.set_ylabel("Uncertainty (posterior predictive standard deviation)");
```

Expand All @@ -300,8 +305,7 @@ So far, we have trained our model on all data at once. Obviously this won't scal
Fortunately, ADVI can be run on mini-batches as well. It just requires some setting up:

```{code-cell} ipython3
minibatch_x = pm.Minibatch(X_train, batch_size=50)
minibatch_y = pm.Minibatch(Y_train, batch_size=50)
minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)
neural_network_minibatch = construct_nn(minibatch_x, minibatch_y)
with neural_network_minibatch:
approx = pm.fit(40000, method=pm.ADVI())
Expand Down

0 comments on commit 86b37ad

Please sign in to comment.