-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #17 from davidnabergoj/docs
Docs
- Loading branch information
Showing
64 changed files
with
1,270 additions
and
271 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"metadata": {}, | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Computing the log determinant of the Jacobian\n", | ||
"\n", | ||
"We show how to compute and retrieve the log determinant of the Jacobian of a bijective transformation. We use Real NVP as an example." | ||
], | ||
"id": "624f99599895f0fd" | ||
}, | ||
{ | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-13T16:39:40.646799Z", | ||
"start_time": "2024-08-13T16:39:38.868039Z" | ||
} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"import torch\n", | ||
"from torchflows import Flow\n", | ||
"from torchflows.architectures import RealNVP\n", | ||
"\n", | ||
"torch.manual_seed(0)\n", | ||
"\n", | ||
"batch_shape = (5, 7)\n", | ||
"event_shape = (2, 3)\n", | ||
"x = torch.randn(size=(*batch_shape, *event_shape))\n", | ||
"z = torch.randn(size=(*batch_shape, *event_shape))\n", | ||
"\n", | ||
"bijection = RealNVP(event_shape=event_shape)\n", | ||
"flow = Flow(bijection)\n", | ||
"\n", | ||
"_, log_det_forward = flow.bijection.forward(x)\n", | ||
"_, log_det_inverse = flow.bijection.inverse(z)" | ||
], | ||
"id": "3f74b61a9929dd3b", | ||
"outputs": [], | ||
"execution_count": 1 | ||
}, | ||
{ | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-13T16:39:40.662420Z", | ||
"start_time": "2024-08-13T16:39:40.653696Z" | ||
} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"print(f'{log_det_forward.shape = }')\n", | ||
"print(f'{log_det_inverse.shape = }')" | ||
], | ||
"id": "3c49e132d9c041c2", | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"log_det_forward.shape = torch.Size([5, 7])\n", | ||
"log_det_inverse.shape = torch.Size([5, 7])\n" | ||
] | ||
} | ||
], | ||
"execution_count": 2 | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"metadata": {}, | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Image modeling with normalizing flows\n", | ||
"\n", | ||
"When working with images, we can use specialized multiscale flow architectures. We can also use standard normalizing flows, which internally work with a flattened image. Note that multiscale architectures expect input images with shape `(channels, height, width)`." | ||
], | ||
"id": "df68afe10da259a1" | ||
}, | ||
{ | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-13T17:20:05.803231Z", | ||
"start_time": "2024-08-13T17:20:03.001656Z" | ||
} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"from torchvision.datasets import MNIST\n", | ||
"import torch\n", | ||
"\n", | ||
"torch.manual_seed(0)\n", | ||
"\n", | ||
"# pip install torchvision\n", | ||
"dataset = MNIST(root='./data', download=True, train=True)\n", | ||
"train_data = dataset.data.float()[:, None]\n", | ||
"train_data = train_data[torch.randperm(len(train_data))]\n", | ||
"train_data = (train_data - torch.mean(train_data)) / torch.std(train_data)\n", | ||
"x_train, x_val = train_data[:1000], train_data[1000:1200]\n", | ||
"\n", | ||
"print(f'{x_train.shape = }')\n", | ||
"print(f'{x_val.shape = }')\n", | ||
"\n", | ||
"image_shape = train_data.shape[1:]\n", | ||
"print(f'{image_shape = }')" | ||
], | ||
"id": "b4d5e1888ff6a0e7", | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"x_train.shape = torch.Size([1000, 1, 28, 28])\n", | ||
"x_val.shape = torch.Size([200, 1, 28, 28])\n", | ||
"image_shape = torch.Size([1, 28, 28])\n" | ||
] | ||
} | ||
], | ||
"execution_count": 1 | ||
}, | ||
{ | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-13T17:20:06.058329Z", | ||
"start_time": "2024-08-13T17:20:05.891695Z" | ||
} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"from torchflows.flows import Flow\n", | ||
"from torchflows.architectures import RealNVP, MultiscaleRealNVP\n", | ||
"\n", | ||
"real_nvp = Flow(RealNVP(image_shape))\n", | ||
"multiscale_real_nvp = Flow(MultiscaleRealNVP(image_shape))" | ||
], | ||
"id": "744513899ffa6a46", | ||
"outputs": [], | ||
"execution_count": 2 | ||
}, | ||
{ | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-13T17:26:11.651540Z", | ||
"start_time": "2024-08-13T17:20:06.378393Z" | ||
} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"real_nvp.fit(x_train, x_val=x_val, early_stopping=True, show_progress=True)\n", | ||
"multiscale_real_nvp.fit(x_train, x_val=x_val, early_stopping=True, show_progress=True)" | ||
], | ||
"id": "7a439e2565ce5a25", | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Fitting NF: 30%|███ | 151/500 [00:18<00:42, 8.30it/s, Training loss (batch): -0.2608, Validation loss: 1.3448 [best: 0.1847 @ 100]] \n", | ||
"Fitting NF: 30%|███ | 152/500 [05:47<13:14, 2.28s/it, Training loss (batch): -0.3050, Validation loss: 0.9754 [best: 0.1744 @ 101]] \n" | ||
] | ||
} | ||
], | ||
"execution_count": 3 | ||
}, | ||
{ | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-13T17:26:11.699539Z", | ||
"start_time": "2024-08-13T17:26:11.686539Z" | ||
} | ||
}, | ||
"cell_type": "code", | ||
"source": "", | ||
"id": "c38fc6cc58bdc0b2", | ||
"outputs": [], | ||
"execution_count": null | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.