Skip to content

Commit

Permalink
Add notebook with validation data fitting example
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Nov 9, 2023
1 parent 46385b5 commit 8c73cc8
Showing 1 changed file with 206 additions and 0 deletions.
206 changes: 206 additions & 0 deletions notebooks/fitting_with_validation_data.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Fitting with validation data \n",
"\n",
"This notebook shows how using validation data can improve the normalizing flow fit.\n",
"\n",
"We create a synthetic example with very little training data and a flow with a very large number of layers. We show that using validation data prevents the flow from overfitting in spite of having too many parameters. "
],
"metadata": {
"collapsed": false
},
"id": "ea05f26c641401d0"
},
{
"cell_type": "code",
"execution_count": 1,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2023-11-09T21:11:10.440521400Z",
"start_time": "2023-11-09T21:11:08.267904100Z"
}
},
"outputs": [],
"source": [
"import torch\n",
"from normalizing_flows.flows import Flow\n",
"from normalizing_flows.bijections import RealNVP"
]
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"# Create some synthetic training and validation data\n",
"torch.manual_seed(0)\n",
"\n",
"event_shape = (10,)\n",
"n_train = 100\n",
"n_val = 20\n",
"n_test = 10000\n",
"\n",
"x_train = torch.randn(n_train, *event_shape) * 2 + 4\n",
"x_val = torch.randn(n_val, *event_shape) * 2 + 4\n",
"x_test = torch.randn(n_test, *event_shape) * 2 + 4"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-11-09T21:11:10.456694900Z",
"start_time": "2023-11-09T21:11:10.445522900Z"
}
},
"id": "21b252329b5695cf"
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Fitting NF: 100%|██████████| 500/500 [00:15<00:00, 32.71it/s, Training loss (batch): 1.7106]\n"
]
}
],
"source": [
"# Train without validation data\n",
"torch.manual_seed(0)\n",
"flow0 = Flow(RealNVP(event_shape, n_layers=20))\n",
"flow0.fit(x_train, show_progress=True)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-11-09T21:11:25.777575300Z",
"start_time": "2023-11-09T21:11:10.457694Z"
}
},
"id": "b8c5703314f84814"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Fitting NF: 100%|██████████| 500/500 [00:23<00:00, 21.42it/s, Training loss (batch): 1.7630, Validation loss: 2.8325]\n"
]
}
],
"source": [
"# Train with validation data and keep the best weights\n",
"torch.manual_seed(0)\n",
"flow1 = Flow(RealNVP(event_shape, n_layers=20))\n",
"flow1.fit(x_train, show_progress=True, x_val=x_val)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-11-09T21:11:49.164216Z",
"start_time": "2023-11-09T21:11:25.775746200Z"
}
},
"id": "95d4d4e0447f1d4"
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Fitting NF: 39%|███▉ | 194/500 [00:11<00:18, 16.57it/s, Training loss (batch): 1.9825, Validation loss: 2.1353]\n"
]
}
],
"source": [
"# Train with validation data, early stopping, and keep the best weights\n",
"torch.manual_seed(0)\n",
"flow2 = Flow(RealNVP(event_shape, n_layers=20))\n",
"flow2.fit(x_train, show_progress=True, x_val=x_val, early_stopping=True)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-11-09T21:12:00.931776800Z",
"start_time": "2023-11-09T21:11:49.165794100Z"
}
},
"id": "2a6ff6eaea4e1323"
},
{
"cell_type": "markdown",
"source": [
"The normalizing flow has a lot of parameters and thus overfits without validation data. The test loss is much lower when using validation data. We may stop training early after no observable validation loss improvement for a certain number of epochs (default: 50). In this experiment, validation loss does not improve after these epochs, as evidenced by the same test loss as observed without early stopping."
],
"metadata": {
"collapsed": false
},
"id": "84366140ce6804fe"
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test loss values\n",
"\n",
"Without validation data: 55.78230667114258\n",
"With validation data, no early stopping: 24.563425064086914\n",
"With validation data, early stopping: 24.563425064086914\n"
]
}
],
"source": [
"print(\"Test loss values\")\n",
"print()\n",
"print(f\"Without validation data: {torch.mean(-flow0.log_prob(x_test))}\")\n",
"print(f\"With validation data, no early stopping: {torch.mean(-flow1.log_prob(x_test))}\")\n",
"print(f\"With validation data, early stopping: {torch.mean(-flow2.log_prob(x_test))}\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-11-09T21:12:01.263469Z",
"start_time": "2023-11-09T21:12:00.925959700Z"
}
},
"id": "bfaca2ae85997ee3"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 8c73cc8

Please sign in to comment.