From 8c73cc8e77fd0d77ff47b53bded294066791066d Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Thu, 9 Nov 2023 13:12:57 -0800 Subject: [PATCH] Add notebook with validation data fitting example --- notebooks/fitting_with_validation_data.ipynb | 206 +++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 notebooks/fitting_with_validation_data.ipynb diff --git a/notebooks/fitting_with_validation_data.ipynb b/notebooks/fitting_with_validation_data.ipynb new file mode 100644 index 0000000..fa60bad --- /dev/null +++ b/notebooks/fitting_with_validation_data.ipynb @@ -0,0 +1,206 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Fitting with validation data \n", + "\n", + "This notebook shows how using validation data can improve the normalizing flow fit.\n", + "\n", + "We create a synthetic example with very little training data and a flow with a very large number of layers. We show that using validation data prevents the flow from overfitting in spite of having too many parameters. " + ], + "metadata": { + "collapsed": false + }, + "id": "ea05f26c641401d0" + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2023-11-09T21:11:10.440521400Z", + "start_time": "2023-11-09T21:11:08.267904100Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "from normalizing_flows.flows import Flow\n", + "from normalizing_flows.bijections import RealNVP" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "# Create some synthetic training and validation data\n", + "torch.manual_seed(0)\n", + "\n", + "event_shape = (10,)\n", + "n_train = 100\n", + "n_val = 20\n", + "n_test = 10000\n", + "\n", + "x_train = torch.randn(n_train, *event_shape) * 2 + 4\n", + "x_val = torch.randn(n_val, *event_shape) * 2 + 4\n", + "x_test = torch.randn(n_test, *event_shape) * 2 + 4" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:11:10.456694900Z", + "start_time": "2023-11-09T21:11:10.445522900Z" + } + }, + "id": "21b252329b5695cf" + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fitting NF: 100%|██████████| 500/500 [00:15<00:00, 32.71it/s, Training loss (batch): 1.7106]\n" + ] + } + ], + "source": [ + "# Train without validation data\n", + "torch.manual_seed(0)\n", + "flow0 = Flow(RealNVP(event_shape, n_layers=20))\n", + "flow0.fit(x_train, show_progress=True)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:11:25.777575300Z", + "start_time": "2023-11-09T21:11:10.457694Z" + } + }, + "id": "b8c5703314f84814" + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fitting NF: 100%|██████████| 500/500 [00:23<00:00, 21.42it/s, Training loss (batch): 1.7630, Validation loss: 2.8325]\n" + ] + } + ], + "source": [ + "# Train with validation data and keep the best weights\n", + "torch.manual_seed(0)\n", + "flow1 = Flow(RealNVP(event_shape, n_layers=20))\n", + "flow1.fit(x_train, show_progress=True, x_val=x_val)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:11:49.164216Z", + "start_time": "2023-11-09T21:11:25.775746200Z" + } + }, + "id": "95d4d4e0447f1d4" + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fitting NF: 39%|███▉ | 194/500 [00:11<00:18, 16.57it/s, Training loss (batch): 1.9825, Validation loss: 2.1353]\n" + ] + } + ], + "source": [ + "# Train with validation data, early stopping, and keep the best weights\n", + "torch.manual_seed(0)\n", + "flow2 = Flow(RealNVP(event_shape, n_layers=20))\n", + "flow2.fit(x_train, show_progress=True, x_val=x_val, early_stopping=True)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:12:00.931776800Z", + "start_time": "2023-11-09T21:11:49.165794100Z" + } + }, + "id": "2a6ff6eaea4e1323" + }, + { + "cell_type": "markdown", + "source": [ + "The normalizing flow has a lot of parameters and thus overfits without validation data. The test loss is much lower when using validation data. We may stop training early after no observable validation loss improvement for a certain number of epochs (default: 50). In this experiment, validation loss does not improve after these epochs, as evidenced by the same test loss as observed without early stopping." + ], + "metadata": { + "collapsed": false + }, + "id": "84366140ce6804fe" + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test loss values\n", + "\n", + "Without validation data: 55.78230667114258\n", + "With validation data, no early stopping: 24.563425064086914\n", + "With validation data, early stopping: 24.563425064086914\n" + ] + } + ], + "source": [ + "print(\"Test loss values\")\n", + "print()\n", + "print(f\"Without validation data: {torch.mean(-flow0.log_prob(x_test))}\")\n", + "print(f\"With validation data, no early stopping: {torch.mean(-flow1.log_prob(x_test))}\")\n", + "print(f\"With validation data, early stopping: {torch.mean(-flow2.log_prob(x_test))}\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-09T21:12:01.263469Z", + "start_time": "2023-11-09T21:12:00.925959700Z" + } + }, + "id": "bfaca2ae85997ee3" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}