Skip to content

Commit

Permalink
Regression example
Browse files Browse the repository at this point in the history
  • Loading branch information
kirill-fedyanin committed Mar 10, 2020
1 parent 8fd7eb4 commit 0e2d1e2
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 63 deletions.
32 changes: 17 additions & 15 deletions alpaca/model/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@


class BaseMLP(nn.Module):
def __init__(self, layer_sizes, activation, postprocessing=lambda x: x):
def __init__(self, layer_sizes, activation, postprocessing=lambda x: x, device=None):
super(BaseMLP, self).__init__()

self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = device

self.layer_sizes = layer_sizes
self.fcs = []
Expand All @@ -23,19 +26,18 @@ def __init__(self, layer_sizes, activation, postprocessing=lambda x: x):
self.double()
self.to(self.device)

def forward(self, x, dropout_rate=0, train=False, dropout_mask=None):
out = torch.DoubleTensor(x).to(self.device) if isinstance(x, np.ndarray) else x
out = self.activation(self.fcs[0](out))
def forward(self, x, dropout_rate=0, dropout_mask=None):
x = self.activation(self.fcs[0](x))

for layer_num, fc in enumerate(self.fcs[1:-1]):
out = self.activation(fc(out))
x = self.activation(fc(x))
if dropout_mask is None:
out = nn.Dropout(dropout_rate)(out)
x = nn.Dropout(dropout_rate)(x)
else:
out = out*dropout_mask(out, dropout_rate, layer_num)
out = self.fcs[-1](out)
out = self.postprocessing(out)
return out if train else out.detach()
x = x*dropout_mask(x, dropout_rate, layer_num)
x = self.fcs[-1](x)
x = self.postprocessing(x)
return x

def fit(
self, train_set, val_set, epochs=10000, verbose=True,
Expand All @@ -53,7 +55,7 @@ def fit(
labels = labels.to(self.device)

# Forward pass
outputs = self(points, train=True, dropout_rate=dropout_rate)
outputs = self(points, dropout_rate=dropout_rate)
loss = self.criterion(outputs, labels)

# Backward and optimize
Expand Down Expand Up @@ -96,15 +98,15 @@ def _print_status(self, epoch, epochs, loss, val_loss):


class MLP(BaseMLP):
def __init__(self, layer_sizes, l2_reg=1e-5, postprocessing=None, loss=nn.MSELoss,
optimizer=None, activation=None):
def __init__(self, layer_sizes, postprocessing=None, loss=nn.MSELoss,
optimizer=None, activation=None, **kwargs):
if postprocessing is None:
postprocessing = lambda x: x

if activation is None:
activation = F.celu

super(MLP, self).__init__(layer_sizes, activation=activation, postprocessing=postprocessing)
super(MLP, self).__init__(layer_sizes, activation=activation, postprocessing=postprocessing, **kwargs)

self.criterion = loss()

Expand Down
3 changes: 3 additions & 0 deletions alpaca/uncertainty_estimator/mcdue.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch
from .masks import build_mask


class MCDUE:
Expand Down Expand Up @@ -30,6 +31,8 @@ def __init__(self, net, nn_runs=25, dropout_rate=.5, dropout_mask=None, keep_run
self.net = net
self.nn_runs = nn_runs
self.dropout_rate = dropout_rate
if isinstance(dropout_mask, str):
dropout_mask = build_mask(dropout_mask)
self.dropout_mask = dropout_mask
self.keep_runs = keep_runs
self._mcd_runs = np.array([])
Expand Down
137 changes: 90 additions & 47 deletions examples/regression_uq.ipynb
Original file line number Diff line number Diff line change
@@ -1,25 +1,51 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"from torch import nn\n",
"from torch.utils.data import TensorDataset, DataLoader\n",
"\n",
"from sklearn.metrics import r2_score \n",
"from alpaca.uncertainty_estimator import build_estimator\n",
"from alpaca.model.mlp import MLP \n",
"from alpaca.dataloader.builder import build_dataset\n",
"from alpaca.analysis.metrics import ndcg\n",
"\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n",
"is_executing": false
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Load dataset\n",
"mnist = build_dataset('mnist', val_size=10_000)\n",
"mnist = build_dataset('kin8nm', val_split=1_000)\n",
"x_train, y_train = mnist.dataset('train')\n",
"x_val, y_val = mnist.dataset('val')\n",
"x_shape = (-1, 1, 28, 28)\n",
"\n",
"train_ds = TensorDataset(torch.FloatTensor(x_train.reshape(x_shape)), torch.LongTensor(y_train))\n",
"val_ds = TensorDataset(torch.FloatTensor(x_val.reshape(x_shape)), torch.LongTensor(y_val))\n",
"x_train.shape, y_val.shape\n",
"train_ds = TensorDataset(torch.DoubleTensor(x_train), torch.DoubleTensor(y_train))\n",
"val_ds = TensorDataset(torch.DoubleTensor(x_val), torch.DoubleTensor(y_val))\n",
"train_loader = DataLoader(train_ds, batch_size=512)\n",
"val_loader = DataLoader(val_ds, batch_size=512)\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
"name": "#%%\n",
"is_executing": false
}
}
},
Expand All @@ -29,29 +55,60 @@
"outputs": [],
"source": [
"# Train model\n",
"model = SimpleConv()\n",
"criterion = nn.CrossEntropyLoss()\n",
"layers = (8, 256, 128, 64, 1)\n",
"model = MLP(layers)\n",
"model.to(device)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n",
"is_executing": false
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"criterion = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters())\n",
"\n",
"for x_batch, y_batch in train_loader: # Train for one epoch\n",
" prediction = model(x_batch)\n",
" optimizer.zero_grad()\n",
" loss = criterion(prediction, y_batch)\n",
" loss.backward()\n",
" optimizer.step()\n",
"print('Train loss on last batch', loss.item())\n",
"\n",
"# Check accuracy\n",
"model.train()\n",
"for epochs in range(10):\n",
" for x_batch, y_batch in train_loader: # Train for one epoch\n",
" predictions = model(x_batch.to(device))\n",
" loss = criterion(predictions, y_batch.to(device))\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"print('Train loss on last batch', loss.item())\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n",
"is_executing": false
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Check model effectiveness \n",
"model.eval()\n",
"x_batch, y_batch = next(iter(val_loader))\n",
"\n",
"class_preds = F.softmax(model(x_batch), dim=-1).detach().numpy()\n",
"predictions = np.argmax(class_preds, axis=-1)\n",
"print('Accuracy', accuracy_score(predictions, y_batch))\n"
"predictions = model(x_batch.to(device)).detach().cpu().numpy()\n",
"print('R2', r2_score(predictions, y_batch))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
"name": "#%%\n",
"is_executing": false
}
}
},
Expand All @@ -61,13 +118,15 @@
"outputs": [],
"source": [
"# Calculate uncertainty estimation\n",
"estimator = build_estimator(\"bald_masked\", model, dropout_mask='mc_dropout', num_classes=10)\n",
"estimations = estimator.estimate(x_batch)\n"
"estimator = build_estimator(\"mcdue_masked\", model, dropout_mask='mc_dropout')\n",
"estimations = estimator.estimate(x_batch.to(device))\n",
"print(estimations[:10])"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
"name": "#%%\n",
"is_executing": false
}
}
},
Expand All @@ -77,15 +136,15 @@
"outputs": [],
"source": [
"# Calculate NDCG score for the uncertainty\n",
"errors = [metrics.log_loss(target.reshape(-1, 1), pred.reshape((-1, 10)), labels=list(range(10))) for pred, target in zip(class_preds, y_batch.numpy())]\n",
"\n",
"errors = np.abs(estimations - y_batch.reshape((-1)).numpy()) \n",
"score = ndcg(np.array(errors), estimations)\n",
"print(\"Quality score is \", score)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
"name": "#%%\n",
"is_executing": false
}
}
},
Expand All @@ -99,7 +158,8 @@
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
"name": "#%%\n",
"is_executing": false
}
}
}
Expand All @@ -125,24 +185,7 @@
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"source": [
"\n",
"%%\n",
"import numpy as np\n",
"import torch\n",
"from torch import nn\n",
"from torch.utils.data import TensorDataset, DataLoader\n",
"import torch.nn.functional as F\n",
"\n",
"from sklearn.metrics import accuracy_score\n",
"import matplotlib.pyplot as plt\n",
"from sklearn import metrics\n",
"\n",
"from alpaca.uncertainty_estimator import build_estimator\n",
"from alpaca.model.cnn import SimpleConv\n",
"from alpaca.dataloader.builder import build_dataset\n",
"from alpaca.analysis.metrics import ndcg\n"
],
"source": [],
"metadata": {
"collapsed": false
}
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def load_requirements(filename):

setuptools.setup(
name="alpaca-ml",
version="0.0.8",
version="0.0.10",
author="Maxim Panov and Evgenii Tsymbalov and Kirill Fedyanin",
author_email="[email protected]",
description="Active learning utilities for machine learning applications",
Expand Down

0 comments on commit 0e2d1e2

Please sign in to comment.