-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add notebook with validation data fitting example
- Loading branch information
1 parent
46385b5
commit 8c73cc8
Showing
1 changed file
with
206 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Fitting with validation data \n", | ||
"\n", | ||
"This notebook shows how using validation data can improve the normalizing flow fit.\n", | ||
"\n", | ||
"We create a synthetic example with very little training data and a flow with a very large number of layers. We show that using validation data prevents the flow from overfitting in spite of having too many parameters. " | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "ea05f26c641401d0" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "initial_id", | ||
"metadata": { | ||
"collapsed": true, | ||
"ExecuteTime": { | ||
"end_time": "2023-11-09T21:11:10.440521400Z", | ||
"start_time": "2023-11-09T21:11:08.267904100Z" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from normalizing_flows.flows import Flow\n", | ||
"from normalizing_flows.bijections import RealNVP" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"outputs": [], | ||
"source": [ | ||
"# Create some synthetic training and validation data\n", | ||
"torch.manual_seed(0)\n", | ||
"\n", | ||
"event_shape = (10,)\n", | ||
"n_train = 100\n", | ||
"n_val = 20\n", | ||
"n_test = 10000\n", | ||
"\n", | ||
"x_train = torch.randn(n_train, *event_shape) * 2 + 4\n", | ||
"x_val = torch.randn(n_val, *event_shape) * 2 + 4\n", | ||
"x_test = torch.randn(n_test, *event_shape) * 2 + 4" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-11-09T21:11:10.456694900Z", | ||
"start_time": "2023-11-09T21:11:10.445522900Z" | ||
} | ||
}, | ||
"id": "21b252329b5695cf" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Fitting NF: 100%|██████████| 500/500 [00:15<00:00, 32.71it/s, Training loss (batch): 1.7106]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Train without validation data\n", | ||
"torch.manual_seed(0)\n", | ||
"flow0 = Flow(RealNVP(event_shape, n_layers=20))\n", | ||
"flow0.fit(x_train, show_progress=True)" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-11-09T21:11:25.777575300Z", | ||
"start_time": "2023-11-09T21:11:10.457694Z" | ||
} | ||
}, | ||
"id": "b8c5703314f84814" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Fitting NF: 100%|██████████| 500/500 [00:23<00:00, 21.42it/s, Training loss (batch): 1.7630, Validation loss: 2.8325]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Train with validation data and keep the best weights\n", | ||
"torch.manual_seed(0)\n", | ||
"flow1 = Flow(RealNVP(event_shape, n_layers=20))\n", | ||
"flow1.fit(x_train, show_progress=True, x_val=x_val)" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-11-09T21:11:49.164216Z", | ||
"start_time": "2023-11-09T21:11:25.775746200Z" | ||
} | ||
}, | ||
"id": "95d4d4e0447f1d4" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Fitting NF: 39%|███▉ | 194/500 [00:11<00:18, 16.57it/s, Training loss (batch): 1.9825, Validation loss: 2.1353]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Train with validation data, early stopping, and keep the best weights\n", | ||
"torch.manual_seed(0)\n", | ||
"flow2 = Flow(RealNVP(event_shape, n_layers=20))\n", | ||
"flow2.fit(x_train, show_progress=True, x_val=x_val, early_stopping=True)" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-11-09T21:12:00.931776800Z", | ||
"start_time": "2023-11-09T21:11:49.165794100Z" | ||
} | ||
}, | ||
"id": "2a6ff6eaea4e1323" | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"The normalizing flow has a lot of parameters and thus overfits without validation data. The test loss is much lower when using validation data. We may stop training early after no observable validation loss improvement for a certain number of epochs (default: 50). In this experiment, validation loss does not improve after these epochs, as evidenced by the same test loss as observed without early stopping." | ||
], | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"id": "84366140ce6804fe" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Test loss values\n", | ||
"\n", | ||
"Without validation data: 55.78230667114258\n", | ||
"With validation data, no early stopping: 24.563425064086914\n", | ||
"With validation data, early stopping: 24.563425064086914\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(\"Test loss values\")\n", | ||
"print()\n", | ||
"print(f\"Without validation data: {torch.mean(-flow0.log_prob(x_test))}\")\n", | ||
"print(f\"With validation data, no early stopping: {torch.mean(-flow1.log_prob(x_test))}\")\n", | ||
"print(f\"With validation data, early stopping: {torch.mean(-flow2.log_prob(x_test))}\")" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"ExecuteTime": { | ||
"end_time": "2023-11-09T21:12:01.263469Z", | ||
"start_time": "2023-11-09T21:12:00.925959700Z" | ||
} | ||
}, | ||
"id": "bfaca2ae85997ee3" | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |