-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/main'
# Conflicts: # docs/conf.py
- Loading branch information
Showing
66 changed files
with
1,453 additions
and
254 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"metadata": {}, | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Computing the log determinant of the Jacobian\n", | ||
"\n", | ||
"We show how to compute and retrieve the log determinant of the Jacobian of a bijective transformation. We use Real NVP as an example." | ||
], | ||
"id": "624f99599895f0fd" | ||
}, | ||
{ | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-13T16:39:40.646799Z", | ||
"start_time": "2024-08-13T16:39:38.868039Z" | ||
} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"import torch\n", | ||
"from torchflows import Flow\n", | ||
"from torchflows.architectures import RealNVP\n", | ||
"\n", | ||
"torch.manual_seed(0)\n", | ||
"\n", | ||
"batch_shape = (5, 7)\n", | ||
"event_shape = (2, 3)\n", | ||
"x = torch.randn(size=(*batch_shape, *event_shape))\n", | ||
"z = torch.randn(size=(*batch_shape, *event_shape))\n", | ||
"\n", | ||
"bijection = RealNVP(event_shape=event_shape)\n", | ||
"flow = Flow(bijection)\n", | ||
"\n", | ||
"_, log_det_forward = flow.bijection.forward(x)\n", | ||
"_, log_det_inverse = flow.bijection.inverse(z)" | ||
], | ||
"id": "3f74b61a9929dd3b", | ||
"outputs": [], | ||
"execution_count": 1 | ||
}, | ||
{ | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-13T16:39:40.662420Z", | ||
"start_time": "2024-08-13T16:39:40.653696Z" | ||
} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"print(f'{log_det_forward.shape = }')\n", | ||
"print(f'{log_det_inverse.shape = }')" | ||
], | ||
"id": "3c49e132d9c041c2", | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"log_det_forward.shape = torch.Size([5, 7])\n", | ||
"log_det_inverse.shape = torch.Size([5, 7])\n" | ||
] | ||
} | ||
], | ||
"execution_count": 2 | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"metadata": {}, | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Image modeling with normalizing flows\n", | ||
"\n", | ||
"When working with images, we can use specialized multiscale flow architectures. We can also use standard normalizing flows, which internally work with a flattened image. Note that multiscale architectures expect input images with shape `(channels, height, width)`." | ||
], | ||
"id": "df68afe10da259a1" | ||
}, | ||
{ | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-13T17:20:05.803231Z", | ||
"start_time": "2024-08-13T17:20:03.001656Z" | ||
} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"from torchvision.datasets import MNIST\n", | ||
"import torch\n", | ||
"\n", | ||
"torch.manual_seed(0)\n", | ||
"\n", | ||
"# pip install torchvision\n", | ||
"dataset = MNIST(root='./data', download=True, train=True)\n", | ||
"train_data = dataset.data.float()[:, None]\n", | ||
"train_data = train_data[torch.randperm(len(train_data))]\n", | ||
"train_data = (train_data - torch.mean(train_data)) / torch.std(train_data)\n", | ||
"x_train, x_val = train_data[:1000], train_data[1000:1200]\n", | ||
"\n", | ||
"print(f'{x_train.shape = }')\n", | ||
"print(f'{x_val.shape = }')\n", | ||
"\n", | ||
"image_shape = train_data.shape[1:]\n", | ||
"print(f'{image_shape = }')" | ||
], | ||
"id": "b4d5e1888ff6a0e7", | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"x_train.shape = torch.Size([1000, 1, 28, 28])\n", | ||
"x_val.shape = torch.Size([200, 1, 28, 28])\n", | ||
"image_shape = torch.Size([1, 28, 28])\n" | ||
] | ||
} | ||
], | ||
"execution_count": 1 | ||
}, | ||
{ | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-13T17:20:06.058329Z", | ||
"start_time": "2024-08-13T17:20:05.891695Z" | ||
} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"from torchflows.flows import Flow\n", | ||
"from torchflows.architectures import RealNVP, MultiscaleRealNVP\n", | ||
"\n", | ||
"real_nvp = Flow(RealNVP(image_shape))\n", | ||
"multiscale_real_nvp = Flow(MultiscaleRealNVP(image_shape))" | ||
], | ||
"id": "744513899ffa6a46", | ||
"outputs": [], | ||
"execution_count": 2 | ||
}, | ||
{ | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-13T17:26:11.651540Z", | ||
"start_time": "2024-08-13T17:20:06.378393Z" | ||
} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"real_nvp.fit(x_train, x_val=x_val, early_stopping=True, show_progress=True)\n", | ||
"multiscale_real_nvp.fit(x_train, x_val=x_val, early_stopping=True, show_progress=True)" | ||
], | ||
"id": "7a439e2565ce5a25", | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Fitting NF: 30%|███ | 151/500 [00:18<00:42, 8.30it/s, Training loss (batch): -0.2608, Validation loss: 1.3448 [best: 0.1847 @ 100]] \n", | ||
"Fitting NF: 30%|███ | 152/500 [05:47<13:14, 2.28s/it, Training loss (batch): -0.3050, Validation loss: 0.9754 [best: 0.1744 @ 101]] \n" | ||
] | ||
} | ||
], | ||
"execution_count": 3 | ||
}, | ||
{ | ||
"metadata": { | ||
"ExecuteTime": { | ||
"end_time": "2024-08-13T17:26:11.699539Z", | ||
"start_time": "2024-08-13T17:26:11.686539Z" | ||
} | ||
}, | ||
"cell_type": "code", | ||
"source": "", | ||
"id": "c38fc6cc58bdc0b2", | ||
"outputs": [], | ||
"execution_count": null | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.