diff --git a/docs/conf.py b/docs/conf.py index 06ff1fe..009f213 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -7,7 +7,7 @@ author = 'David Nabergoj' release = '1.0' -version = '1.0.4' +version = '1.0.5' # -- General configuration diff --git a/docs/notebooks/MNIST.ipynb b/docs/notebooks/MNIST.ipynb new file mode 100644 index 0000000..4d4b7a5 --- /dev/null +++ b/docs/notebooks/MNIST.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "code", + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-08-30T15:11:39.705430Z", + "start_time": "2024-08-30T15:11:36.646915Z" + } + }, + "source": [ + "import matplotlib.pyplot as plt\n", + "import torch\n", + "from torchflows.flows import Flow\n", + "from torchflows.architectures import MultiscaleRealNVP\n", + "\n", + "from torchvision.datasets import MNIST" + ], + "outputs": [], + "execution_count": 1 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-30T15:11:40.777876Z", + "start_time": "2024-08-30T15:11:39.710431Z" + } + }, + "cell_type": "code", + "source": [ + "torch.manual_seed(0)\n", + "\n", + "train_dataset = MNIST(root='./data', train=True, download=True)\n", + "data_constraint = 0.9\n", + "\n", + "x = train_dataset.data.float()\n", + "y = (x + torch.rand_like(x)) / 256.\n", + "y = (2 * y - 1) * data_constraint\n", + "y = (y + 1) / 2\n", + "y = y.log() - (1. - y).log()\n", + "y = y[:, None]\n", + "y = torch.concat([y, torch.randn_like(y)], dim=1) # Auxiliary Gaussian channel dimensions\n", + "# y = (y - torch.mean(y)) / torch.std(y)\n", + "# y = torch.nn.functional.pad(y, [2, 2, 2, 2])\n", + "\n", + "train_data = y[:50000]\n", + "validation_data = y[50000:]" + ], + "id": "51636bf2dddb1fa1", + "outputs": [], + "execution_count": 2 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-30T15:11:40.919287Z", + "start_time": "2024-08-30T15:11:40.905414Z" + } + }, + "cell_type": "code", + "source": "y.shape", + "id": "6d49b5c9a93f2699", + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([60000, 2, 28, 28])" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 3 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-30T15:56:58.351957Z", + "start_time": "2024-08-30T15:11:40.996413Z" + } + }, + "cell_type": "code", + "source": [ + "torch.manual_seed(0)\n", + "flow = Flow(MultiscaleRealNVP(event_shape=train_data.shape[1:])).cuda()\n", + "flow.fit(x_train=train_data, x_val=validation_data, show_progress=True, early_stopping=True, batch_size='adaptive')" + ], + "id": "d630e341d5471a8f", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fitting NF: 73%|███████▎ | 367/500 [45:17<16:24, 7.40s/it, Training loss (batch): 0.6879, Validation loss: 0.7043 [best: 0.2577 @ 316]]\n" + ] + } + ], + "execution_count": 4 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-30T16:48:10.812621Z", + "start_time": "2024-08-30T16:48:10.516652Z" + } + }, + "cell_type": "code", + "source": [ + "torch.manual_seed(0)\n", + "x_flow = flow.sample((100,))[:, 0].detach().cpu()\n", + "\n", + "plt.matshow(x_flow[8], cmap='gray');" + ], + "id": "87a588361a15ede6", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaMAAAGkCAYAAACckEpMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjJUlEQVR4nO3de2zV9f3H8Vdb6OHWHqhAL1KglOtEMKtQ64XhaLgsM6CYeOEPcAYjK2bInIZNRTaTbvwSJS4I/2wwI3hhEYhkY+HWEuRiqBCmYgMNCtgL0NmeXqC0Pd/fH4SyjuvnQ8/3c3r6fCQn2tPz4vs53357XhzO97xPnOd5ngAAcCje9QIAAKCMAADOUUYAAOcoIwCAc5QRAMA5yggA4BxlBABwjjICADhHGQEAnKOMAADOdZoyWrlypYYOHaoePXooNzdXn3/+uesl+e6NN95QXFxcu8vo0aNdL8sXu3fv1iOPPKKMjAzFxcVp06ZN7b7veZ5ef/11paenq2fPnsrPz9exY8fcLDaCbrYf5s2bd9UxMn36dDeLjaDCwkJNmDBBSUlJGjhwoGbNmqXS0tJ2t7lw4YIKCgp0xx13qE+fPpo9e7aqqqocrTgybmU/TJ48+apj4vnnn3e04uvrFGX00UcfafHixVq6dKm++OILjR8/XtOmTdOZM2dcL813d911lyoqKtoue/bscb0kXzQ0NGj8+PFauXLlNb+/fPlyvfPOO1q9erUOHDig3r17a9q0abpw4YLPK42sm+0HSZo+fXq7Y+SDDz7wcYX+KC4uVkFBgfbv369t27apublZU6dOVUNDQ9ttXnzxRX366afasGGDiouLVV5erscee8zhqjverewHSZo/f367Y2L58uWOVnwDXicwceJEr6CgoO3r1tZWLyMjwyssLHS4Kv8tXbrUGz9+vOtlOCfJ27hxY9vX4XDYS0tL8/7v//6v7bqamhovEAh4H3zwgYMV+uN/94Pned7cuXO9mTNnOlmPS2fOnPEkecXFxZ7nXfr5d+/e3duwYUPbbY4ePepJ8vbt2+dqmRH3v/vB8zzvJz/5iferX/3K3aJuUdQ/M7p48aJKSkqUn5/fdl18fLzy8/O1b98+hytz49ixY8rIyNCwYcM0Z84cnTx50vWSnDtx4oQqKyvbHSPBYFC5ubld8hgpKirSwIEDNWrUKC1YsEDV1dWulxRxtbW1kqSUlBRJUklJiZqbm9sdE6NHj9bgwYNj+pj43/1w2bp169S/f3+NHTtWS5YsUWNjo4vl3VA31wu4mXPnzqm1tVWpqantrk9NTdU333zjaFVu5Obmau3atRo1apQqKiq0bNkyPfTQQ/ryyy+VlJTkennOVFZWStI1j5HL3+sqpk+frscee0xZWVkqKyvTb3/7W82YMUP79u1TQkKC6+VFRDgc1qJFi/TAAw9o7Nixki4dE4mJierbt2+728byMXGt/SBJTz/9tIYMGaKMjAwdOXJEr7zyikpLS/XJJ584XO3Vor6McMWMGTPa/n/cuHHKzc3VkCFD9PHHH+vZZ591uDJEiyeffLLt/++++26NGzdO2dnZKioq0pQpUxyuLHIKCgr05ZdfdpnXT6/nevvhueeea/v/u+++W+np6ZoyZYrKysqUnZ3t9zKvK+r/ma5///5KSEi46iyYqqoqpaWlOVpVdOjbt69Gjhyp48ePu16KU5ePA46Rqw0bNkz9+/eP2WNk4cKF2rJli3bt2qVBgwa1XZ+WlqaLFy+qpqam3e1j9Zi43n64ltzcXEmKumMi6ssoMTFROTk52rFjR9t14XBYO3bsUF5ensOVuVdfX6+ysjKlp6e7XopTWVlZSktLa3eMhEIhHThwoMsfI6dPn1Z1dXXMHSOe52nhwoXauHGjdu7cqaysrHbfz8nJUffu3dsdE6WlpTp58mRMHRM32w/XcvjwYUmKvmPC9RkUt+LDDz/0AoGAt3btWu/rr7/2nnvuOa9v375eZWWl66X56te//rVXVFTknThxwvvss8+8/Px8r3///t6ZM2dcLy3i6urqvEOHDnmHDh3yJHlvvfWWd+jQIe+7777zPM/z/vjHP3p9+/b1Nm/e7B05csSbOXOml5WV5Z0/f97xyjvWjfZDXV2d99JLL3n79u3zTpw44W3fvt378Y9/7I0YMcK7cOGC66V3qAULFnjBYNArKiryKioq2i6NjY1tt3n++ee9wYMHezt37vQOHjzo5eXleXl5eQ5X3fFuth+OHz/u/f73v/cOHjzonThxwtu8ebM3bNgwb9KkSY5XfrVOUUae53l//vOfvcGDB3uJiYnexIkTvf3797teku+eeOIJLz093UtMTPTuvPNO74knnvCOHz/uelm+2LVrlyfpqsvcuXM9z7t0evdrr73mpaameoFAwJsyZYpXWlrqdtERcKP90NjY6E2dOtUbMGCA1717d2/IkCHe/PnzY/IvbdfaB5K8NWvWtN3m/Pnz3i9/+UuvX79+Xq9evbxHH33Uq6iocLfoCLjZfjh58qQ3adIkLyUlxQsEAt7w4cO93/zmN15tba3bhV9DnOd5nn/PwwAAuFrUv2YEAIh9lBEAwDnKCADgHGUEAHCOMgIAOEcZAQCc61Rl1NTUpDfeeENNTU2ul+IU++EK9sUl7Icr2BeXdLb90KneZxQKhRQMBlVbW6vk5GTXy3GG/XAF++IS9sMV7ItLOtt+6FTPjAAAsYkyAgA4F3WfZxQOh1VeXq6kpCTFxcW1+14oFGr3366K/XAF++IS9sMV7ItLomE/eJ6nuro6ZWRkKD7+xs99ou41o9OnTyszM9P1MgAAHeTUqVM3/ZylqHtmdPnjs++5556Y/ZjkWPG/z1xvVZT9/QdAhLS2turw4cNtj+s3EnVldPkBLiEhwZcysn1A9YPtg7Zf9ynay8hmO34eD9G+vmgX7fsv2tfnp1u5XxE7gWHlypUaOnSoevToodzcXH3++eeR2hQAoJOLSBl99NFHWrx4sZYuXaovvvhC48eP17Rp03TmzJlIbA4A0MlFpIzeeustzZ8/X88884x+9KMfafXq1erVq5f++te/RmJzAIBOrsPL6OLFiyopKVF+fv6VjcTHKz8/X/v27bvq9k1NTQqFQu0uAICupcPL6Ny5c2ptbVVqamq761NTU1VZWXnV7QsLCxUMBtsunNYNAF2P8wkMS5YsUW1tbdvl1KlTrpcEAPBZh5/a3b9/fyUkJKiqqqrd9VVVVUpLS7vq9oFAQIFAoKOXAQDoRDr8mVFiYqJycnK0Y8eOtuvC4bB27NihvLy8jt4cACAGRORNr4sXL9bcuXN17733auLEiVqxYoUaGhr0zDPPRGJzAIBOLiJl9MQTT+js2bN6/fXXVVlZqXvuuUdbt2696qQGAACkKByUevkDoXJycmJqHJCfuzkW7xP8F4vHkc19ivaxXNE8dqilpUUlJSW39AF/zs+mAwCAMgIAOEcZAQCco4wAAM5RRgAA5ygjAIBzlBEAwDnKCADgHGUEAHCOMgIAOEcZAQCco4wAAM5RRgAA5yLyERIdxY/JstE88dZPsTg5OT7e/O9a4XDYOGObs1lft252v7Ktra2+ZGwm7dvsB8lun0f7pPmu/HjEMyMAgHOUEQDAOcoIAOAcZQQAcI4yAgA4RxkBAJyjjAAAzlFGAADnKCMAgHOUEQDAOcoIAOAcZQQAcC6qB6X6wWZIo1/DFm0HIEb7+vzalp9DMW0GhNrcJ5vhpZJ/g2Zt2A6ntRErQ0Vvl+3vRiT3H8+MAADOUUYAAOcoIwCAc5QRAMA5yggA4BxlBABwjjICADhHGQEAnKOMAADOUUYAAOcoIwCAc5QRAMC5Lj8o1WZgoF8DOP0c9GkzrNJ2kKbNsM/u3bsbZ7p1Mz+8W1pajDNS9A8V7dWrl3GmubnZONPY2GicSU5ONs5IUl1dnXHG5ufbo0cP4wzM8cwIAOAcZQQAcI4yAgA4RxkBAJyjjAAAzlFGAADnKCMAgHOUEQDAOcoIAOAcZQQAcI4yAgA4RxkBAJyL6kGpfgwKjYuL8yUT7WyGivq5rfT0dOOMzc/pzJkzxhlboVDIOJOYmGi1rdraWl+2ZTPQNjs72zgjSf/5z3+MMz/88INx5sKFC8YZyX6orR+i8TGMZ0YAAOcoIwCAcx1eRm+88Ybi4uLaXUaPHt3RmwEAxJCIvFBw1113afv27Vc24uPrEQCAziciLdGtWzelpaVF4o8GAMSgiLxmdOzYMWVkZGjYsGGaM2eOTp48ed3bNjU1KRQKtbsAALqWDi+j3NxcrV27Vlu3btWqVat04sQJPfTQQ9f9vPrCwkIFg8G2S2ZmZkcvCQAQ5eK8CL+Zp6amRkOGDNFbb72lZ5999qrvNzU1qampqe3rUCikzMxM5eTkKCEhIZJLkxSd59u74Od+iI83/zsQ7zO6xPZ9Ri0tLb5sy+bhZOzYscYZifcZXRbNj2EtLS0qKSlRbW2tkpOTb3jbiJ9Z0LdvX40cOVLHjx+/5vcDgYACgUCklwEAiGIRf59RfX29ysrKrP5mCwDoGjq8jF566SUVFxfr22+/1d69e/Xoo48qISFBTz31VEdvCgAQIzr8n+lOnz6tp556StXV1RowYIAefPBB7d+/XwMGDOjoTQEAYkSHl9GHH37Y0X8kACDGRfVoBD/OEmltbTXO2JwN5ucZLzbbqq+vN87Yni306KOPGmdszjzLy8szzrz77rvGGUlWz/wHDhxonLn//vuNM5KUmppqnLFZn82ZZ0eOHDHOSNKgQYOMM5999pnVtmz891nCt8rmZK7m5mbjjO1UnEg+jjEoFQDgHGUEAHCOMgIAOEcZAQCco4wAAM5RRgAA5ygjAIBzlBEAwDnKCADgHGUEAHCOMgIAOEcZAQCci+pBqRH+RHRJdkNPo53NfrvZRwJfy1tvvWWckaS///3vxpmZM2caZ44ePWqcsR2U+vXXXxtnbD4223bAZV1dnS+Z8vJy48yIESOMM5KUn59vnLEZ9LllyxbjjGQ39NSGzTFhO/A0ko/JsfdIDADodCgjAIBzlBEAwDnKCADgHGUEAHCOMgIAOEcZAQCco4wAAM5RRgAA5ygjAIBzlBEAwDnKCADgXFQPSjXh5+A/m+GqLS0txplwOGyckaSUlBTjzPDhw40zy5YtM85I0tSpU40zNoNcz507Z5xZsmSJcUaSBgwYYJxpaGgwzmRnZxtnJGndunXGGZufU3V1tXGmX79+xhlJWrVqlXHm+++/N84kJiYaZySpqanJOOPX0FPbxxbbx9lbwTMjAIBzlBEAwDnKCADgHGUEAHCOMgIAOEcZAQCco4wAAM5RRgAA5ygjAIBzlBEAwDnKCADgHGUEAHCOMgIAOBczU7tt2Uyh9WvSt80EX0lqbm42zpw6dco484tf/MI4I0lHjx41zixfvtw4M3ToUOPM5MmTjTOStHfvXuOMzT4vLy83zkjSkCFDjDOlpaXGmaSkJONMa2urcUaSMjIyjDM2v+///ve/jTOSlJCQYJyxeWyxyURy+rYtnhkBAJyjjAAAzlFGAADnKCMAgHOUEQDAOcoIAOAcZQQAcI4yAgA4RxkBAJyjjAAAzlFGAADnKCMAgHMxMyjVZligLZuhpzbru3jxonFGkkaPHm2cGTlypHHmH//4h3FGksaPH2+c2bBhg3EmPT3dOPO73/3OOCNJ1dXVxpkLFy4YZ3r06GGcsc11797dODN8+HDjTFVVlXFGkqZMmWKcWbt2rXEmMzPTOCNJlZWVxpmWlharbcUCnhkBAJyjjAAAzhmX0e7du/XII48oIyNDcXFx2rRpU7vve56n119/Xenp6erZs6fy8/N17NixjlovACAGGZdRQ0ODxo8fr5UrV17z+8uXL9c777yj1atX68CBA+rdu7emTZtm9e/jAICuwfgEhhkzZmjGjBnX/J7neVqxYoVeffVVzZw5U5L03nvvKTU1VZs2bdKTTz55e6sFAMSkDn3N6MSJE6qsrFR+fn7bdcFgULm5udq3b981M01NTQqFQu0uAICupUPL6PKpjKmpqe2uT01Nve5pjoWFhQoGg20X29MoAQCdl/Oz6ZYsWaLa2tq2y6lTp1wvCQDgsw4to7S0NElXv4mtqqqq7Xv/KxAIKDk5ud0FANC1dGgZZWVlKS0tTTt27Gi7LhQK6cCBA8rLy+vITQEAYojx2XT19fU6fvx429cnTpzQ4cOHlZKSosGDB2vRokV68803NWLECGVlZem1115TRkaGZs2a1ZHrBgDEEOMyOnjwoB5++OG2rxcvXixJmjt3rtauXauXX35ZDQ0Neu6551RTU6MHH3xQW7dutZ6pBQCIfXGenxNGb0EoFFIwGFROTo4SEhJuORcXFxfBVblh+6Pp3bu3cebxxx83ztx///3GGUnau3evcaa1tdU4U1paapxpbGw0zkjSV199ZZwJh8PGGZshvZKUmJhonLEZlDpv3jzjTFFRkXFGsltfXV2dccZ2YPHp06eNM1H2cHzbWltbVVJSotra2pueD+D8bDoAACgjAIBzlBEAwDnKCADgHGUEAHCOMgIAOEcZAQCco4wAAM5RRgAA5ygjAIBzlBEAwDnKCADgnPHUbj9F6/BTv4YZ2gy3lKQ+ffoYZz766CPjzJkzZ4wzkrR69WrjzJtvvmmcsRl6ajPcUrIbpmkzyb5bN7tf2f79+xtnvvnmG+PMrl27jDOjRo0yzkjS4MGDjTNbtmwxztx5553GGUn6/vvvrXKmYmW4Ks+MAADOUUYAAOcoIwCAc5QRAMA5yggA4BxlBABwjjICADhHGQEAnKOMAADOUUYAAOcoIwCAc5QRAMA5yggA4FxUT+2O1mm08fHmHW5zX86fP2+ckaS+ffsaZyZNmmScqampMc5I0qlTp4wz77//vnEmISHBONPS0mKckewmrNscEzbTwSXp3Llzxpk5c+YYZ7Kzs40zY8aMMc5IUnp6unGmd+/expnvvvvOOCNJe/bsMc7YTGWP1k83MMUzIwCAc5QRAMA5yggA4BxlBABwjjICADhHGQEAnKOMAADOUUYAAOcoIwCAc5QRAMA5yggA4BxlBABwLqoHpUbrAMBwOOzLdmwHxZ45c8Y4M27cOOPMqFGjjDOS3aDUsrIy40wgEDDOhEIh44xkN4CzsbHROPP0008bZyRp27ZtxpnXXnvNOPP2228bZ3744QfjjGQ3wNRmiLDN75Mk9ejRwzhj89hi8zhh+9gSycdknhkBAJyjjAAAzlFGAADnKCMAgHOUEQDAOcoIAOAcZQQAcI4yAgA4RxkBAJyjjAAAzlFGAADnKCMAgHNRPSjVD7YDA/2QmJholRs+fLhx5v333zfOBINB44yt++67zzhjM4h04sSJxhnJbihrQkKCccZ2UOX9999vnLn33nuNM/X19caZ0tJS44wkVVdXG2fGjh1rnPniiy+MM1J0Dz2NxiHUPDMCADhHGQEAnDMuo927d+uRRx5RRkaG4uLitGnTpnbfnzdvnuLi4tpdpk+f3lHrBQDEIOMyamho0Pjx47Vy5crr3mb69OmqqKhou3zwwQe3tUgAQGwzPoFhxowZmjFjxg1vEwgElJaWZr0oAEDXEpHXjIqKijRw4ECNGjVKCxYsuOFZL01NTQqFQu0uAICupcPLaPr06Xrvvfe0Y8cO/elPf1JxcbFmzJih1tbWa96+sLBQwWCw7ZKZmdnRSwIARLkOf5/Rk08+2fb/d999t8aNG6fs7GwVFRVpypQpV91+yZIlWrx4cdvXoVCIQgKALibip3YPGzZM/fv31/Hjx6/5/UAgoOTk5HYXAEDXEvEyOn36tKqrq5Wenh7pTQEAOinjf6arr69v9yznxIkTOnz4sFJSUpSSkqJly5Zp9uzZSktLU1lZmV5++WUNHz5c06ZN69CFAwBih3EZHTx4UA8//HDb15df75k7d65WrVqlI0eO6G9/+5tqamqUkZGhqVOn6g9/+IPV7C4AQNdgXEaTJ0++4WC+f/3rX7e1IABA19Plp3b7Nb3WZrJut252P55Dhw4ZZ2wmff/kJz8xzkhSQUGBcSY7O9s4YzM12ZbNiTdDhgwxzpw9e9Y4I0kTJkwwzmzcuNE4M3LkSONMSkqKcUaSxo8fb5zZvHmzccZ2sr/N1HibSf1+Tn+P5KccMCgVAOAcZQQAcI4yAgA4RxkBAJyjjAAAzlFGAADnKCMAgHOUEQDAOcoIAOAcZQQAcI4yAgA4RxkBAJyL8yI5+c5CKBRSMBhUTk6O0aBQP++GX8NVu3fv7st2JCk1NdU4U1FRYbWtlpYW40x5eblxJi0tzTgzZswY44wkffXVV8aZOXPmGGd69eplnJGkoUOHGmfuvPNO40xTU5NxZt26dcYZye6YKCsrM87YDCKV7B4n/Hoc8+sxrKWlRSUlJaqtrb3pMGGeGQEAnKOMAADOUUYAAOcoIwCAc5QRAMA5yggA4BxlBABwjjICADhHGQEAnKOMAADOUUYAAOcoIwCAc7c+idQBP4YG2gwM9GsAYn19vXFGkhITE40zjY2Nxpn4eLu/y9xzzz3GmZ49expnHn/8cePM9u3bjTOSNHPmTOPMxIkTjTNvvvmmcUaSVqxYYZz59NNPjTNHjx41ztgce5L0/fffG2f8HD5sMxDYhslA6ctsH1sj+ZjMMyMAgHOUEQDAOcoIAOAcZQQAcI4yAgA4RxkBAJyjjAAAzlFGAADnKCMAgHOUEQDAOcoIAOAcZQQAcC6qB6VGq3A47Mt2evToYZWzWd8PP/xgnLEdmmgzrNIm8/bbbxtn7r33XuOMJPXr1884c+jQIeOMzUBWSSovLzfO2BwTx44dM84kJCQYZyS7QaStra3GGduBwDb3y+Z3yub33WbY8+3kbgXPjAAAzlFGAADnKCMAgHOUEQDAOcoIAOAcZQQAcI4yAgA4RxkBAJyjjAAAzlFGAADnKCMAgHOUEQDAOcoIAOAcU7ujWGJiolXu4sWLxhmbCeEXLlwwzkhSWVmZcSYQCBhnnnnmGeOMLZt9PmbMGOOM7T7/+OOPjTO1tbVW2zJlM0nbVrdu5g95ttPpbXKRnIr932zvUyTxzAgA4BxlBABwzqiMCgsLNWHCBCUlJWngwIGaNWuWSktL293mwoULKigo0B133KE+ffpo9uzZqqqq6tBFAwBii1EZFRcXq6CgQPv379e2bdvU3NysqVOnqqGhoe02L774oj799FNt2LBBxcXFKi8v12OPPdbhCwcAxA6jV/O2bt3a7uu1a9dq4MCBKikp0aRJk1RbW6u//OUvWr9+vX76059KktasWaMxY8Zo//79uu+++676M5uamtTU1NT2dSgUsrkfAIBO7LZeM7p8tk1KSookqaSkRM3NzcrPz2+7zejRozV48GDt27fvmn9GYWGhgsFg2yUzM/N2lgQA6ISsyygcDmvRokV64IEHNHbsWElSZWWlEhMT1bdv33a3TU1NVWVl5TX/nCVLlqi2trbtcurUKdslAQA6Kev3GRUUFOjLL7/Unj17bmsBgUDA6j0kAIDYYfXMaOHChdqyZYt27dqlQYMGtV2flpamixcvqqampt3tq6qqlJaWdlsLBQDELqMy8jxPCxcu1MaNG7Vz505lZWW1+35OTo66d++uHTt2tF1XWlqqkydPKi8vr2NWDACIOUb/TFdQUKD169dr8+bNSkpKansdKBgMqmfPngoGg3r22We1ePFipaSkKDk5WS+88ILy8vKueSYdAACSYRmtWrVKkjR58uR2169Zs0bz5s2TJL399tuKj4/X7Nmz1dTUpGnTpundd9/tkMUCAGJTnBdlE/NCoZCCwaBycnKMhhpG2d24il8DECW7wZM2AzhzcnKMM5L085//3DhjM1z1oYceMs58++23xhlJ6tOnj3HG5uSfr7/+2jgjSdXV1VY5P/j5uxvNw0tt+XmfTLfV2tqqkpIS1dbWKjk5+Ya3ZTYdAMA5yggA4BxlBABwjjICADhHGQEAnKOMAADOUUYAAOcoIwCAc5QRAMA5yggA4BxlBABwjjICADhn/UmvfvBjgKLNwEC/BjvaDjOMjzf/O8aIESOMM7NmzTLOSNLevXuNM5c/2t7EzQYzXku/fv2MM5JUUlJinElKSjLOnD171jgjSYmJicaZ8+fPG2cSEhKMM+Fw2Dhjuy2bIcI225H8G7Dq52NYJO8Tz4wAAM5RRgAA5ygjAIBzlBEAwDnKCADgHGUEAHCOMgIAOEcZAQCco4wAAM5RRgAA5ygjAIBzlBEAwDnKCADgXFRP7fZr6q2paJ70Ldmt7+TJk8aZ9957zzgjSc3NzcaZmpoa40z37t2NMxUVFcYZSfr888+NMy0tLcYZm4nsknTx4kXjTLdu5g8PNse57X2y4dd9smWzLZvf92h8bOWZEQDAOcoIAOAcZQQAcI4yAgA4RxkBAJyjjAAAzlFGAADnKCMAgHOUEQDAOcoIAOAcZQQAcI4yAgA4F9WDUk2GBtoO/ovmwYStra3GGclu8GQgEDDOnD171jgjSXfddZdxxuY+FRcXG2dsh2La7D/boaw2EhMTjTM2w1UTEhKMM7bHuV/bstmOn/wc5BpJPDMCADhHGQEAnKOMAADOUUYAAOcoIwCAc5QRAMA5yggA4BxlBABwjjICADhHGQEAnKOMAADOUUYAAOeielCqCT+HBfq1LdsBjTbrC4fDxhnbAZclJSXGmT59+hhn6uvrjTO2bH5Wfg4V9WtAqM2x162b3cOQzTFrM3DXTzYDlf0a9my7rVsV3T8ZAECXQBkBAJwzKqPCwkJNmDBBSUlJGjhwoGbNmqXS0tJ2t5k8ebLi4uLaXZ5//vkOXTQAILYYlVFxcbEKCgq0f/9+bdu2Tc3NzZo6daoaGhra3W7+/PmqqKhouyxfvrxDFw0AiC1Grxxu3bq13ddr167VwIEDVVJSokmTJrVd36tXL6WlpXXMCgEAMe+2XjOqra2VJKWkpLS7ft26derfv7/Gjh2rJUuWqLGx8bp/RlNTk0KhULsLAKBrsT61OxwOa9GiRXrggQc0duzYtuuffvppDRkyRBkZGTpy5IheeeUVlZaW6pNPPrnmn1NYWKhly5bZLgMAEAPiPMsTxxcsWKB//vOf2rNnjwYNGnTd2+3cuVNTpkzR8ePHlZ2dfdX3m5qa1NTU1PZ1KBRSZmamcnJyrN9nEyv8fC+AzfsvbN9zYPOel1h8n5EN2/cZ+bU+v449ye59RjZsfw/9Es3vM2ptbVVJSYlqa2uVnJx8w9taPTNauHChtmzZot27d9+wiCQpNzdXkq5bRoFAQIFAwGYZAIAYYVRGnufphRde0MaNG1VUVKSsrKybZg4fPixJSk9Pt1ogACD2GZVRQUGB1q9fr82bNyspKUmVlZWSpGAwqJ49e6qsrEzr16/Xz372M91xxx06cuSIXnzxRU2aNEnjxo2LyB0AAHR+RmW0atUqSZfe2Prf1qxZo3nz5ikxMVHbt2/XihUr1NDQoMzMTM2ePVuvvvpqhy0YABB7jP+Z7kYyMzNVXFx8WwuyFY2D//5bLL4IastmMOb58+eNM34N+pTU7iScW2XzWqmfQ0VtjlmbjF8nIkix+XsY7ffpVjGbDgDgHGUEAHCOMgIAOEcZAQCco4wAAM5RRgAA5ygjAIBzlBEAwDnKCADgHGUEAHCOMgIAOEcZAQCco4wAAM7ZjQD2QTgcNppGa/vRxX7xcxpvtE/x9XNCuCnbfefXpxXb7rtoPiaieW23w6/feT9/nyL5s4ruR3AAQJdAGQEAnKOMAADOUUYAAOcoIwCAc5QRAMA5yggA4BxlBABwjjICADhHGQEAnKOMAADORd1sustzllpbW61yAIDocPlx/FYen6OujOrq6iRJR44ccbwSAEBHqKurUzAYvOFt4rwoe0oRDodVXl6upKSkqybEhkIhZWZm6tSpU0pOTna0QvfYD1ewLy5hP1zBvrgkGvaD53mqq6tTRkbGTT9ZIeqeGcXHx2vQoEE3vE1ycnKXPsguYz9cwb64hP1wBfviEtf74WbPiC7jBAYAgHOUEQDAuU5VRoFAQEuXLvXtUzWjFfvhCvbFJeyHK9gXl3S2/RB1JzAAALqeTvXMCAAQmygjAIBzlBEAwDnKCADgHGUEAHCOMgIAOEcZAQCco4wAAM79P7V+LCjrsZ6HAAAAAElFTkSuQmCC" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 81 + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 0995bf7..b0165ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "torchflows" -version = "1.0.4" +version = "1.0.5" authors = [ { name = "David Nabergoj", email = "david.nabergoj@fri.uni-lj.si" }, ] diff --git a/test/constants.py b/test/constants.py index fbe27a5..32064be 100644 --- a/test/constants.py +++ b/test/constants.py @@ -1,7 +1,7 @@ __test_constants = { 'batch_shape': [(1,), (2,), (5,), (5, 2, 3)], 'event_shape': [(2,), (3,), (3, 5, 2)], - 'image_shape': [(4, 4, 3), (20, 20, 3), (10, 20, 3), (20, 20, 1), (10, 20, 1)], + 'image_shape': [(3, 4, 4), (3, 20, 20), (3, 10, 20), (1, 20, 20), (1, 10, 20)], 'context_shape': [None, (2,), (3,), (3, 5, 2)], 'input_event_shape': [(2,), (3,), (3, 5, 2)], 'output_event_shape': [(2,), (3,), (3, 5, 2)], diff --git a/test/test_autograd_bijections.py b/test/test_autograd_bijections.py index e609a6e..6b4069f 100644 --- a/test/test_autograd_bijections.py +++ b/test/test_autograd_bijections.py @@ -10,8 +10,10 @@ from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \ LRSCoupling, LinearRQSCoupling from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR +from torchflows.bijections.finite.residual.architectures import InvertibleResNet, ResFlow, ProximalResFlow from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock from torchflows.bijections.finite.residual.planar import Planar +from torchflows.bijections.finite.residual.proximal import ProximalResFlowBlock from torchflows.bijections.finite.residual.radial import Radial from torchflows.bijections.finite.residual.sylvester import Sylvester from torchflows.utils import get_batch_shape @@ -93,13 +95,10 @@ def test_masked_autoregressive(bijection_class: Bijection, batch_shape: Tuple, e assert_valid_log_probability_gradient(bijection, x, context) -@pytest.mark.skip(reason="Computation takes too long") @pytest.mark.parametrize('bijection_class', [ - InvertibleResNetBlock, - ResFlowBlock, - Planar, - Radial, - Sylvester + InvertibleResNet, + ResFlow, + ProximalResFlow, ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) diff --git a/test/test_conditioner_transforms.py b/test/test_conditioner_transforms.py index 4d3e677..e6255a6 100644 --- a/test/test_conditioner_transforms.py +++ b/test/test_conditioner_transforms.py @@ -26,7 +26,7 @@ def test_autoregressive(transform_class, x = torch.randn(size=(*batch_shape, *input_event_shape)) transform: ConditionerTransform = transform_class( input_event_shape=input_event_shape, - output_event_shape=output_event_shape, + transformed_event_shape=output_event_shape, parameter_shape_per_element=parameter_shape_per_element, context_shape=context_shape, ) diff --git a/test/test_factored_bijection.py b/test/test_factored_bijection.py deleted file mode 100644 index b4d42ac..0000000 --- a/test/test_factored_bijection.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch -from torchflows.bijections.finite.multiscale.base import FactoredBijection -from torchflows.bijections.finite.autoregressive.layers import ElementwiseAffine -from test.constants import __test_constants - - -def test_basic(): - torch.manual_seed(0) - - bijection = FactoredBijection( - event_shape=(6, 6), - small_bijection_event_shape=(3, 3), - small_bijection_mask=torch.tensor([ - [True, True, True, False, False, False], - [True, True, True, False, False, False], - [True, True, True, False, False, False], - [False, False, False, False, False, False], - [False, False, False, False, False, False], - [False, False, False, False, False, False], - ]), - small_bijection=ElementwiseAffine(event_shape=(3, 3)) - ) - - x = torch.randn(100, *bijection.event_shape) - z, log_det_forward = bijection.forward(x) - - assert torch.allclose( - x[..., ~bijection.transformed_event_mask], - z[..., ~bijection.transformed_event_mask], - __test_constants['data_atol_easy'] - ) - - assert ~torch.allclose( - x[..., bijection.transformed_event_mask], - z[..., bijection.transformed_event_mask] - ) - - xr, log_det_inverse = bijection.inverse(z) - assert torch.allclose(x, xr, __test_constants['data_atol_easy']) - assert torch.allclose(log_det_forward, -log_det_inverse, __test_constants['log_det_atol_easy']) diff --git a/test/test_globally_learned_conditioner_outputs.py b/test/test_globally_learned_conditioner_outputs.py new file mode 100644 index 0000000..0a2dd0f --- /dev/null +++ b/test/test_globally_learned_conditioner_outputs.py @@ -0,0 +1,29 @@ +import torch + +from torchflows.bijections.finite.autoregressive.conditioning.transforms import FeedForward + + +def test_standard(): + torch.manual_seed(0) + + input_event_shape = torch.Size((10, 10)) + parameter_shape = torch.Size((20, 3)) + test_inputs = torch.randn(100, *input_event_shape) + + t = FeedForward(input_event_shape, parameter_shape) + output = t(test_inputs) + + assert output.shape == (100, *parameter_shape) + + +def test_eighty_pct_global(): + torch.manual_seed(0) + + input_event_shape = torch.Size((10, 10)) + parameter_shape = torch.Size((20, 3)) + test_inputs = torch.randn(100, *input_event_shape) + + t = FeedForward(input_event_shape, parameter_shape, percentage_global_parameters=0.8) + output = t(test_inputs) + + assert output.shape == (100, *parameter_shape) diff --git a/test/test_identity_bijections.py b/test/test_identity_bijections.py index 6862d8b..673356b 100644 --- a/test/test_identity_bijections.py +++ b/test/test_identity_bijections.py @@ -2,7 +2,7 @@ from torchflows.bijections.finite.autoregressive.layers import ( AffineCoupling, - DSCoupling, + DeepSigmoidalCoupling, RQSCoupling, InverseAffineCoupling, LRSCoupling, @@ -32,7 +32,7 @@ 'layer_class', [ AffineCoupling, - DSCoupling, + DeepSigmoidalCoupling, RQSCoupling, InverseAffineCoupling, LRSCoupling, diff --git a/test/test_invertible_convolution.py b/test/test_invertible_convolution.py new file mode 100644 index 0000000..037f9b7 --- /dev/null +++ b/test/test_invertible_convolution.py @@ -0,0 +1,13 @@ +import torch +from torchflows.bijections.finite.multiscale.base import Invertible1x1ConvolutionalCoupling + +def test_basic(): + torch.manual_seed(0) + event_shape = 3, 20, 20 + x = torch.randn(size=(4, *event_shape)) + layer = Invertible1x1ConvolutionalCoupling(event_shape) + z, log_det = layer.forward(x) + xr, log_det_inv = layer.inverse(z) + + assert torch.allclose(x, xr) + assert torch.allclose(log_det_inv, -log_det) \ No newline at end of file diff --git a/test/test_layer_gradients.py b/test/test_layer_gradients.py new file mode 100644 index 0000000..5b916ef --- /dev/null +++ b/test/test_layer_gradients.py @@ -0,0 +1,96 @@ +import torch +from torch import nn + +from torchflows import Flow +from torchflows.bijections.base import BijectiveComposition +from torchflows.bijections.finite.autoregressive.layers import AffineCoupling, ElementwiseAffine +from torchflows.bijections.finite.autoregressive.transformers.linear.affine import Affine +from torchflows.bijections.finite.multiscale.architectures import MultiscaleNICE +from torchflows.bijections.finite.multiscale.base import CheckerboardCoupling, NormalizedCheckerboardCoupling + + +def test_elementwise_affine(): + torch.manual_seed(0) + event_shape = torch.Size((3, 20, 20)) + x = torch.randn(size=(4, *event_shape)) + flow = Flow(ElementwiseAffine(event_shape)) + log_prob = flow.log_prob(x) + loss = -log_prob.mean() + loss.backward() + assert loss.grad_fn is not None + + +def test_affine_coupling(): + torch.manual_seed(0) + event_shape = torch.Size((3, 20, 20)) + x = torch.randn(size=(4, *event_shape)) + flow = Flow(AffineCoupling(event_shape)) + log_prob = flow.log_prob(x) + loss = -log_prob.mean() + loss.backward() + assert loss.grad_fn is not None + + +def test_checkerboard(): + torch.manual_seed(0) + event_shape = torch.Size((3, 20, 20)) + x = torch.randn(size=(4, *event_shape)) + flow = Flow(CheckerboardCoupling(event_shape, Affine)) + log_prob = flow.log_prob(x) + loss = -log_prob.mean() + loss.backward() + assert loss.grad_fn is not None + + +def test_normalized_checkerboard(): + torch.manual_seed(0) + event_shape = torch.Size((3, 20, 20)) + x = torch.randn(size=(4, *event_shape)) + flow = Flow(NormalizedCheckerboardCoupling(event_shape, transformer_class=Affine)) + log_prob = flow.log_prob(x) + loss = -log_prob.mean() + loss.backward() + assert loss.grad_fn is not None + + +def test_checkerboard_composition(): + torch.manual_seed(0) + event_shape = torch.Size((3, 20, 20)) + x = torch.randn(size=(4, *event_shape)) + flow = Flow(BijectiveComposition(event_shape, [ + NormalizedCheckerboardCoupling( + event_shape, + transformer_class=Affine, + alternate=i % 2 == 1, + conditioner='convnet' + ) + for i in range(4) + ])) + log_prob = flow.log_prob(x) + loss = -log_prob.mean() + loss.backward() + assert loss.grad_fn is not None + + +def test_multiscale_nice_small(): + torch.manual_seed(0) + event_shape = torch.Size((3, 20, 20)) + x = torch.randn(size=(4, *event_shape)) + flow = Flow(MultiscaleNICE(event_shape, n_layers=1)) + assert isinstance(flow.bijection.checkerboard_layers, nn.ModuleList) + log_prob = flow.log_prob(x) + loss = -log_prob.mean() + loss.backward() + assert loss.grad_fn is not None + + +def test_multiscale_nice(): + torch.manual_seed(0) + event_shape = torch.Size((3, 20, 20)) + x = torch.randn(size=(4, *event_shape)) + flow = Flow(MultiscaleNICE(event_shape, n_layers=2)) + assert isinstance(flow.bijection.checkerboard_layers, nn.ModuleList) + log_prob = flow.log_prob(x) + loss = -log_prob.mean() + loss.backward() + assert loss.grad_fn is not None diff --git a/test/test_multiscale_bijections.py b/test/test_multiscale_bijections.py index b1504a7..cfd4493 100644 --- a/test/test_multiscale_bijections.py +++ b/test/test_multiscale_bijections.py @@ -16,10 +16,10 @@ MultiscaleRealNVP ]) @pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) -def test_non_factored(architecture_class, image_shape): +def test_basic(architecture_class, image_shape): torch.manual_seed(0) x = torch.randn(size=(5, *image_shape)) - bijection = architecture_class(image_shape, n_layers=2, factored=False) + bijection = architecture_class(image_shape, n_layers=2) z, ldf = bijection.forward(x) xr, ldi = bijection.inverse(z) assert torch.allclose(x, xr, atol=__test_constants['data_atol']) @@ -33,78 +33,16 @@ def test_non_factored(architecture_class, image_shape): MultiscaleRealNVP ]) @pytest.mark.parametrize('image_shape', [(1, 28, 28), (3, 28, 28)]) -def test_non_factored_too_small_image(architecture_class, image_shape): +def test_too_small_image(architecture_class, image_shape): torch.manual_seed(0) with pytest.raises(ValueError): - bijection = architecture_class(image_shape, n_layers=3, factored=False) + bijection = architecture_class(image_shape, n_layers=3) -@pytest.mark.parametrize('architecture_class', [ - MultiscaleRQNSF, - MultiscaleLRSNSF, - MultiscaleNICE, - MultiscaleRealNVP -]) -@pytest.mark.parametrize('image_shape', [(1, 32, 32), (3, 32, 32)]) -def test_factored(architecture_class, image_shape): - torch.manual_seed(0) - x = torch.randn(size=(5, *image_shape)) - bijection = architecture_class(image_shape, n_layers=2, factored=True) - z, ldf = bijection.forward(x) - xr, ldi = bijection.inverse(z) - assert torch.allclose(x, xr, atol=__test_constants['data_atol']) - assert torch.allclose(ldf, -ldi, atol=__test_constants['log_det_atol_easy']) # 1e-2 -@pytest.mark.parametrize('architecture_class', [ - MultiscaleRQNSF, - MultiscaleLRSNSF, - MultiscaleNICE, - MultiscaleRealNVP -]) -@pytest.mark.parametrize('image_shape', [(1, 15, 32), (3, 15, 32)]) -def test_factored_wrong_shape(architecture_class, image_shape): - torch.manual_seed(0) - x = torch.randn(size=(5, *image_shape)) - with pytest.raises(ValueError): - bijection = architecture_class(image_shape, n_layers=2, factored=True) - -@pytest.mark.parametrize('architecture_class', [ - MultiscaleRQNSF, - MultiscaleLRSNSF, - MultiscaleNICE, - MultiscaleRealNVP -]) -@pytest.mark.parametrize('image_shape', [(1, 8, 8), (3, 8, 8)]) -def test_factored_too_small_image(architecture_class, image_shape): - torch.manual_seed(0) - x = torch.randn(size=(5, *image_shape)) - with pytest.raises(ValueError): - bijection = architecture_class(image_shape, n_layers=8, factored=True) -@pytest.mark.parametrize('architecture_class', [ - MultiscaleRQNSF, - MultiscaleLRSNSF, - MultiscaleNICE, - MultiscaleRealNVP -]) -@pytest.mark.parametrize('image_shape', [(1, 4, 4), (3, 4, 4)]) -def test_non_factored_automatic_n_layers(architecture_class, image_shape): - torch.manual_seed(0) - x = torch.randn(size=(5, *image_shape)) - bijection = architecture_class(image_shape, factored=False) -@pytest.mark.parametrize('architecture_class', [ - MultiscaleRQNSF, - MultiscaleLRSNSF, - MultiscaleNICE, - MultiscaleRealNVP -]) -@pytest.mark.parametrize('image_shape', [(1, 4, 8), (3, 4, 4)]) -def test_factored_automatic_n_layers(architecture_class, image_shape): - torch.manual_seed(0) - x = torch.randn(size=(5, *image_shape)) - bijection = architecture_class(image_shape, factored=True) diff --git a/test/test_reconstruction_bijections.py b/test/test_reconstruction_bijections.py index dbb2b8f..2aaaffd 100644 --- a/test/test_reconstruction_bijections.py +++ b/test/test_reconstruction_bijections.py @@ -3,7 +3,6 @@ import pytest import torch - from torchflows.bijections.continuous.base import ContinuousBijection, ExactODEFunction from torchflows.bijections.base import Bijection from torchflows.bijections.continuous.ffjord import FFJORD @@ -12,7 +11,7 @@ from torchflows.bijections.finite.autoregressive.architectures import NICE, RealNVP, CouplingRQNSF, MAF, IAF, \ InverseAutoregressiveRQNSF, MaskedAutoregressiveRQNSF from torchflows.bijections.finite.autoregressive.layers import ElementwiseScale, ElementwiseAffine, ElementwiseShift, \ - LRSCoupling, LinearRQSCoupling + LRSCoupling, LinearRQSCoupling, ActNorm, DenseSigmoidalCoupling, DeepDenseSigmoidalCoupling, DeepSigmoidalCoupling from torchflows.bijections.finite.linear import LU, ReversePermutation, LowerTriangular, Orthogonal, QR from torchflows.bijections.finite.residual.architectures import ResFlow, InvertibleResNet, ProximalResFlow from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock @@ -32,6 +31,9 @@ def setup_data(bijection_class, batch_shape, event_shape, context_shape): else: context = None bijection = bijection_class(event_shape) + if isinstance(bijection, (FFJORD, RNODE, OTFlow)): + # "Fix" bijection object + bijection = bijection_class(event_shape, solver='dopri5') # use dopri5 for accurate reconstructions return bijection, x, context @@ -127,7 +129,8 @@ def assert_valid_reconstruction_continuous(bijection: ContinuousBijection, Orthogonal, QR, ElementwiseAffine, - ElementwiseShift + ElementwiseShift, + ActNorm ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) @@ -142,7 +145,10 @@ def test_linear(bijection_class: Bijection, batch_shape: Tuple, event_shape: Tup RealNVP, CouplingRQNSF, LRSCoupling, - LinearRQSCoupling + LinearRQSCoupling, + DenseSigmoidalCoupling, + DeepDenseSigmoidalCoupling, + DeepSigmoidalCoupling, ]) @pytest.mark.parametrize('batch_shape', __test_constants['batch_shape']) @pytest.mark.parametrize('event_shape', __test_constants['event_shape']) diff --git a/test/test_reconstruction_transformers.py b/test/test_reconstruction_transformers.py index d7e54ec..7ca4c54 100644 --- a/test/test_reconstruction_transformers.py +++ b/test/test_reconstruction_transformers.py @@ -9,7 +9,7 @@ LinearRational as LinearRationalSpline from torchflows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import \ RationalQuadratic as RationalQuadraticSpline -from torchflows.bijections.finite.autoregressive.transformers.linear.convolution import Invertible1x1Convolution +from torchflows.bijections.finite.autoregressive.transformers.linear.convolution import Invertible1x1ConvolutionTransformer from torchflows.bijections.finite.autoregressive.transformers.linear.affine import Affine, Scale, Shift from torchflows.bijections.finite.autoregressive.transformers.combination.sigmoid import Sigmoid, DeepSigmoid, \ DenseSigmoid, DeepDenseSigmoid @@ -118,12 +118,10 @@ def test_combination_vector_to_vector(transformer_class: ScalarTransformer, batc @pytest.mark.parametrize('image_shape', __test_constants['image_shape']) def test_convolution(batch_size: int, image_shape: Tuple): torch.manual_seed(0) - transformer = Invertible1x1Convolution(image_shape) - - *image_dimensions, n_channels = image_shape + transformer = Invertible1x1ConvolutionTransformer(image_shape) images = torch.randn(size=(batch_size, *image_shape)) - parameters = torch.randn(size=(batch_size, *image_dimensions, *transformer.parameter_shape)) + parameters = torch.randn(size=(batch_size, *transformer.parameter_shape)) latent_images, log_det_forward = transformer.forward(images, parameters) reconstructed_images, log_det_inverse = transformer.inverse(latent_images, parameters) diff --git a/test/test_sampling.py b/test/test_sampling.py new file mode 100644 index 0000000..2d78772 --- /dev/null +++ b/test/test_sampling.py @@ -0,0 +1,11 @@ +import pytest + +from torchflows import Flow +from torchflows.architectures import PlanarFlow, SylvesterFlow, RadialFlow + + +@pytest.mark.parametrize('arch_cls', [PlanarFlow, SylvesterFlow, RadialFlow]) +def test_basic(arch_cls): + event_shape = (1, 2, 3, 4) + f = Flow(arch_cls(event_shape=event_shape)) + assert f.sample((10,)).shape == (10, *event_shape) diff --git a/test/test_sigmoid_transformer.py b/test/test_sigmoid_transformer.py index d7e7826..3309be0 100644 --- a/test/test_sigmoid_transformer.py +++ b/test/test_sigmoid_transformer.py @@ -2,8 +2,8 @@ import torch from torchflows import Flow -from torchflows.bijections.finite.autoregressive.architectures import CouplingDSF -from torchflows.bijections.finite.autoregressive.layers import DSCoupling +from torchflows.bijections.finite.autoregressive.architectures import CouplingDeepSF +from torchflows.bijections.finite.autoregressive.layers import DeepSigmoidalCoupling from torchflows.bijections.finite.autoregressive.transformers.combination.sigmoid import Sigmoid, DeepSigmoid from torchflows.bijections.base import invert from test.constants import __test_constants @@ -77,8 +77,8 @@ def test_deep_sigmoid_transformer(event_shape, batch_shape, hidden_dim): def test_deep_sigmoid_coupling(event_shape, batch_shape): torch.manual_seed(0) - forward_layer = DSCoupling(torch.Size(event_shape)) - inverse_layer = invert(DSCoupling(torch.Size(event_shape))) + forward_layer = DeepSigmoidalCoupling(torch.Size(event_shape)) + inverse_layer = invert(DeepSigmoidalCoupling(torch.Size(event_shape))) x = torch.randn(size=(*batch_shape, *event_shape)) # Reduce magnitude for stability y, log_det_forward = forward_layer.forward(x) @@ -109,7 +109,7 @@ def test_deep_sigmoid_coupling_flow(event_shape, batch_shape): n_dim = int(torch.prod(torch.tensor(event_shape))) event_shape = (n_dim,) # Overwrite - forward_flow = Flow(CouplingDSF(event_shape)) + forward_flow = Flow(CouplingDeepSF(event_shape)) x = torch.randn(size=(*batch_shape, n_dim)) log_prob = forward_flow.log_prob(x) @@ -117,7 +117,7 @@ def test_deep_sigmoid_coupling_flow(event_shape, batch_shape): assert torch.all(~torch.isnan(log_prob)) assert torch.all(~torch.isinf(log_prob)) - inverse_flow = Flow(invert(CouplingDSF(event_shape))) + inverse_flow = Flow(invert(CouplingDeepSF(event_shape))) x_new = inverse_flow.sample(len(x)) assert x_new.shape == (len(x), *inverse_flow.bijection.event_shape) diff --git a/torchflows/architectures.py b/torchflows/architectures.py index 265c842..b8710d8 100644 --- a/torchflows/architectures.py +++ b/torchflows/architectures.py @@ -9,7 +9,16 @@ InverseAutoregressiveRQNSF, CouplingLRS, MaskedAutoregressiveLRS, - CouplingDSF, + InverseAutoregressiveLRS, + CouplingDeepSF, + MaskedAutoregressiveDeepSF, + InverseAutoregressiveDeepSF, + CouplingDenseSF, + MaskedAutoregressiveDenseSF, + InverseAutoregressiveDenseSF, + CouplingDeepDenseSF, + MaskedAutoregressiveDeepDenseSF, + InverseAutoregressiveDeepDenseSF, UMNNMAF ) @@ -32,7 +41,8 @@ MultiscaleRQNSF, MultiscaleLRSNSF, MultiscaleNICE, - # MultiscaleDeepSigmoid, # TODO stabler initialization - # MultiscaleDenseSigmoid, # TODO stabler initialization - # MultiscaleDeepDenseSigmoid # TODO stabler initialization + MultiscaleDeepSigmoid, + MultiscaleDenseSigmoid, + MultiscaleDeepDenseSigmoid, + AffineGlow ) diff --git a/torchflows/base_distributions/mixture.py b/torchflows/base_distributions/mixture.py index 17f3a0f..a6e7217 100644 --- a/torchflows/base_distributions/mixture.py +++ b/torchflows/base_distributions/mixture.py @@ -38,8 +38,8 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor: # We are assuming all components are normalized value = value.to(self.log_weights) batch_shape = get_batch_shape(value, self.event_shape) - log_probs = torch.zeros(*batch_shape, self.n_components).to(self.log_weights) - for i in range(self.n_components): + log_probs = torch.zeros(*batch_shape, len(self.components)).to(self.log_weights) + for i in range(len(self.components)): log_probs[..., i] = self.components[i].log_prob(value) sample_shape_mask = [None for _ in range(len(value.shape) - len(self.event_shape))] return torch.logsumexp(self.log_weights[sample_shape_mask] + log_probs, dim=-1) diff --git a/torchflows/bijections/base.py b/torchflows/bijections/base.py index 23012fb..57c0ca6 100644 --- a/torchflows/bijections/base.py +++ b/torchflows/bijections/base.py @@ -11,6 +11,7 @@ class Bijection(nn.Module): """Bijection class. """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], context_shape: Union[torch.Size, Tuple[int, ...]] = None, @@ -25,7 +26,6 @@ def __init__(self, self.event_shape = event_shape self.n_dim = int(torch.prod(torch.as_tensor(event_shape))) self.context_shape = context_shape - self.transformed_shape = self.event_shape # Overwritten in multiscale flows TODO make into property def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: """Forward bijection map. @@ -51,9 +51,17 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. """ raise NotImplementedError - @staticmethod - def batch_apply(fn, batch_size, *args): - dataset = TensorDataset(*args) + def batch_apply(self, fn, batch_size, x, context=None): + batch_shape = x.shape[:-len(self.event_shape)] + + if context is None: + x_flat = torch.flatten(x, start_dim=0, end_dim=len(batch_shape) - 1) + dataset = TensorDataset(x_flat) + else: + x_flat = torch.flatten(x, start_dim=0, end_dim=len(batch_shape) - 1) + context_flat = torch.flatten(context, start_dim=0, end_dim=len(batch_shape) - 1) + dataset = TensorDataset(x_flat, context_flat) + data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) outputs = [] log_dets = [] @@ -61,8 +69,8 @@ def batch_apply(fn, batch_size, *args): batch_out, batch_log_det = fn(*batch) outputs.append(batch_out) log_dets.append(batch_log_det) - outputs = torch.cat(outputs, dim=0) - log_dets = torch.cat(log_dets, dim=0) + outputs = torch.cat(outputs, dim=0).view_as(x) + log_dets = torch.cat(log_dets, dim=0).view(*batch_shape) return outputs, log_dets def batch_forward(self, x: torch.Tensor, batch_size: int, context: torch.Tensor = None): @@ -90,6 +98,7 @@ def batch_inverse(self, x: torch.Tensor, batch_size: int, context: torch.Tensor def regularization(self): return 0.0 + def invert(bijection: Bijection) -> Bijection: """Swap the forward and inverse methods of the input bijection. @@ -105,6 +114,7 @@ class BijectiveComposition(Bijection): """ Composition of bijections. Inherits from Bijection. """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], layers: List[Bijection], diff --git a/torchflows/bijections/continuous/base.py b/torchflows/bijections/continuous/base.py index 365606e..212873c 100644 --- a/torchflows/bijections/continuous/base.py +++ b/torchflows/bijections/continuous/base.py @@ -1,3 +1,4 @@ +import math from typing import Union, Tuple, List, Optional, Dict import torch @@ -68,7 +69,10 @@ def create_nn_time_independent(event_size: int, hidden_size: int = 30, n_hidden_ return TimeDerivativeDNN(layers) -def create_nn(event_size: int, hidden_size: int = 30, n_hidden_layers: int = 2): +def create_nn(event_size: int, hidden_size: int = None, n_hidden_layers: int = 2): + if hidden_size is None: + hidden_size = max(4, int(3 * math.log(event_size))) + assert n_hidden_layers >= 0 if n_hidden_layers == 0: layers = [diff_eq_layers.ConcatLinear(event_size, event_size)] @@ -268,12 +272,13 @@ class ContinuousBijection(Bijection): Reference: Chen et al. "Neural Ordinary Differential Equations" (2019); https://arxiv.org/abs/1806.07366. """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], f: ODEFunction, context_shape: Union[torch.Size, Tuple[int, ...]] = None, end_time: float = 1.0, - solver: str = 'dopri5', + solver: str = 'euler', # Use euler (fastest solver) atol: float = 1e-5, rtol: float = 1e-5, **kwargs): diff --git a/torchflows/bijections/continuous/ddnf.py b/torchflows/bijections/continuous/ddnf.py index 5f253e5..4cbff88 100644 --- a/torchflows/bijections/continuous/ddnf.py +++ b/torchflows/bijections/continuous/ddnf.py @@ -17,7 +17,7 @@ class DeepDiffeomorphicBijection(ApproximateContinuousBijection): Reference: Salman et al. "Deep diffeomorphic normalizing flows" (2018); https://arxiv.org/abs/1810.03256. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_steps: int = 150, **kwargs): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_steps: int = 150, solver="euler", **kwargs): """ Constructor. @@ -27,4 +27,4 @@ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_steps: int n_dim = int(torch.prod(torch.as_tensor(event_shape))) diff_eq = RegularizedApproximateODEFunction(create_nn_time_independent(n_dim)) self.n_steps = n_steps - super().__init__(event_shape, diff_eq, solver="euler", **kwargs) # USE DOPRI5 for stability + super().__init__(event_shape, diff_eq, solver=solver, **kwargs) diff --git a/torchflows/bijections/continuous/otflow.py b/torchflows/bijections/continuous/otflow.py index 8571458..fed00f5 100644 --- a/torchflows/bijections/continuous/otflow.py +++ b/torchflows/bijections/continuous/otflow.py @@ -146,7 +146,7 @@ def __init__(self, event_size: int, hidden_size: int = None, **kwargs): # hidden_size = m if hidden_size is None: - hidden_size = max(int(math.log(event_size)), 4) + hidden_size = max(3 * int(math.log(event_size)), 4) r = min(10, event_size) @@ -207,7 +207,8 @@ class OTFlow(ExactContinuousBijection): Reference: Onken et al. "OT-Flow: Fast and Accurate Continuous Normalizing Flows via Optimal Transport" (2021); https://arxiv.org/abs/2006.00104. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], solver='dopri8', **kwargs): n_dim = int(torch.prod(torch.as_tensor(event_shape))) - diff_eq = OTFlowODEFunction(n_dim) - super().__init__(event_shape, diff_eq, **kwargs) + diff_eq = OTFlowODEFunction(n_dim, hidden_size=50) + super().__init__(event_shape, diff_eq, solver=solver, **kwargs) diff --git a/torchflows/bijections/continuous/rnode.py b/torchflows/bijections/continuous/rnode.py index 66d5aa5..3ba06ba 100644 --- a/torchflows/bijections/continuous/rnode.py +++ b/torchflows/bijections/continuous/rnode.py @@ -14,5 +14,5 @@ class RNODE(ApproximateContinuousBijection): """ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): n_dim = int(torch.prod(torch.as_tensor(event_shape))) - diff_eq = RegularizedApproximateODEFunction(create_nn(n_dim), regularization="sq_jac_norm") + diff_eq = RegularizedApproximateODEFunction(create_nn(n_dim, hidden_size=100, n_hidden_layers=1), regularization="sq_jac_norm") super().__init__(event_shape, diff_eq, **kwargs) diff --git a/torchflows/bijections/finite/autoregressive/architectures.py b/torchflows/bijections/finite/autoregressive/architectures.py index 3cc3ce9..1aba1e7 100644 --- a/torchflows/bijections/finite/autoregressive/architectures.py +++ b/torchflows/bijections/finite/autoregressive/architectures.py @@ -9,11 +9,16 @@ RQSForwardMaskedAutoregressive, RQSInverseMaskedAutoregressive, InverseAffineCoupling, - DSCoupling, + DeepSigmoidalCoupling, ElementwiseAffine, UMNNMaskedAutoregressive, LRSCoupling, - LRSForwardMaskedAutoregressive + LRSForwardMaskedAutoregressive, + LRSInverseMaskedAutoregressive, + DenseSigmoidalCoupling, + DeepDenseSigmoidalCoupling, DeepSigmoidalInverseMaskedAutoregressive, DeepSigmoidalForwardMaskedAutoregressive, + DenseSigmoidalInverseMaskedAutoregressive, DenseSigmoidalForwardMaskedAutoregressive, + DeepDenseSigmoidalInverseMaskedAutoregressive, DeepDenseSigmoidalForwardMaskedAutoregressive, ActNorm ) from torchflows.bijections.base import BijectiveComposition from torchflows.bijections.finite.autoregressive.layers_base import CouplingBijection, \ @@ -25,7 +30,8 @@ def make_basic_layers(base_bijection: Type[ Union[CouplingBijection, MaskedAutoregressiveBijection, InverseMaskedAutoregressiveBijection]], event_shape, n_layers: int = 2, - edge_list: List[Tuple[int, int]] = None): + edge_list: List[Tuple[int, int]] = None, + **kwargs): """ Returns a list of bijections for transformations of vectors. """ @@ -33,8 +39,10 @@ def make_basic_layers(base_bijection: Type[ for _ in range(n_layers): if edge_list is None: bijections.append(ReversePermutation(event_shape=event_shape)) - bijections.append(base_bijection(event_shape=event_shape, edge_list=edge_list)) + bijections.append(base_bijection(event_shape=event_shape, edge_list=edge_list, **kwargs)) + bijections.append(ActNorm(event_shape=event_shape)) bijections.append(ElementwiseAffine(event_shape=event_shape)) + bijections.append(ActNorm(event_shape=event_shape)) return bijections @@ -43,6 +51,7 @@ class NICE(BijectiveComposition): Reference: Dinh et al. "NICE: Non-linear Independent Components Estimation" (2015); https://arxiv.org/abs/1410.8516. """ + def __init__(self, event_shape, n_layers: int = 2, @@ -59,6 +68,7 @@ class RealNVP(BijectiveComposition): Reference: Dinh et al. "Density estimation using Real NVP" (2017); https://arxiv.org/abs/1605.08803. """ + def __init__(self, event_shape, n_layers: int = 2, @@ -75,6 +85,7 @@ class InverseRealNVP(BijectiveComposition): Reference: Dinh et al. "Density estimation using Real NVP" (2017); https://arxiv.org/abs/1605.08803. """ + def __init__(self, event_shape, n_layers: int = 2, @@ -117,6 +128,7 @@ class CouplingRQNSF(BijectiveComposition): Reference: Durkan et al. "Neural Spline Flows" (2019); https://arxiv.org/abs/1906.04032. """ + def __init__(self, event_shape, n_layers: int = 2, @@ -146,6 +158,7 @@ class CouplingLRS(BijectiveComposition): Reference: Dolatabadi et al. "Invertible Generative Modeling using Linear Rational Splines" (2020); https://arxiv.org/abs/2001.05168. """ + def __init__(self, event_shape, n_layers: int = 2, @@ -162,6 +175,7 @@ class MaskedAutoregressiveLRS(BijectiveComposition): Reference: Dolatabadi et al. "Invertible Generative Modeling using Linear Rational Splines" (2020); https://arxiv.org/abs/2001.05168. """ + def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) @@ -174,6 +188,7 @@ class InverseAutoregressiveRQNSF(BijectiveComposition): Reference: Durkan et al. "Neural Spline Flows" (2019); https://arxiv.org/abs/1906.04032. """ + def __init__(self, event_shape, n_layers: int = 2, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) @@ -181,11 +196,59 @@ def __init__(self, event_shape, n_layers: int = 2, **kwargs): super().__init__(event_shape, bijections, **kwargs) -class CouplingDSF(BijectiveComposition): - """Coupling deep sigmoidal flow (C-DSF) architecture. +class InverseAutoregressiveLRS(BijectiveComposition): + """Inverse autoregressive linear rational spline (MA-LRS) architecture. + + Reference: Dolatabadi et al. "Invertible Generative Modeling using Linear Rational Splines" (2020); https://arxiv.org/abs/2001.05168. + """ + + def __init__(self, event_shape, n_layers: int = 2, **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers(LRSInverseMaskedAutoregressive, event_shape, n_layers) + super().__init__(event_shape, bijections, **kwargs) + + +class CouplingDeepSF(BijectiveComposition): + """Coupling deep sigmoidal flow architecture. + + Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. + """ + + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers(DeepSigmoidalCoupling, event_shape, n_layers, edge_list) + super().__init__(event_shape, bijections, **kwargs) + + +class InverseAutoregressiveDeepSF(BijectiveComposition): + """Inverse autoregressive deep sigmoidal flow architecture. + + Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. + """ + + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers(DeepSigmoidalInverseMaskedAutoregressive, event_shape, n_layers, edge_list) + super().__init__(event_shape, bijections, **kwargs) + + +class MaskedAutoregressiveDeepSF(BijectiveComposition): + """Masked autoregressive deep sigmoidal flow architecture. Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. """ + def __init__(self, event_shape, n_layers: int = 2, @@ -193,7 +256,151 @@ def __init__(self, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_basic_layers(DSCoupling, event_shape, n_layers, edge_list) + bijections = make_basic_layers(DeepSigmoidalForwardMaskedAutoregressive, event_shape, n_layers, edge_list) + super().__init__(event_shape, bijections, **kwargs) + + +class CouplingDenseSF(BijectiveComposition): + """Coupling dense sigmoidal flow architecture. + + Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. + """ + + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + percentage_global_parameters: float = 0.8, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers( + DenseSigmoidalCoupling, + event_shape, + n_layers, + edge_list, + percentage_global_parameters=percentage_global_parameters + ) + super().__init__(event_shape, bijections, **kwargs) + + +class InverseAutoregressiveDenseSF(BijectiveComposition): + """Inverse autoregressive dense sigmoidal flow architecture. + + Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. + """ + + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + percentage_global_parameters: float = 0.8, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers( + DenseSigmoidalInverseMaskedAutoregressive, + event_shape, + n_layers, + edge_list, + percentage_global_parameters=percentage_global_parameters + ) + super().__init__(event_shape, bijections, **kwargs) + + +class MaskedAutoregressiveDenseSF(BijectiveComposition): + """Masked autoregressive dense sigmoidal flow architecture. + + Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. + """ + + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + percentage_global_parameters: float = 0.8, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers( + DenseSigmoidalForwardMaskedAutoregressive, + event_shape, + n_layers, + edge_list, + percentage_global_parameters=percentage_global_parameters + ) + super().__init__(event_shape, bijections, **kwargs) + + +class CouplingDeepDenseSF(BijectiveComposition): + """Coupling deep-dense sigmoidal flow architecture. + + Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. + """ + + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + percentage_global_parameters: float = 0.8, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers( + DeepDenseSigmoidalCoupling, + event_shape, + n_layers, + edge_list, + percentage_global_parameters=percentage_global_parameters + ) + super().__init__(event_shape, bijections, **kwargs) + + +class InverseAutoregressiveDeepDenseSF(BijectiveComposition): + """Inverse autoregressive deep-dense sigmoidal flow architecture. + + Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. + """ + + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + percentage_global_parameters: float = 0.8, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers( + DeepDenseSigmoidalInverseMaskedAutoregressive, + event_shape, + n_layers, + edge_list, + percentage_global_parameters=percentage_global_parameters + ) + super().__init__(event_shape, bijections, **kwargs) + + +class MaskedAutoregressiveDeepDenseSF(BijectiveComposition): + """Masked autoregressive deep-dense sigmoidal flow architecture. + + Reference: Huang et al. "Neural Autoregressive Flows" (2018); https://arxiv.org/abs/1804.00779. + """ + + def __init__(self, + event_shape, + n_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + percentage_global_parameters: float = 0.8, + **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + bijections = make_basic_layers( + DeepDenseSigmoidalForwardMaskedAutoregressive, + event_shape, + n_layers, + edge_list, + percentage_global_parameters=percentage_global_parameters + ) super().__init__(event_shape, bijections, **kwargs) @@ -202,6 +409,7 @@ class UMNNMAF(BijectiveComposition): Reference: Wehenkel and Louppe "Unconstrained Monotonic Neural Networks" (2021); https://arxiv.org/abs/1908.05164. """ + def __init__(self, event_shape, n_layers: int = 1, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) diff --git a/torchflows/bijections/finite/autoregressive/conditioning/transforms.py b/torchflows/bijections/finite/autoregressive/conditioning/transforms.py index 61f4cc5..7ce5d28 100644 --- a/torchflows/bijections/finite/autoregressive/conditioning/transforms.py +++ b/torchflows/bijections/finite/autoregressive/conditioning/transforms.py @@ -1,5 +1,5 @@ import math -from typing import Tuple, Union, Type +from typing import Tuple, Union, Type, Optional import torch import torch.nn as nn @@ -21,12 +21,14 @@ class ConditionerTransform(nn.Module): """ def __init__(self, - input_event_shape, - context_shape, + input_event_shape: Union[torch.Size, Tuple[int, ...]], + context_shape: Union[torch.Size, Tuple[int, ...]], parameter_shape: Union[torch.Size, Tuple[int, ...]], context_combiner: ContextCombiner = None, - global_parameter_mask: torch.Tensor = None, + global_parameter_mask: Optional[torch.Tensor] = None, initial_global_parameter_value: float = None, + output_lower_bound: float = -torch.inf, + output_upper_bound: float = torch.inf, **kwargs): """ :param input_event_shape: shape of conditioner input tensor x. @@ -46,6 +48,9 @@ def __init__(self, f"but found {global_parameter_mask.shape}" ) + self.output_upper_bound = output_upper_bound + self.output_lower_bound = output_lower_bound + if context_shape is None: context_combiner = Bypass(input_event_shape) elif context_shape is not None and context_combiner is None: @@ -61,7 +66,10 @@ def __init__(self, self.parameter_shape = parameter_shape self.global_parameter_mask = global_parameter_mask self.n_transformer_parameters = int(torch.prod(torch.as_tensor(self.parameter_shape))) - self.n_global_parameters = 0 if global_parameter_mask is None else int(torch.sum(self.global_parameter_mask)) + if global_parameter_mask is None: + self.n_global_parameters = 0 + else: + self.n_global_parameters = int(torch.sum(global_parameter_mask)) self.n_predicted_parameters = self.n_transformer_parameters - self.n_global_parameters if initial_global_parameter_value is None: @@ -80,19 +88,27 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None): batch_shape = get_batch_shape(x, self.input_event_shape) if self.n_global_parameters == 0: # All parameters are predicted - return self.predict_theta_flat(x, context).view(*batch_shape, *self.parameter_shape) + output = self.predict_theta_flat(x, context).view(*batch_shape, *self.parameter_shape) else: if self.n_global_parameters == self.n_transformer_parameters: # All transformer parameters are learned globally - output = torch.zeros(*batch_shape, *self.parameter_shape, device=x.device) + output = torch.zeros(*batch_shape, *self.parameter_shape).to(x) output[..., self.global_parameter_mask] = self.global_theta_flat - return output else: # Some transformer parameters are learned globally, some are predicted - output = torch.zeros(*batch_shape, *self.parameter_shape, device=x.device) + output = torch.zeros(*batch_shape, *self.parameter_shape).to(x) output[..., self.global_parameter_mask] = self.global_theta_flat output[..., ~self.global_parameter_mask] = self.predict_theta_flat(x, context) - return output + + if self.output_lower_bound > -torch.inf and self.output_upper_bound < torch.inf: + output = torch.sigmoid(output) + output = output * (self.output_upper_bound - self.output_lower_bound) + self.output_lower_bound + elif self.output_lower_bound > -torch.inf and self.output_upper_bound == torch.inf: + output = torch.exp(output) + self.output_lower_bound + elif self.output_lower_bound == -torch.inf and self.output_upper_bound < torch.inf: + output = -torch.exp(output) + self.output_upper_bound + + return output def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): raise NotImplementedError @@ -101,18 +117,70 @@ def regularization(self): return sum([torch.sum(torch.square(p)) for p in self.parameters()]) -class Constant(ConditionerTransform): +class ElementwiseConditionerTransform(ConditionerTransform): + """ + Conditioner transform that predicts a set of parameters for every element of the transformed tensor. + """ + + def __init__(self, + input_event_shape: Union[torch.Size, Tuple[int, ...]], + transformed_event_shape: Union[torch.Size, Tuple[int, ...]], + parameter_shape_per_element: Union[torch.Size, Tuple[int, ...]], + context_shape: Union[torch.Size, Tuple[int, ...]] = None, + **kwargs): + super().__init__( + input_event_shape=input_event_shape, + parameter_shape=(*transformed_event_shape, *parameter_shape_per_element), + context_shape=context_shape, + **kwargs + ) + + +class TensorConditionerTransform(ConditionerTransform): + """ + Conditioner transform that predicts a set of parameters for the entire transformed tensor. + """ + + def __init__(self, + input_event_shape: Union[torch.Size, Tuple[int, ...]], + parameter_shape: Union[torch.Size, Tuple[int, ...]], + context_shape: Union[torch.Size, Tuple[int, ...]] = None, + percentage_global_parameters: float = 0.0, + **kwargs): + if 0.0 < percentage_global_parameters <= 1.0: + n_parameters = int(torch.prod(torch.as_tensor(parameter_shape))) + parameter_permutation = torch.randperm(n_parameters) + global_param_indices = parameter_permutation[:int(n_parameters * percentage_global_parameters)] + global_mask = torch.zeros(size=(n_parameters,), dtype=torch.bool) + global_mask[global_param_indices] = True + global_mask = global_mask.view(*parameter_shape) + else: + global_mask = None + + super().__init__( + input_event_shape=input_event_shape, + parameter_shape=parameter_shape, + context_shape=context_shape, + **{ + **kwargs, + **dict( + global_parameter_mask=global_mask + ) + } + ) + + +class Constant(TensorConditionerTransform): def __init__(self, event_shape, parameter_shape, fill_value: float = None): super().__init__( input_event_shape=event_shape, - context_shape=None, parameter_shape=parameter_shape, initial_global_parameter_value=fill_value, global_parameter_mask=torch.ones(parameter_shape, dtype=torch.bool) ) -class MADE(ConditionerTransform): +class MADE(ElementwiseConditionerTransform): """ Masked autoencoder for distribution estimation (MADE). @@ -130,7 +198,7 @@ def forward(self, x): def __init__(self, input_event_shape: Union[torch.Size, Tuple[int, ...]], - output_event_shape: Union[torch.Size, Tuple[int, ...]], + transformed_event_shape: Union[torch.Size, Tuple[int, ...]], parameter_shape_per_element: Union[torch.Size, Tuple[int, ...]], context_shape: Union[torch.Size, Tuple[int, ...]] = None, n_hidden: int = None, @@ -138,12 +206,13 @@ def __init__(self, **kwargs): super().__init__( input_event_shape=input_event_shape, + transformed_event_shape=transformed_event_shape, + parameter_shape_per_element=parameter_shape_per_element, context_shape=context_shape, - parameter_shape=(*output_event_shape, *parameter_shape_per_element), **kwargs ) n_predicted_parameters_per_element = int(torch.prod(torch.as_tensor(parameter_shape_per_element))) - n_output_event_dims = int(torch.prod(torch.as_tensor(output_event_shape))) + n_output_event_dims = int(torch.prod(torch.as_tensor(transformed_event_shape))) if n_hidden is None: n_hidden = max(int(3 * math.log10(self.n_input_event_dims)), 4) @@ -201,7 +270,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, n_layers=1, **kwargs) -class FeedForward(ConditionerTransform): +class FeedForward(TensorConditionerTransform): def __init__(self, input_event_shape: torch.Size, parameter_shape: torch.Size, @@ -218,7 +287,7 @@ def __init__(self, ) if n_hidden is None: - n_hidden = max(int(3 * math.log10(self.n_input_event_dims)), 4) + n_hidden = max(int(5 * math.log10(max(self.n_input_event_dims, self.n_predicted_parameters))), 4) layers = [] if n_layers == 1: @@ -230,7 +299,7 @@ def __init__(self, layers.append(nn.Linear(n_hidden, self.n_predicted_parameters)) else: raise ValueError - layers.append(nn.Unflatten(dim=-1, unflattened_size=self.parameter_shape)) + layers.append(nn.Unflatten(dim=-1, unflattened_size=(self.n_predicted_parameters,))) self.sequential = nn.Sequential(*layers) def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): @@ -242,7 +311,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, n_layers=1) -class ResidualFeedForward(ConditionerTransform): +class ResidualFeedForward(TensorConditionerTransform): class ResidualBlock(nn.Module): def __init__(self, event_size: int, hidden_size: int, block_size: int, nonlinearity: Type[nn.Module]): super().__init__() @@ -276,7 +345,7 @@ def __init__(self, ) if n_hidden is None: - n_hidden = max(int(3 * math.log10(self.n_input_event_dims)), 4) + n_hidden = max(int(5 * math.log10(max(self.n_input_event_dims, self.n_predicted_parameters))), 4) if n_layers <= 2: raise ValueError(f"Number of layers in ResidualFeedForward must be at least 3, but found {n_layers}") @@ -285,7 +354,7 @@ def __init__(self, for _ in range(n_layers - 2): layers.append(self.ResidualBlock(n_hidden, n_hidden, block_size, nonlinearity=nonlinearity)) layers.append(nn.Linear(n_hidden, self.n_predicted_parameters)) - layers.append(nn.Unflatten(dim=-1, unflattened_size=self.parameter_shape)) + layers.append(nn.Unflatten(dim=-1, unflattened_size=(self.n_predicted_parameters,))) self.sequential = nn.Sequential(*layers) def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None): diff --git a/torchflows/bijections/finite/autoregressive/layers.py b/torchflows/bijections/finite/autoregressive/layers.py index b70d717..727ca5d 100644 --- a/torchflows/bijections/finite/autoregressive/layers.py +++ b/torchflows/bijections/finite/autoregressive/layers.py @@ -6,14 +6,16 @@ from torchflows.bijections.finite.autoregressive.conditioning.coupling_masks import make_coupling from torchflows.bijections.finite.autoregressive.layers_base import MaskedAutoregressiveBijection, \ InverseMaskedAutoregressiveBijection, ElementwiseBijection, CouplingBijection -from torchflows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift +from torchflows.bijections.finite.autoregressive.transformers.linear.affine import Scale, Affine, Shift, InverseAffine from torchflows.bijections.finite.autoregressive.transformers.base import ScalarTransformer from torchflows.bijections.finite.autoregressive.transformers.integration.unconstrained_monotonic_neural_network import \ UnconstrainedMonotonicNeuralNetwork from torchflows.bijections.finite.autoregressive.transformers.spline.linear_rational import LinearRational from torchflows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import RationalQuadratic from torchflows.bijections.finite.autoregressive.transformers.combination.sigmoid import ( - DeepSigmoid + DeepSigmoid, + DenseSigmoid, + DeepDenseSigmoid ) from torchflows.bijections.base import invert @@ -24,6 +26,39 @@ def __init__(self, event_shape, **kwargs): super().__init__(transformer) +class ElementwiseInverseAffine(ElementwiseBijection): + def __init__(self, event_shape, **kwargs): + transformer = InverseAffine(event_shape, **kwargs) + super().__init__(transformer) + + +class ActNorm(ElementwiseInverseAffine): + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, **kwargs) + self.first_training_batch_pass: bool = True + self.value.requires_grad_(False) + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + + :param x: x.shape = (*batch_shape, *event_shape) + :param context: + :return: + """ + if self.training and self.first_training_batch_pass: + batch_shape = x.shape[:-len(self.event_shape)] + n_batch_dims = len(batch_shape) + self.first_training_batch_pass = False + shift = torch.mean(x, dim=list(range(n_batch_dims)))[..., None].to(self.value) + if torch.prod(torch.as_tensor(batch_shape)) == 1: + scale = torch.ones_like(shift) # unit scale if unable to estimate + else: + scale = torch.std(x, dim=list(range(n_batch_dims)))[..., None].to(self.value) + unconstrained_scale = self.transformer.unconstrain_scale(scale) + self.value.data = torch.concatenate([unconstrained_scale, shift], dim=-1).data + return super().forward(x, context) + + class ElementwiseScale(ElementwiseBijection): def __init__(self, event_shape, **kwargs): transformer = Scale(event_shape, **kwargs) @@ -149,7 +184,7 @@ def __init__(self, super().__init__(transformer, coupling, conditioner_transform) -class DSCoupling(CouplingBijection): +class DeepSigmoidalCoupling(CouplingBijection): def __init__(self, event_shape: torch.Size, context_shape: torch.Size = None, @@ -175,6 +210,180 @@ def __init__(self, super().__init__(transformer, coupling, conditioner_transform) +class DeepSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_hidden_layers: int = 2, + **kwargs): + transformer: ScalarTransformer = DeepSigmoid( + event_shape=torch.Size(event_shape), + n_hidden_layers=n_hidden_layers + ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + + +class DeepSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_hidden_layers: int = 2, + **kwargs): + transformer: ScalarTransformer = DeepSigmoid( + event_shape=torch.Size(event_shape), + n_hidden_layers=n_hidden_layers + ) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + + +class DenseSigmoidalCoupling(CouplingBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_dense_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + coupling_kwargs: dict = None, + percentage_global_parameters: float = 0.8, + **kwargs): + if coupling_kwargs is None: + coupling_kwargs = dict() + coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + transformer = DenseSigmoid( + event_shape=torch.Size((coupling.target_event_size,)), + n_dense_layers=n_dense_layers + ) + # Parameter order: [c1, c2, c3, c4, ..., ck] for all components + # Each component has parameter order [a_unc, b, w_unc] + conditioner_transform = FeedForward( + input_event_shape=torch.Size((coupling.source_event_size,)), + parameter_shape=torch.Size(transformer.parameter_shape), + context_shape=context_shape, + **{ + **kwargs, + **dict(percentage_global_parameters=percentage_global_parameters) + } + ) + super().__init__(transformer, coupling, conditioner_transform) + + +class DenseSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_dense_layers: int = 2, + percentage_global_parameters: float = 0.8, + **kwargs): + transformer: ScalarTransformer = DenseSigmoid( + event_shape=torch.Size(event_shape), + n_dense_layers=n_dense_layers + ) + super().__init__( + event_shape, + context_shape, + transformer=transformer, + **{ + **kwargs, + **dict(percentage_global_parameters=percentage_global_parameters) + } + ) + + +class DenseSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_dense_layers: int = 2, + percentage_global_parameters: float = 0.8, + **kwargs): + transformer: ScalarTransformer = DenseSigmoid( + event_shape=torch.Size(event_shape), + n_dense_layers=n_dense_layers + ) + super().__init__( + event_shape, + context_shape, + transformer=transformer, + **{ + **kwargs, + **dict(percentage_global_parameters=percentage_global_parameters) + } + ) + + +class DeepDenseSigmoidalCoupling(CouplingBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_hidden_layers: int = 2, + edge_list: List[Tuple[int, int]] = None, + coupling_kwargs: dict = None, + percentage_global_parameters: float = 0.8, + **kwargs): + if coupling_kwargs is None: + coupling_kwargs = dict() + coupling = make_coupling(event_shape, edge_list, **coupling_kwargs) + transformer = DeepDenseSigmoid( + event_shape=torch.Size((coupling.target_event_size,)), + n_hidden_layers=n_hidden_layers + ) + # Parameter order: [c1, c2, c3, c4, ..., ck] for all components + # Each component has parameter order [a_unc, b, w_unc] + conditioner_transform = FeedForward( + input_event_shape=torch.Size((coupling.source_event_size,)), + parameter_shape=torch.Size(transformer.parameter_shape), + context_shape=context_shape, + **{ + **kwargs, + **dict(percentage_global_parameters=percentage_global_parameters) + } + ) + super().__init__(transformer, coupling, conditioner_transform) + + +class DeepDenseSigmoidalInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_hidden_layers: int = 2, + percentage_global_parameters: float = 0.8, + **kwargs): + transformer: ScalarTransformer = DeepDenseSigmoid( + event_shape=torch.Size(event_shape), + n_hidden_layers=n_hidden_layers + ) + super().__init__( + event_shape, + context_shape, + transformer=transformer, + **{ + **kwargs, + **dict(percentage_global_parameters=percentage_global_parameters) + } + ) + + +class DeepDenseSigmoidalForwardMaskedAutoregressive(MaskedAutoregressiveBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_hidden_layers: int = 2, + percentage_global_parameters: float = 0.8, + **kwargs): + transformer: ScalarTransformer = DeepDenseSigmoid( + event_shape=torch.Size(event_shape), + n_hidden_layers=n_hidden_layers + ) + super().__init__( + event_shape, + context_shape, + transformer=transformer, + **{ + **kwargs, + **dict(percentage_global_parameters=percentage_global_parameters) + } + ) + + class LinearAffineCoupling(AffineCoupling): def __init__(self, event_shape: torch.Size, **kwargs): super().__init__(event_shape, **kwargs, n_layers=1) @@ -244,6 +453,16 @@ def __init__(self, super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) +class LRSInverseMaskedAutoregressive(InverseMaskedAutoregressiveBijection): + def __init__(self, + event_shape: torch.Size, + context_shape: torch.Size = None, + n_bins: int = 8, + **kwargs): + transformer: ScalarTransformer = LinearRational(event_shape=event_shape, n_bins=n_bins) + super().__init__(event_shape, context_shape, transformer=transformer, **kwargs) + + class UMNNMaskedAutoregressive(MaskedAutoregressiveBijection): def __init__(self, event_shape: torch.Size, diff --git a/torchflows/bijections/finite/autoregressive/layers_base.py b/torchflows/bijections/finite/autoregressive/layers_base.py index c75e0f6..fe07d4c 100644 --- a/torchflows/bijections/finite/autoregressive/layers_base.py +++ b/torchflows/bijections/finite/autoregressive/layers_base.py @@ -122,7 +122,7 @@ def __init__(self, **kwargs): conditioner_transform = MADE( input_event_shape=event_shape, - output_event_shape=event_shape, + transformed_event_shape=event_shape, parameter_shape_per_element=transformer.parameter_shape_per_element, context_shape=context_shape, **kwargs diff --git a/torchflows/bijections/finite/autoregressive/transformers/linear/affine.py b/torchflows/bijections/finite/autoregressive/transformers/linear/affine.py index ff006dc..5c0e9a4 100644 --- a/torchflows/bijections/finite/autoregressive/transformers/linear/affine.py +++ b/torchflows/bijections/finite/autoregressive/transformers/linear/affine.py @@ -1,5 +1,5 @@ import math -from typing import Tuple +from typing import Tuple, Union import torch @@ -30,9 +30,15 @@ def parameter_shape_per_element(self): def default_parameters(self) -> torch.Tensor: return torch.zeros(self.parameter_shape) + def constrain_scale(self, unconstrained_scale: torch.Tensor) -> torch.Tensor: + return torch.exp(self.identity_unconstrained_alpha + unconstrained_scale / self.const) + self.m + + def unconstrain_scale(self, scale: torch.Tensor) -> torch.Tensor: + return (torch.log(scale - self.m) - self.identity_unconstrained_alpha) * self.const + def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: u_alpha = h[..., 0] - alpha = torch.exp(self.identity_unconstrained_alpha + u_alpha / self.const) + self.m + alpha = self.constrain_scale(u_alpha) log_alpha = torch.log(alpha) u_beta = h[..., 1] @@ -43,7 +49,7 @@ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: u_alpha = h[..., 0] - alpha = torch.exp(self.identity_unconstrained_alpha + u_alpha / self.const) + self.m + alpha = self.constrain_scale(u_alpha) log_alpha = torch.log(alpha) u_beta = h[..., 1] @@ -53,6 +59,26 @@ def inverse(self, z: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch return (z - beta) / alpha, log_det +class InverseAffine(Affine): + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs): + super().__init__(event_shape, **kwargs) + + def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return super().inverse(x, h) + + def inverse(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return super().forward(x, h) + + +class SafeAffine(Affine): + """ + Affine transformer with minimum scale 0.1 for numerical stability. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, min_scale=0.1) + + class Affine2(ScalarTransformer): """ Affine transformer with near-identity initialization. diff --git a/torchflows/bijections/finite/autoregressive/transformers/linear/convolution.py b/torchflows/bijections/finite/autoregressive/transformers/linear/convolution.py index 741327a..eca8832 100644 --- a/torchflows/bijections/finite/autoregressive/transformers/linear/convolution.py +++ b/torchflows/bijections/finite/autoregressive/transformers/linear/convolution.py @@ -2,24 +2,24 @@ import torch from torchflows.bijections.finite.autoregressive.transformers.base import TensorTransformer from torchflows.bijections.finite.autoregressive.transformers.linear.matrix import LUTransformer -from torchflows.utils import sum_except_batch, get_batch_shape +from torchflows.utils import get_batch_shape -class Invertible1x1Convolution(TensorTransformer): +class Invertible1x1ConvolutionTransformer(TensorTransformer): """ Invertible 1x1 convolution. - This transformer receives as input a batch of images x with x.shape (*batch_shape, *image_dimensions, channels) and - parameters h for an invertible linear transform of the channels - with h.shape = (*batch_shape, *image_dimensions, *parameter_shape). - Note that image_dimensions can be a shape with arbitrarily ordered dimensions (height, width). - In fact, it is not required that the image is two-dimensional. Voxels with shape (height, width, depth, channels) - are also supported, as well as tensors with more general shapes. + This transformer receives as input a batch of images x with x.shape `(*batch_shape, channels, *image_dimensions)` + and parameters h for an invertible linear transform of the channels + with h.shape = `(*batch_shape, *parameter_shape)`. + Note that `image_dimensions` can be a shape with arbitrarily ordered dimensions. + In fact, it is not required that the image is two-dimensional. Voxels with shape `(channels, height, width, depth)` + are also supported, as well as tensors with more general shapes. """ def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]]): super().__init__(event_shape) - *self.image_dimensions, self.n_channels = event_shape + self.n_channels, *self.image_dimensions = event_shape self.invertible_linear: TensorTransformer = LUTransformer(event_shape=(self.n_channels,)) @property @@ -33,14 +33,33 @@ def default_parameters(self) -> torch.Tensor: def apply_linear(self, inputs: torch.Tensor, h: torch.Tensor, forward: bool): batch_shape = get_batch_shape(inputs, self.event_shape) + n_batch_dims = len(batch_shape) + n_image_dims = len(self.image_dimensions) + + # (*batch_shape, n_channels, *image_dims) -> (*image_dims, *batch_shape, n_channels) + inputs = torch.permute( + inputs, + ( + *list(range(n_batch_dims + 1, n_batch_dims + 1 + n_image_dims)), # image_dims is moved to the start + *list(range(n_batch_dims)), # batch_shape is moved to the middle + n_batch_dims # n_channels is moved to the end + ) + ) + # Apply linear transformation along channel dimension if forward: outputs, log_det = self.invertible_linear.forward(inputs, h) else: outputs, log_det = self.invertible_linear.inverse(inputs, h) - log_det = sum_except_batch( - log_det.view(*batch_shape, *self.image_dimensions), - event_shape=self.image_dimensions + # outputs and log_det need to be permuted now. + + outputs = torch.permute( + outputs, + ( + *list(range(n_image_dims, n_image_dims + n_batch_dims)), # batch_shape is moved to the start + n_image_dims + n_batch_dims, # n_channels is moved to the middle, + *list(range(n_image_dims)), # image_dims is moved to the end + ) ) return outputs, log_det diff --git a/torchflows/bijections/finite/autoregressive/transformers/linear/matrix.py b/torchflows/bijections/finite/autoregressive/transformers/linear/matrix.py index 96448c1..55e2a26 100644 --- a/torchflows/bijections/finite/autoregressive/transformers/linear/matrix.py +++ b/torchflows/bijections/finite/autoregressive/transformers/linear/matrix.py @@ -37,12 +37,12 @@ def extract_matrices(self, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, batch_shape = h.shape[:-len(self.parameter_shape)] - upper = torch.zeros(size=(*batch_shape, event_size, event_size)) + upper = torch.zeros(size=(*batch_shape, event_size, event_size)).to(h) upper_row_index, upper_col_index = torch.triu_indices(row=event_size, col=event_size, offset=1) upper[..., upper_row_index, upper_col_index] = u_off_diagonal_elements upper[..., range(event_size), range(event_size)] = u_diag - lower = torch.zeros(size=(*batch_shape, event_size, event_size)) + lower = torch.zeros(size=(*batch_shape, event_size, event_size)).to(h) lower_row_index, lower_col_index = torch.tril_indices(row=event_size, col=event_size, offset=-1) lower[..., lower_row_index, lower_col_index] = l_off_diagonal_elements lower[..., range(event_size), range(event_size)] = 1 # Unit diagonal diff --git a/torchflows/bijections/finite/linear.py b/torchflows/bijections/finite/linear.py index 1635765..26b8fcd 100644 --- a/torchflows/bijections/finite/linear.py +++ b/torchflows/bijections/finite/linear.py @@ -39,7 +39,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. z = self.matrix.project(x) z = unflatten_batch(unflatten_event(z, self.event_shape), batch_shape) - log_det = self.matrix.log_det() + torch.zeros(size=batch_shape, device=x.device) + log_det = self.matrix.log_det() + torch.zeros(size=batch_shape, device=x.device).to(x) return z, log_det def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: @@ -49,7 +49,7 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. x = self.matrix.solve(z) x = unflatten_batch(unflatten_event(x, self.event_shape), batch_shape) - log_det = -self.matrix.log_det() + torch.zeros(size=batch_shape, device=z.device) + log_det = -self.matrix.log_det() + torch.zeros(size=batch_shape, device=z.device).to(z) return x, log_det diff --git a/torchflows/bijections/finite/multiscale/architectures.py b/torchflows/bijections/finite/multiscale/architectures.py index 2a6dea4..bee149c 100644 --- a/torchflows/bijections/finite/multiscale/architectures.py +++ b/torchflows/bijections/finite/multiscale/architectures.py @@ -1,8 +1,8 @@ +from typing import Union, Tuple + import torch -from torchflows.bijections.base import BijectiveComposition -from torchflows.bijections.finite.autoregressive.layers import ElementwiseAffine -from torchflows.bijections.finite.autoregressive.transformers.linear.affine import Affine, Shift +from torchflows.bijections.finite.autoregressive.transformers.linear.affine import Shift, Affine from torchflows.bijections.finite.autoregressive.transformers.spline.rational_quadratic import RationalQuadratic from torchflows.bijections.finite.autoregressive.transformers.spline.linear import Linear as LinearRational from torchflows.bijections.finite.autoregressive.transformers.combination.sigmoid import ( @@ -10,7 +10,8 @@ DeepDenseSigmoid, DenseSigmoid ) -from torchflows.bijections.finite.multiscale.base import MultiscaleBijection, FactoredBijection +from torchflows.bijections.finite.multiscale.base import MultiscaleBijection, GlowCheckerboardCoupling, \ + GlowChannelWiseCoupling def check_image_shape_for_multiscale_flow(event_shape, n_layers): @@ -43,220 +44,152 @@ def automatically_determine_n_layers(event_shape): return n_layers -def make_factored_image_layers(event_shape, - transformer_class, - n_layers: int = None, - **kwargs): - """ - Creates a list of image transformations consisting of coupling layers and squeeze layers. - After each coupling, squeeze, coupling mapping, half of the channels are kept as is (not transformed anymore). - - :param event_shape: (c, 2^n, 2^m). - :param transformer_class: - :param n_layers: - :return: - """ - check_image_shape_for_multiscale_flow(event_shape, n_layers) - if n_layers is None: - n_layers = automatically_determine_n_layers(event_shape) - check_image_shape_for_multiscale_flow(event_shape, n_layers) - - def recursive_layer_builder(event_shape_, n_layers_): - msb = MultiscaleBijection( - input_event_shape=event_shape_, - transformer_class=transformer_class, - **kwargs - ) - if n_layers_ == 1: - return msb - - c, h, w = msb.transformed_shape # c is a multiple of 4 after squeezing - - small_bijection_shape = (c // 2, h, w) - small_bijection_mask = (torch.arange(c) >= c // 2)[:, None, None].repeat(1, h, w) - fb = FactoredBijection( - event_shape=(c, h, w), - small_bijection=recursive_layer_builder( - event_shape_=small_bijection_shape, - n_layers_=n_layers_ - 1 - ), - small_bijection_mask=small_bijection_mask - ) - composition = BijectiveComposition( - event_shape=msb.event_shape, - layers=[msb, fb] - ) - composition.transformed_shape = fb.transformed_shape - return composition - - bijections = [ElementwiseAffine(event_shape=event_shape)] - bijections.append(recursive_layer_builder(bijections[-1].transformed_shape, n_layers)) - bijections.append(ElementwiseAffine(event_shape=bijections[-1].transformed_shape)) - return bijections - - -def make_image_layers_non_factored(event_shape, - transformer_class, - n_layers: int = None, - **kwargs): - """ - Returns a list of bijections for transformations of images with multiple channels. - - Let n be the number of layers. This sequence of bijections takes as input an image with shape (c, h, w) and outputs - an image with shape (4 ** n * c, h / 2 ** n, w / 2 ** n). We require h and w to be divisible by 2 ** n. - """ - check_image_shape_for_multiscale_flow(event_shape, n_layers) - if n_layers is None: - n_layers = automatically_determine_n_layers(event_shape) - check_image_shape_for_multiscale_flow(event_shape, n_layers) - - bijections = [ElementwiseAffine(event_shape=event_shape)] - for _ in range(n_layers - 1): - bijections.append( - MultiscaleBijection( - input_event_shape=bijections[-1].transformed_shape, - transformer_class=transformer_class, - **kwargs - ) - ) - bijections.append( - MultiscaleBijection( - input_event_shape=bijections[-1].transformed_shape, - transformer_class=transformer_class, - n_checkerboard_layers=0, - squeeze_layer=False, - n_channel_wise_layers=2, - **kwargs - ) - ) - bijections.append(ElementwiseAffine(event_shape=bijections[-1].transformed_shape)) - return bijections - - -def make_image_layers(*args, factored: bool = False, **kwargs): - if factored: - return make_factored_image_layers(*args, **kwargs) - else: - return make_image_layers_non_factored(*args, **kwargs) - - -class MultiscaleRealNVP(BijectiveComposition): +class MultiscaleRealNVP(MultiscaleBijection): """Multiscale version of Real NVP. Reference: Dinh et al. "Density estimation using Real NVP" (2017); https://arxiv.org/abs/1605.08803. """ + def __init__(self, - event_shape, + event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, - factored: bool = False, - use_resnet: bool = False, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, Affine, n_layers, factored=factored, use_resnet=use_resnet) - super().__init__(event_shape, bijections, **kwargs) - self.transformed_shape = bijections[-1].transformed_shape + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + super().__init__( + event_shape=event_shape, + transformer_class=Affine, + n_blocks=n_layers, + **kwargs + ) -class MultiscaleNICE(BijectiveComposition): +class MultiscaleNICE(MultiscaleBijection): """Multiscale version of NICE. References: - Dinh et al. "NICE: Non-linear Independent Components Estimation" (2015); https://arxiv.org/abs/1410.8516. - Dinh et al. "Density estimation using Real NVP" (2017); https://arxiv.org/abs/1605.08803. """ - def __init__(self, - event_shape, - n_layers: int = None, - factored: bool = False, - use_resnet: bool = False, - **kwargs): + + def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, Shift, n_layers, factored=factored, use_resnet=use_resnet) - super().__init__(event_shape, bijections, **kwargs) - self.transformed_shape = bijections[-1].transformed_shape + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + super().__init__( + event_shape=event_shape, + transformer_class=Shift, + n_blocks=n_layers, + **kwargs + ) -class MultiscaleRQNSF(BijectiveComposition): +class MultiscaleRQNSF(MultiscaleBijection): """Multiscale version of C-RQNSF. References: - Durkan et al. "Neural Spline Flows" (2019); https://arxiv.org/abs/1906.04032. - Dinh et al. "Density estimation using Real NVP" (2017); https://arxiv.org/abs/1605.08803. """ - def __init__(self, - event_shape, - n_layers: int = None, - factored: bool = False, - use_resnet: bool = False, - **kwargs): + + def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, RationalQuadratic, n_layers, factored=factored, - use_resnet=use_resnet) - super().__init__(event_shape, bijections, **kwargs) - self.transformed_shape = bijections[-1].transformed_shape + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + super().__init__( + event_shape=event_shape, + transformer_class=RationalQuadratic, + n_blocks=n_layers, + **kwargs + ) -class MultiscaleLRSNSF(BijectiveComposition): +class MultiscaleLRSNSF(MultiscaleBijection): """Multiscale version of C-LRS. References: - Dolatabadi et al. "Invertible Generative Modeling using Linear Rational Splines" (2020); https://arxiv.org/abs/2001.05168. - Dinh et al. "Density estimation using Real NVP" (2017); https://arxiv.org/abs/1605.08803. """ - def __init__(self, - event_shape, - n_layers: int = None, - factored: bool = False, - use_resnet: bool = False, - **kwargs): + + def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, LinearRational, n_layers, factored=factored, use_resnet=use_resnet) - super().__init__(event_shape, bijections, **kwargs) - self.transformed_shape = bijections[-1].transformed_shape + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + super().__init__( + event_shape=event_shape, + transformer_class=LinearRational, + n_blocks=n_layers, + **kwargs + ) -class MultiscaleDeepSigmoid(BijectiveComposition): - def __init__(self, - event_shape, - n_layers: int = None, - factored: bool = False, - use_resnet: bool = False, - **kwargs): +class MultiscaleDeepSigmoid(MultiscaleBijection): + def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs): + if isinstance(event_shape, int): + event_shape = (event_shape,) + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + super().__init__( + event_shape=event_shape, + transformer_class=DeepSigmoid, + n_blocks=n_layers, + **kwargs + ) + + +class MultiscaleDeepDenseSigmoid(MultiscaleBijection): + def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, DeepSigmoid, n_layers, factored=factored, use_resnet=use_resnet) - super().__init__(event_shape, bijections, **kwargs) - self.transformed_shape = bijections[-1].transformed_shape + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + super().__init__( + event_shape=event_shape, + transformer_class=DeepDenseSigmoid, + n_blocks=n_layers, + **kwargs + ) -class MultiscaleDeepDenseSigmoid(BijectiveComposition): - def __init__(self, - event_shape, - n_layers: int = None, - factored: bool = False, - use_resnet: bool = False, - **kwargs): +class MultiscaleDenseSigmoid(MultiscaleBijection): + def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, DeepDenseSigmoid, n_layers, factored=factored, - use_resnet=use_resnet) - super().__init__(event_shape, bijections, **kwargs) - self.transformed_shape = bijections[-1].transformed_shape + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + super().__init__( + event_shape=event_shape, + transformer_class=DenseSigmoid, + n_blocks=n_layers, + **kwargs + ) -class MultiscaleDenseSigmoid(BijectiveComposition): - def __init__(self, - event_shape, - n_layers: int = None, - factored: bool = False, - use_resnet: bool = False, - **kwargs): +class AffineGlow(MultiscaleBijection): + def __init__(self, event_shape: Union[int, torch.Size, Tuple[int, ...]], n_layers: int = None, **kwargs): if isinstance(event_shape, int): event_shape = (event_shape,) - bijections = make_image_layers(event_shape, DenseSigmoid, n_layers, factored=factored, use_resnet=use_resnet) - super().__init__(event_shape, bijections, **kwargs) - self.transformed_shape = bijections[-1].transformed_shape + if n_layers is None: + n_layers = automatically_determine_n_layers(event_shape) + check_image_shape_for_multiscale_flow(event_shape, n_layers) + super().__init__( + event_shape=event_shape, + transformer_class=Affine, + checkerboard_class=GlowCheckerboardCoupling, + channel_wise_class=GlowChannelWiseCoupling, + n_blocks=n_layers, + **kwargs + ) diff --git a/torchflows/bijections/finite/multiscale/base.py b/torchflows/bijections/finite/multiscale/base.py index 365c5d6..b88a9d1 100644 --- a/torchflows/bijections/finite/multiscale/base.py +++ b/torchflows/bijections/finite/multiscale/base.py @@ -1,114 +1,21 @@ -from typing import Type, Union, Tuple +from typing import Type, Union, Tuple, List import torch +import torch.nn as nn -from torchflows.bijections.finite.autoregressive.conditioning.transforms import ConditionerTransform from torchflows.bijections.base import Bijection, BijectiveComposition +from torchflows.bijections.finite.autoregressive.layers import ActNorm from torchflows.bijections.finite.autoregressive.layers_base import CouplingBijection from torchflows.bijections.finite.autoregressive.transformers.base import TensorTransformer +from torchflows.bijections.finite.autoregressive.transformers.linear.convolution import \ + Invertible1x1ConvolutionTransformer +from torchflows.bijections.finite.multiscale.conditioning.classic import ConvNetConditioner +from torchflows.bijections.finite.multiscale.conditioning.resnet import ResNetConditioner from torchflows.bijections.finite.multiscale.coupling import make_image_coupling, Checkerboard, \ ChannelWiseHalfSplit -from torchflows.neural_networks.convnet import ConvNet -from torchflows.neural_networks.resnet import make_resnet18 from torchflows.utils import get_batch_shape -class FactoredBijection(Bijection): - """ - Factored bijection class. - - Partitions the input tensor x into parts x_A and x_B, then applies a bijection to x_A independently of x_B while - keeping x_B identical. - """ - - def __init__(self, - event_shape: Union[torch.Size, Tuple[int, ...]], - small_bijection: Bijection, - small_bijection_mask: torch.Tensor, - **kwargs): - """ - - :param event_shape: shape of input event x. - :param small_bijection: bijection applied to transformed event x_A. - :param small_bijection_mask: boolean mask that selects which elements of event x correspond to the transformed - event x_A. - :param kwargs: - """ - super().__init__(event_shape, **kwargs) - - # Check that shapes are correct - event_size = torch.prod(torch.as_tensor(event_shape)) - transformed_event_size = torch.prod(torch.as_tensor(small_bijection.event_shape)) - assert event_size >= transformed_event_size - - assert small_bijection_mask.shape == event_shape - - self.transformed_event_mask = small_bijection_mask - self.small_bijection = small_bijection - - def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - batch_shape = get_batch_shape(x, self.event_shape) - transformed, log_det = self.small_bijection.forward( - x[..., self.transformed_event_mask].view(*batch_shape, *self.small_bijection.event_shape), - context - ) - out = x.clone() - out[..., self.transformed_event_mask] = transformed.view(*batch_shape, -1) - return out, log_det - - def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: - batch_shape = get_batch_shape(z, self.event_shape) - transformed, log_det = self.small_bijection.inverse( - z[..., self.transformed_event_mask].view(*batch_shape, *self.small_bijection.transformed_shape), - context - ) - out = z.clone() - out[..., self.transformed_event_mask] = transformed.view(*batch_shape, -1) - return out, log_det - - -class ConvNetConditioner(ConditionerTransform): - def __init__(self, - input_event_shape: torch.Size, - parameter_shape: torch.Size, - kernels: Tuple[int, ...] = None, - **kwargs): - super().__init__( - input_event_shape=input_event_shape, - context_shape=None, - parameter_shape=parameter_shape, - **kwargs - ) - self.network = ConvNet( - input_shape=input_event_shape, - n_outputs=self.n_transformer_parameters, - kernels=kernels - ) - - def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: - return self.network(x) - - -class ResNetConditioner(ConditionerTransform): - def __init__(self, - input_event_shape: torch.Size, - parameter_shape: torch.Size, - **kwargs): - super().__init__( - input_event_shape=input_event_shape, - context_shape=None, - parameter_shape=parameter_shape, - **kwargs - ) - self.network = make_resnet18( - image_shape=input_event_shape, - n_outputs=self.n_transformer_parameters - ) - - def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: - return self.network(x) - - class ConvolutionalCouplingBijection(CouplingBijection): def __init__(self, transformer: TensorTransformer, @@ -157,13 +64,13 @@ def set_transformed_part(self, x: torch.Tensor, x_transformed: torch.Tensor): :param x_transformed: tensor with shape (*b, transformed_channels, transformed_height, transformed_width). """ batch_shape = get_batch_shape(x, self.event_shape) - return x[..., self.coupling.target_mask].view(*batch_shape, *self.coupling.transformed_shape) + x[..., self.coupling.target_mask] = x_transformed.reshape(*batch_shape, -1) + return x def partition_and_predict_parameters(self, x: torch.Tensor, context: torch.Tensor): batch_shape = get_batch_shape(x, self.event_shape) super_out = super().partition_and_predict_parameters(x, context) - return super_out.view(*batch_shape, *self.coupling.transformed_shape, - *self.transformer.parameter_shape_per_element) + return super_out.view(*batch_shape, *self.transformer.parameter_shape) class CheckerboardCoupling(ConvolutionalCouplingBijection): @@ -174,12 +81,44 @@ def __init__(self, **kwargs): coupling = make_image_coupling( event_shape, - coupling_type='checkerboard' if not alternate else 'checkerboard_inverted' + coupling_type='checkerboard' if not alternate else 'checkerboard_inverted', ) transformer = transformer_class(event_shape=coupling.transformed_shape) super().__init__(transformer, coupling, **kwargs) +class NormalizedCheckerboardCoupling(BijectiveComposition): + def __init__(self, event_shape, **kwargs): + layers = [ + ActNorm(event_shape), + CheckerboardCoupling(event_shape, **kwargs), + ] + super().__init__(event_shape, layers) + + +class Invertible1x1ConvolutionalCoupling(ConvolutionalCouplingBijection): + def __init__(self, + event_shape, + alternate: bool = False, + **kwargs): + coupling = make_image_coupling( + event_shape, + coupling_type='channel_wise' if not alternate else 'channel_wise_inverted', + ) + transformer = Invertible1x1ConvolutionTransformer(event_shape=coupling.transformed_shape) + super().__init__(transformer, coupling, **kwargs) + + +class GlowCheckerboardCoupling(BijectiveComposition): + def __init__(self, event_shape, **kwargs): + layers = [ + ActNorm(event_shape), + Invertible1x1ConvolutionalCoupling(event_shape, **kwargs), + CheckerboardCoupling(event_shape, **kwargs) + ] + super().__init__(event_shape, layers) + + class ChannelWiseCoupling(ConvolutionalCouplingBijection): def __init__(self, event_shape, @@ -194,6 +133,25 @@ def __init__(self, super().__init__(transformer, coupling, **kwargs) +class NormalizedChannelWiseCoupling(BijectiveComposition): + def __init__(self, event_shape, **kwargs): + layers = [ + ActNorm(event_shape), + ChannelWiseCoupling(event_shape, **kwargs), + ] + super().__init__(event_shape, layers) + + +class GlowChannelWiseCoupling(BijectiveComposition): + def __init__(self, event_shape, **kwargs): + layers = [ + ActNorm(event_shape), + Invertible1x1ConvolutionalCoupling(event_shape), + ChannelWiseCoupling(event_shape, **kwargs) + ] + super().__init__(event_shape, layers) + + class Squeeze(Bijection): """ Squeeze a batch of tensors with shape (*batch_shape, channels, height, width) into shape @@ -219,7 +177,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. (*batch_shape, 4 * channels, height // 2, width // 2). """ batch_shape = get_batch_shape(x, self.event_shape) - log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype) + log_det = torch.zeros(*batch_shape, device=x.device, dtype=x.dtype).to(x) channels, height, width = x.shape[-3:] assert height % 2 == 0 @@ -239,7 +197,7 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. (*batch_shape, channels, height, width). """ batch_shape = get_batch_shape(z, self.transformed_event_shape) - log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype) + log_det = torch.zeros(*batch_shape, device=z.device, dtype=z.dtype).to(z) four_channels, half_height, half_width = z.shape[-3:] assert four_channels % 4 == 0 @@ -247,7 +205,7 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. height = 2 * half_height channels = four_channels // 4 - out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype) + out = torch.empty(size=(*batch_shape, channels, height, width), device=z.device, dtype=z.dtype).to(z) out[..., ::2, ::2] = z[..., 0:channels, :, :] out[..., ::2, 1::2] = z[..., channels:2 * channels, :, :] out[..., 1::2, ::2] = z[..., 2 * channels:3 * channels, :, :] @@ -255,51 +213,122 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. return out, log_det -class MultiscaleBijection(BijectiveComposition): - """ - Multiscale bijection class. Used for efficient image modeling. Inherits from BijectiveComposition. - """ +class MultiscaleBijection(Bijection): def __init__(self, - input_event_shape, + event_shape: Union[torch.Size, Tuple[int, ...]], transformer_class: Type[TensorTransformer], - n_checkerboard_layers: int = 2, - n_channel_wise_layers: int = 2, - use_squeeze_layer: bool = True, + n_blocks: int, + n_checkerboard_layers: int = 3, + n_channel_wise_layers: int = 3, use_resnet: bool = False, + checkerboard_class: Union[ + Type[CheckerboardCoupling], + Type[NormalizedCheckerboardCoupling], + Type[GlowCheckerboardCoupling] + ] = NormalizedCheckerboardCoupling, + channel_wise_class: Union[ + Type[ChannelWiseCoupling], + Type[NormalizedChannelWiseCoupling], + Type[GlowChannelWiseCoupling] + ] = NormalizedChannelWiseCoupling, + first_layer: bool = True, **kwargs): - """ - MultiscaleBijection constructor. - - :param input_event_shape: shape of event tensor. - :param TensorTransformer transformer_class: type of transformer. - :param int n_checkerboard_layers: number of checkerboard coupling layers. - :param int n_channel_wise_layers: number of channel wise coupling layers. - :param bool use_squeeze_layer: if True, use a squeeze layer. - :param bool use_resnet: if True, use ResNet as the conditioner network. - :param kwargs: keyword arguments for BijectiveComposition superclass constructor. - """ - checkerboard_layers = [ - CheckerboardCoupling( - input_event_shape, - transformer_class, + if n_blocks < 1: + raise ValueError + super().__init__(event_shape, **kwargs) + + self.n_blocks = n_blocks + + if first_layer and checkerboard_class == GlowCheckerboardCoupling: + layer_checkerboard_class = NormalizedCheckerboardCoupling # Compatibility with single channel images + else: + layer_checkerboard_class = checkerboard_class + self.checkerboard_layers = nn.ModuleList([ + layer_checkerboard_class( + event_shape, + transformer_class=transformer_class, alternate=i % 2 == 1, conditioner='resnet' if use_resnet else 'convnet' ) - for i in range(n_checkerboard_layers) - ] - squeeze_layer = Squeeze(input_event_shape) - channel_wise_layers = [ - ChannelWiseCoupling( - squeeze_layer.transformed_event_shape, - transformer_class, - alternate=i % 2 == 1, - conditioner='resnet' if use_resnet else 'convnet' + for i in range(n_checkerboard_layers + (0 if n_blocks > 1 else 1)) + ]) + + if self.n_blocks > 1: + self.squeeze = Squeeze(event_shape) + self.channel_wise_layers = nn.ModuleList([ + channel_wise_class( + self.squeeze.transformed_event_shape, + transformer_class=transformer_class, + alternate=i % 2 == 1, + conditioner='resnet' if use_resnet else 'convnet' + ) + for i in range(n_channel_wise_layers) + ]) + + self.alt_squeeze = Squeeze(event_shape, alternate=True) + + small_event_shape = ( + self.alt_squeeze.transformed_event_shape[0] // 2, + *self.alt_squeeze.transformed_event_shape[1:] ) - for i in range(n_channel_wise_layers) - ] - if use_squeeze_layer: - layers = [*checkerboard_layers, squeeze_layer, *channel_wise_layers] - else: - layers = [*checkerboard_layers, *channel_wise_layers] - super().__init__(input_event_shape, layers, **kwargs) - self.transformed_shape = squeeze_layer.transformed_event_shape if use_squeeze_layer else input_event_shape + self.small_bijection = MultiscaleBijection( + event_shape=small_event_shape, + transformer_class=transformer_class, + n_blocks=self.n_blocks - 1, + first_layer=False, + **kwargs + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + log_det = torch.zeros(size=get_batch_shape(x, event_shape=self.event_shape)).to(x) + + # Propagate through checkerboard layers + for layer in self.checkerboard_layers: + x, log_det_layer = layer.forward(x, context=context) + log_det += log_det_layer + + if self.n_blocks > 1: + # Propagate through channel-wise layers + x, _ = self.squeeze.forward(x, context=context) + for layer in self.channel_wise_layers: + x, log_det_layer = layer.forward(x, context=context) + log_det += log_det_layer + x, _ = self.squeeze.inverse(x, context=context) + + # Chunk and apply small bijection + x, _ = self.alt_squeeze.forward(x, context=context) + x_const, x_rest = torch.chunk(x, 2, dim=-3) # channel dimension split (..., c, h, w) + x_rest, log_det_layer = self.small_bijection.forward(x_rest, context=context) + log_det += log_det_layer + x = torch.cat((x_const, x_rest), dim=-3) # channel dimension concatenation + x, _ = self.alt_squeeze.inverse(x, context=context) + + z = x + return z, log_det + + def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + log_det = torch.zeros(size=get_batch_shape(z, event_shape=self.event_shape)).to(z) + + if self.n_blocks > 1: + # Chunk and apply small bijection + z, _ = self.alt_squeeze.forward(z, context=context) + z_const, z_rest = torch.chunk(z, 2, dim=-3) # channel dimension split (..., c, h, w) + z_rest, log_det_layer = self.small_bijection.inverse(z_rest, context=context) + log_det += log_det_layer + z = torch.cat((z_const, z_rest), dim=-3) # channel dimension concatenation + z, _ = self.alt_squeeze.inverse(z, context=context) + + # Propagate through channel-wise layers + z, _ = self.squeeze.forward(z, context=context) + for layer in self.channel_wise_layers[::-1]: + z, log_det_layer = layer.inverse(z, context=context) + log_det += log_det_layer + z, _ = self.squeeze.inverse(z, context=context) + + # Propagate through checkerboard layers + for layer in self.checkerboard_layers[::-1]: + z, log_det_layer = layer.inverse(z, context=context) + log_det += log_det_layer + + x = z + return x, log_det diff --git a/torchflows/neural_networks/convnet.py b/torchflows/bijections/finite/multiscale/conditioning/classic.py similarity index 78% rename from torchflows/neural_networks/convnet.py rename to torchflows/bijections/finite/multiscale/conditioning/classic.py index 58e5596..e5da0f6 100644 --- a/torchflows/neural_networks/convnet.py +++ b/torchflows/bijections/finite/multiscale/conditioning/classic.py @@ -1,9 +1,10 @@ -import math from typing import Tuple import torch import torch.nn as nn +from torchflows.bijections.finite.autoregressive.conditioning.transforms import TensorConditionerTransform + class ConvModifier(nn.Module): """ @@ -57,7 +58,10 @@ def __init__(self, in_channels, out_channels, input_height, input_width, use_poo def forward(self, x): return self.bn(self.pool(torch.relu(self.conv(x)))) - def __init__(self, input_shape, n_outputs: int, kernels: Tuple[int, ...] = None): + def __init__(self, + input_shape, + n_outputs: int, + kernels: Tuple[int, ...] = None): """ :param input_shape: (channels, height, width) @@ -118,10 +122,25 @@ def forward(self, x): return x -if __name__ == '__main__': - torch.manual_seed(0) - im_shape = (1, 36, 29) - images = torch.randn(size=(11, *im_shape)) - net = ConvNet(input_shape=im_shape, n_outputs=77) - out = net(images) - print(f'{out.shape = }') +class ConvNetConditioner(TensorConditionerTransform): + def __init__(self, + input_event_shape: torch.Size, + parameter_shape: torch.Size, + kernels: Tuple[int, ...] = None, + **kwargs): + super().__init__( + input_event_shape=input_event_shape, + context_shape=None, + parameter_shape=parameter_shape, + output_lower_bound=-2.0, + output_upper_bound=2.0, + **kwargs + ) + self.network = ConvNet( + input_shape=input_event_shape, + n_outputs=self.n_transformer_parameters, + kernels=kernels + ) + + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + return self.network(x) diff --git a/torchflows/bijections/finite/multiscale/conditioning/resnet.py b/torchflows/bijections/finite/multiscale/conditioning/resnet.py new file mode 100644 index 0000000..d2e735b --- /dev/null +++ b/torchflows/bijections/finite/multiscale/conditioning/resnet.py @@ -0,0 +1,147 @@ +from typing import Tuple + +from torchflows.bijections.finite.autoregressive.conditioning.transforms import TensorConditionerTransform +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchflows.bijections.finite.multiscale.conditioning.classic import ConvModifier + + +class BasicResidualBlock(nn.Module): + """ + Basic residual block. Keeps image height and width the same. + """ + + def __init__(self, in_channels, hidden_channels): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(hidden_channels) + self.conv2 = nn.Conv2d(hidden_channels, in_channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(in_channels) + + def h(self, x): + y = F.relu(self.bn1(self.conv1(x))) + z = F.relu(self.bn2(self.conv2(y))) + return z + + def forward(self, x): + return x + self.h(x) + + +class BasicResidualBlockGroup(nn.Module): + def __init__(self, in_channels: int, n_blocks: int, hidden_channels: int = 16): + super().__init__() + self.blocks = nn.ModuleList([BasicResidualBlock(in_channels, hidden_channels) for _ in range(n_blocks)]) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + +class BottleneckBlock(nn.Module): + """ + Doubles the number of channels, halves height and width. + """ + + def __init__(self, in_channels: int): + super().__init__() + self.conv2d = nn.Conv2d( + in_channels, + in_channels * 2, + kernel_size=3, + stride=2, + padding=1 + ) + + def forward(self, x): + return self.conv2d(x) + + +class ResNet(nn.Module): + """ + ResNet class. + """ + + def __init__(self, + c, + h, + w, + hidden_size: int = 100, + n_outputs: int = 10, + n_blocks: Tuple[int, int, int] = (1, 1, 1)): + """ + + :param c: number of input image channels. + :param h: input image height. + :param w: input image width. + :param hidden_size: number of hidden units at the last linear layer. + :param n_outputs: number of outputs. + """ + super(ResNet, self).__init__() + self.modifier = ConvModifier((c, h, w), c_target=4, h_target=32, w_target=32) # to `(4, 32, 32)` + + self.stage1 = BasicResidualBlockGroup(4, n_blocks=n_blocks[0]) # (4, 32, 32) + self.down1 = BottleneckBlock(4) # (8, 16, 16) + + self.stage2 = BasicResidualBlockGroup(8, n_blocks=n_blocks[1]) # (8, 16, 16) + self.down2 = BottleneckBlock(8) # (16, 8, 8) + + self.stage3 = BasicResidualBlockGroup(16, n_blocks=n_blocks[2]) # (16, 8, 8) + self.down3 = BottleneckBlock(16) # (32, 4, 4), note: 32 * 4 * 4 = 512 (for linear layer) + + self.linear1 = nn.Linear(512, hidden_size) # 32 * 4 * 4 = 512 + self.linear2 = nn.Linear(hidden_size, n_outputs) + + def forward(self, x): + """ + :param x: tensor with shape (*b, channels, height, width). + :return: + """ + batch_shape = x.shape[:-3] + + out = self.modifier(x) + + out = self.stage1(out) + out = self.down1(out) + + out = self.stage2(out) + out = self.down2(out) + + out = self.stage3(out) + out = self.down3(out) + + out = self.linear1(out.view(*batch_shape, 512)) + out = F.leaky_relu(out) + out = self.linear2(out) + return out + + +class ResNetConditioner(TensorConditionerTransform): + def __init__(self, + input_event_shape: torch.Size, + parameter_shape: torch.Size, + **kwargs): + super().__init__( + input_event_shape=input_event_shape, + context_shape=None, + parameter_shape=parameter_shape, + output_lower_bound=-2.0, + output_upper_bound=2.0, + **kwargs + ) + self.network = ResNet( + *input_event_shape, + n_outputs=self.n_transformer_parameters + ) + + def predict_theta_flat(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + return self.network(x) + +if __name__ == '__main__': + torch.manual_seed(0) + x = torch.randn((15, 3, 77, 13)) + rn = ResNet(3, 77, 13, n_outputs=7) + y = rn(x) + print(y.shape) \ No newline at end of file diff --git a/torchflows/bijections/finite/multiscale/coupling.py b/torchflows/bijections/finite/multiscale/coupling.py index 7087b0f..785ff2e 100644 --- a/torchflows/bijections/finite/multiscale/coupling.py +++ b/torchflows/bijections/finite/multiscale/coupling.py @@ -8,7 +8,7 @@ class Checkerboard(Coupling): Checkerboard coupling for image data. """ - def __init__(self, event_shape, invert: bool = False): + def __init__(self, event_shape, invert: bool = False, **kwargs): """ :param event_shape: image shape with the form (n_channels, height, width). Note: width and height must be equal and a power of two. @@ -43,6 +43,9 @@ def __init__(self, event_shape, invert: bool = False): :param invert: invert the checkerboard mask. """ n_channels, height, width = event_shape + if n_channels <= 1: + raise ValueError("Number of channels must be at least 2") + mask = torch.as_tensor(torch.arange(start=0, end=n_channels) < (n_channels // 2)) mask = mask[:, None, None].repeat(1, height, width) # (channels, height, width) if invert: diff --git a/torchflows/bijections/finite/multiscale/layers.py b/torchflows/bijections/finite/multiscale/layers.py index 99c170f..3c2fcb6 100644 --- a/torchflows/bijections/finite/multiscale/layers.py +++ b/torchflows/bijections/finite/multiscale/layers.py @@ -3,5 +3,5 @@ class MultiscaleAffineCoupling(MultiscaleBijection): - def __init__(self, input_event_shape, **kwargs): - super().__init__(input_event_shape, transformer_class=Affine, **kwargs) + def __init__(self, event_shape, **kwargs): + super().__init__(event_shape, transformer_class=Affine, **kwargs) diff --git a/torchflows/bijections/finite/residual/architectures.py b/torchflows/bijections/finite/residual/architectures.py index d44194d..d14e392 100644 --- a/torchflows/bijections/finite/residual/architectures.py +++ b/torchflows/bijections/finite/residual/architectures.py @@ -3,11 +3,11 @@ import torch from torchflows.bijections.base import BijectiveComposition -from torchflows.bijections.finite.autoregressive.transformers.linear.affine import Affine +from torchflows.bijections.finite.autoregressive.layers import ElementwiseAffine from torchflows.bijections.finite.residual.base import ResidualComposition from torchflows.bijections.finite.residual.iterative import InvertibleResNetBlock, ResFlowBlock from torchflows.bijections.finite.residual.proximal import ProximalResFlowBlock -from torchflows.bijections.finite.residual.planar import Planar +from torchflows.bijections.finite.residual.planar import Planar, InversePlanar from torchflows.bijections.finite.residual.radial import Radial from torchflows.bijections.finite.residual.sylvester import Sylvester @@ -17,7 +17,8 @@ class InvertibleResNet(ResidualComposition): Reference: Behrmann et al. "Invertible Residual Networks" (2019); https://arxiv.org/abs/1811.00995. """ - def __init__(self, event_shape, context_shape=None, n_layers: int = 16, **kwargs): + + def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): blocks = [ InvertibleResNetBlock(event_shape=event_shape, context_shape=context_shape, **kwargs) for _ in range(n_layers) @@ -30,7 +31,8 @@ class ResFlow(ResidualComposition): Reference: Chen et al. "Residual Flows for Invertible Generative Modeling" (2020); https://arxiv.org/abs/1906.02735. """ - def __init__(self, event_shape, context_shape=None, n_layers: int = 16, **kwargs): + + def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): blocks = [ ResFlowBlock(event_shape=event_shape, context_shape=context_shape, **kwargs) for _ in range(n_layers) @@ -43,7 +45,8 @@ class ProximalResFlow(ResidualComposition): Reference: Hertrich "Proximal Residual Flows for Bayesian Inverse Problems" (2022); https://arxiv.org/abs/2211.17158. """ - def __init__(self, event_shape, context_shape=None, n_layers: int = 16, **kwargs): + + def __init__(self, event_shape, context_shape=None, n_layers: int = 2, **kwargs): blocks = [ ProximalResFlowBlock(event_shape=event_shape, context_shape=context_shape, gamma=0.01, **kwargs) for _ in range(n_layers) @@ -58,13 +61,17 @@ class PlanarFlow(BijectiveComposition): Reference: Rezende and Mohamed "Variational Inference with Normalizing Flows" (2016); https://arxiv.org/abs/1505.05770. """ - def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2): + + def __init__(self, + event_shape: Union[torch.Size, Tuple[int, ...]], + n_layers: int = 2, + inverse: bool = True): if n_layers < 1: raise ValueError(f"Flow needs at least one layer, but got {n_layers}") super().__init__(event_shape, [ - Affine(event_shape), - *[Planar(event_shape) for _ in range(n_layers)], - Affine(event_shape) + ElementwiseAffine(event_shape), + *[(InversePlanar if inverse else Planar)(event_shape) for _ in range(n_layers)], + ElementwiseAffine(event_shape) ]) @@ -75,13 +82,14 @@ class RadialFlow(BijectiveComposition): Reference: Rezende and Mohamed "Variational Inference with Normalizing Flows" (2016); https://arxiv.org/abs/1505.05770. """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2): if n_layers < 1: raise ValueError(f"Flow needs at least one layer, but got {n_layers}") super().__init__(event_shape, [ - Affine(event_shape), + ElementwiseAffine(event_shape), *[Radial(event_shape) for _ in range(n_layers)], - Affine(event_shape) + ElementwiseAffine(event_shape) ]) @@ -92,11 +100,12 @@ class SylvesterFlow(BijectiveComposition): Reference: Van den Berg et al. "Sylvester Normalizing Flows for Variational Inference" (2019); https://arxiv.org/abs/1803.05649. """ + def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_layers: int = 2, **kwargs): if n_layers < 1: raise ValueError(f"Flow needs at least one layer, but got {n_layers}") super().__init__(event_shape, [ - Affine(event_shape), + ElementwiseAffine(event_shape), *[Sylvester(event_shape, **kwargs) for _ in range(n_layers)], - Affine(event_shape) + ElementwiseAffine(event_shape) ]) diff --git a/torchflows/bijections/finite/residual/base.py b/torchflows/bijections/finite/residual/base.py index 2598a62..55b0e61 100644 --- a/torchflows/bijections/finite/residual/base.py +++ b/torchflows/bijections/finite/residual/base.py @@ -4,7 +4,7 @@ from torchflows.bijections.finite.autoregressive.layers import ElementwiseAffine from torchflows.bijections.base import Bijection, BijectiveComposition -from torchflows.utils import get_batch_shape, unflatten_event, flatten_event +from torchflows.utils import get_batch_shape, unflatten_event, flatten_event, flatten_batch, unflatten_batch class ResidualBijection(Bijection): @@ -26,14 +26,18 @@ def forward(self, context: torch.Tensor = None, skip_log_det: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(x, self.event_shape) - z = x + unflatten_event(self.g(flatten_event(x, self.event_shape)), self.event_shape) + x_flat = flatten_batch(flatten_event(x, self.event_shape), batch_shape) + g_flat = self.g(x_flat) + g = unflatten_event(unflatten_batch(g_flat, batch_shape), self.event_shape) + + z = x + g if skip_log_det: log_det = torch.full(size=batch_shape, fill_value=torch.nan) else: - x_flat = flatten_event(x, self.event_shape).clone() + x_flat = flatten_batch(flatten_event(x, self.event_shape).clone(), batch_shape) x_flat.requires_grad_(True) - log_det = -self.log_det(x_flat, training=self.training) + log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape) return z, log_det @@ -45,14 +49,18 @@ def inverse(self, batch_shape = get_batch_shape(z, self.event_shape) x = z for _ in range(n_iterations): - x = z - unflatten_event(self.g(flatten_event(x, self.event_shape)), self.event_shape) + x_flat = flatten_batch(flatten_event(x, self.event_shape), batch_shape) + g_flat = self.g(x_flat) + g = unflatten_event(unflatten_batch(g_flat, batch_shape), self.event_shape) + + x = z - g if skip_log_det: log_det = torch.full(size=batch_shape, fill_value=torch.nan) else: - x_flat = flatten_event(x, self.event_shape).clone() + x_flat = flatten_batch(flatten_event(x, self.event_shape).clone(), batch_shape) x_flat.requires_grad_(True) - log_det = -self.log_det(x_flat, training=self.training) + log_det = -unflatten_batch(self.log_det(x_flat, training=self.training), batch_shape) return x, log_det diff --git a/torchflows/bijections/finite/residual/planar.py b/torchflows/bijections/finite/residual/planar.py index b6d59d3..1edc957 100644 --- a/torchflows/bijections/finite/residual/planar.py +++ b/torchflows/bijections/finite/residual/planar.py @@ -41,10 +41,10 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. z = z.view(*batch_shape, self.n_dim) # x = z + u * self.h(w.T @ z + self.b) - x = z + u * self.h(torch.einsum('...i,...i', w, z) + self.b) + x = z + u * self.h(torch.einsum('...i,...i', w, z) + self.b)[..., None] # phi = self.h_deriv(w.T @ z + self.b) * w - phi = self.h_deriv(torch.einsum('...i,...i', w, z) + self.b) * w + phi = w * self.h_deriv(torch.einsum('...i,...i', w, z) + self.b)[..., None] # log_det = torch.log(torch.abs(1 + u.T @ phi)) log_det = torch.log(torch.abs(1 + torch.einsum('...i,...i', u, phi))) diff --git a/torchflows/bijections/finite/residual/proximal.py b/torchflows/bijections/finite/residual/proximal.py index 916081a..7b71d44 100644 --- a/torchflows/bijections/finite/residual/proximal.py +++ b/torchflows/bijections/finite/residual/proximal.py @@ -70,8 +70,8 @@ def __init__(self, event_size: int, hidden_size: int, act: ProximityOperator): # Initialize t_tilde close to identity divisor = max(self.event_size ** 2, 100) - self.b = nn.Parameter(torch.randn(self.hidden_size) / divisor) - self.delta_t_tilde = nn.Parameter(torch.randn(self.hidden_size, self.event_size) / divisor) + self.b = nn.Parameter(torch.randn(size=(self.hidden_size,)) / divisor) + self.delta_t_tilde = nn.Parameter(torch.randn(size=(self.hidden_size, self.event_size)) / divisor) self.act = act @property @@ -109,7 +109,7 @@ def __init__(self, event_size: int, n_layers: int = 1, hidden_size: int = None, if act is None: act = TanH() if hidden_size is None: - hidden_size = max(math.log(event_size), 4) + hidden_size = int(max(math.log(event_size), 4)) super().__init__(*[PNNBlock(event_size, hidden_size, act) for _ in range(n_layers)]) self.n_layers = n_layers self.act = act diff --git a/torchflows/bijections/finite/residual/radial.py b/torchflows/bijections/finite/residual/radial.py index 736b320..385e291 100644 --- a/torchflows/bijections/finite/residual/radial.py +++ b/torchflows/bijections/finite/residual/radial.py @@ -30,7 +30,7 @@ def h(self, z): def h_deriv(self, z): batch_shape = z.shape[:-1] z0 = self.z0.view(*([1] * len(batch_shape)), *self.z0.shape) - sign = (-1.0) ** torch.where(z - z0 < 0)[0] + sign = (-1.0) ** torch.less(z, z0).float() return -(self.h(z) ** 2) * sign * z def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: @@ -54,7 +54,7 @@ def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. log_det = torch.abs(torch.add( (self.n_dim - 1) * torch.log1p(beta_times_h_val), torch.log(1 + beta_times_h_val + self.h_deriv(z) * r) - )) + )).sum(dim=-1) x = x.view(*batch_shape, *self.event_shape) return x, log_det diff --git a/torchflows/bijections/finite/residual/sylvester.py b/torchflows/bijections/finite/residual/sylvester.py index 24ba10e..5208448 100644 --- a/torchflows/bijections/finite/residual/sylvester.py +++ b/torchflows/bijections/finite/residual/sylvester.py @@ -18,6 +18,8 @@ def __init__(self, if m is None: m = self.n_dim // 2 + if m > self.n_dim: + raise ValueError self.m = m self.b = nn.Parameter(torch.randn(m)) @@ -29,13 +31,13 @@ def __init__(self, @property def w(self): r_tilde = self.r_tilde.mat() - q = self.q.mat() + q = self.q.mat()[:, :self.m] return torch.einsum('...ij,...kj->...ik', r_tilde, q) @property def u(self): r = self.r.mat() - q = self.q.mat() + q = self.q.mat()[:, :self.m] return torch.einsum('...ij,...jk->...ik', q, r) def h(self, x): @@ -49,21 +51,21 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch. def inverse(self, z: torch.Tensor, context: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_shape = get_batch_shape(z, self.event_shape) + z_flat = torch.flatten(z, start_dim=len(batch_shape)) u = self.u.view(*([1] * len(batch_shape)), *self.u.shape) w = self.w.view(*([1] * len(batch_shape)), *self.w.shape) b = self.b.view(*([1] * len(batch_shape)), *self.b.shape) - wzpb = torch.einsum('...ij,...j->...i', w, z) + b # (..., m) + wzpb = torch.einsum('...ij,...j->...i', w, z_flat) + b # (..., m) - z = z.view(*batch_shape, self.n_dim) - x = z + torch.einsum( + x = z_flat + torch.einsum( '...ij,...j->...i', u, self.h(wzpb) ) wu = torch.einsum('...ij,...jk->...ik', w, u) # (..., m, m) - diag = torch.zeros(size=(batch_shape, self.m, self.m)) + diag = torch.zeros(size=(*batch_shape, self.m, self.m)) diag[..., range(self.m), range(self.m)] = self.h_deriv(wzpb) # (..., m, m) _, log_det = torch.linalg.slogdet(torch.eye(self.m) + torch.einsum('...ij,...jk->...ik', diag, wu)) diff --git a/torchflows/flows.py b/torchflows/flows.py index 8cc8362..22bde3c 100644 --- a/torchflows/flows.py +++ b/torchflows/flows.py @@ -112,8 +112,8 @@ def fit(self, if batch_size is None: batch_size = len(x_train) elif isinstance(batch_size, str) and batch_size == "adaptive": - min_batch_size = 32 - max_batch_size = 4096 + min_batch_size = max(32, min(1024, len(x_train) // 100)) + max_batch_size = min(4096, len(x_train) // 10) batch_size_adaptation_interval = 10 # double the batch size every 10 epochs adaptive_batch_size = True batch_size = min_batch_size @@ -156,11 +156,10 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): return batch_loss - iterator = tqdm(range(n_epochs), desc='Fitting NF', disable=not show_progress) optimizer = torch.optim.AdamW(self.parameters(), lr=lr) val_loss = None - for epoch in iterator: + for epoch in (pbar := tqdm(range(n_epochs), desc='Fitting NF', disable=not show_progress)): if ( adaptive_batch_size and epoch % batch_size_adaptation_interval == batch_size_adaptation_interval - 1 @@ -195,20 +194,24 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): for train_batch in train_loader: optimizer.zero_grad() train_loss = compute_batch_loss(train_batch, reduction=torch.mean) + if not torch.isfinite(train_loss): + raise ValueError("Flow training diverged") train_loss += self.regularization() + if not torch.isfinite(train_loss): + raise ValueError("Flow training diverged") train_loss.backward() optimizer.step() if show_progress: if val_loss is None: - iterator.set_postfix_str(f'Training loss (batch): {train_loss:.4f}') + pbar.set_postfix_str(f'Training loss (batch): {train_loss:.4f}') elif early_stopping: - iterator.set_postfix_str( + pbar.set_postfix_str( f'Training loss (batch): {train_loss:.4f}, ' f'Validation loss: {val_loss:.4f} [best: {best_val_loss:.4f} @ {best_epoch}]' ) else: - iterator.set_postfix_str( + pbar.set_postfix_str( f'Training loss (batch): {train_loss:.4f}, ' f'Validation loss: {val_loss:.4f}' ) @@ -252,7 +255,8 @@ def variational_fit(self, early_stopping: bool = False, early_stopping_threshold: int = 50, keep_best_weights: bool = True, - show_progress: bool = False): + show_progress: bool = False, + check_for_divergences: bool = False): """Train the normalizing flow to fit a target log probability. Stochastic variational inference lets us train a distribution using the unnormalized target log density instead of a fixed dataset. @@ -272,31 +276,63 @@ def variational_fit(self, self.train() + flow_training_diverged = False optimizer = torch.optim.AdamW(self.parameters(), lr=lr) best_loss = torch.inf best_epoch = 0 + initial_weights = deepcopy(self.state_dict()) best_weights = deepcopy(self.state_dict()) + n_divergences = 0 for epoch in (pbar := tqdm(range(n_epochs), desc='Fitting with SVI', disable=not show_progress)): + if check_for_divergences and not all([torch.isfinite(p).all() for p in self.parameters()]): + flow_training_diverged = True + print('Flow training diverged') + print('Reverting to initial weights') + break + optimizer.zero_grad() flow_x, flow_log_prob = self.sample(n_samples, return_log_prob=True) - loss = -torch.mean(target_log_prob(flow_x) + flow_log_prob) + target_log_prob_value = target_log_prob(flow_x) + loss = -torch.mean(target_log_prob_value + flow_log_prob) loss += self.regularization() - loss.backward() - optimizer.step() - if loss < best_loss: - best_loss = loss - best_epoch = epoch - if keep_best_weights: - best_weights = deepcopy(self.state_dict()) + epoch_diverged = False + if check_for_divergences: + if not torch.isfinite(loss): + epoch_diverged = True + if torch.max(torch.abs(flow_x)) > 1e8: + epoch_diverged = True + elif torch.max(torch.abs(flow_log_prob)) > 1e6: + epoch_diverged = True + elif torch.any(~torch.isfinite(flow_x)): + epoch_diverged = True + elif torch.any(~torch.isfinite(flow_log_prob)): + epoch_diverged = True + n_divergences += epoch_diverged + + if not epoch_diverged: + loss.backward() + optimizer.step() + if loss < best_loss: + best_loss = loss + best_epoch = epoch + if keep_best_weights: + best_weights = deepcopy(self.state_dict()) + else: + loss = torch.nan - pbar.set_postfix_str(f'Loss: {loss:.4f} [best: {best_loss:.4f} @ {best_epoch}]') + pbar.set_postfix_str(f'Loss: {loss:.4f} [best: {best_loss:.4f} @ {best_epoch}], ' + f'divergences: {n_divergences}, ' + f'flow log_prob: {flow_log_prob.mean():.2f}, ' + f'target log_prob: {target_log_prob_value.mean():.2f}') if epoch - best_epoch > early_stopping_threshold and early_stopping: break - if keep_best_weights: + if flow_training_diverged: + self.load_state_dict(initial_weights) + elif keep_best_weights: self.load_state_dict(best_weights) self.eval() @@ -376,10 +412,10 @@ def sample(self, if no_grad: z = z.detach() with torch.no_grad(): - x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape), + x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.event_shape), context=context) else: - x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.transformed_shape), + x, log_det = self.bijection.inverse(z.view(*sample_shape, *self.bijection.event_shape), context=context) x = x.to(self.get_device()) diff --git a/torchflows/neural_networks/resnet.py b/torchflows/neural_networks/resnet.py deleted file mode 100644 index 435c3a9..0000000 --- a/torchflows/neural_networks/resnet.py +++ /dev/null @@ -1,167 +0,0 @@ -# https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, in_planes, planes, stride=1): - super(BasicBlock, self).__init__() - self.conv1 = nn.Conv2d( - in_planes, - planes, - kernel_size=3, - stride=stride, - padding=1, - bias=False - ) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d( - planes, - planes, - kernel_size=3, - stride=1, - padding=1, - bias=False - ) - self.bn2 = nn.BatchNorm2d(planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion * planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion * planes, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion * planes) - ) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - out += self.shortcut(x) - out = F.relu(out) - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, in_planes, planes, stride=1): - super(Bottleneck, self).__init__() - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, - stride=stride, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, self.expansion * - planes, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(self.expansion * planes) - - self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion * planes: - self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, self.expansion * planes, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(self.expansion * planes) - ) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = F.relu(self.bn2(self.conv2(out))) - out = self.bn3(self.conv3(out)) - out += self.shortcut(x) - out = F.relu(out) - return out - - -class ResNet(nn.Module): - """ - ResNet class. - """ - - def __init__(self, c, h, w, block, num_blocks, n_hidden=100, n_outputs=10): - """ - - :param c: number of input image channels. - :param h: input image height. - :param w: input image width. - :param block: block class for ResNet. - :param num_blocks: List of block numbers for each of the four layers. - :param n_hidden: number of hidden units at the last linear layer. - :param n_outputs: number of outputs. - """ - if h % 4 != 0: - raise ValueError('Image height must be divisible by 4.') - if w % 4 != 0: - raise ValueError('Image width must be divisible by 4.') - - super(ResNet, self).__init__() - self.in_planes = 64 - - self.conv1 = nn.Conv2d(c, 64, kernel_size=3, - stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(64) - self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) - self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=1) - self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=1) - self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=1) - self.linear1 = nn.Linear(512 * h * w // 16, n_hidden) - self.linear2 = nn.Linear(n_hidden, n_outputs) - - def _make_layer(self, block, planes, num_blocks, stride): - strides = [stride] + [1] * (num_blocks - 1) - layers = [] - for stride in strides: - layers.append(block(self.in_planes, planes, stride)) - self.in_planes = planes * block.expansion - return nn.Sequential(*layers) - - def forward(self, x): - """ - :param x: tensor with shape (*b, channels, height, width). Height and width must be equal. - :return: - """ - out = F.relu(self.bn1(self.conv1(x))) - out = self.layer1(out) - out = self.layer2(out) - out = self.layer3(out) - out = self.layer4(out) - out = F.avg_pool2d(out, 4) - out = out.flatten(start_dim=1, end_dim=-1) - out = F.relu(self.linear1(out)) - out = self.linear2(out) - return out - - -def make_resnet18(image_shape, n_outputs): - return ResNet(*image_shape, BasicBlock, num_blocks=[2, 2, 2, 2], n_outputs=n_outputs) - - -def make_resnet34(image_shape, n_outputs): - return ResNet(*image_shape, BasicBlock, num_blocks=[3, 4, 6, 3], n_outputs=n_outputs) - - -def make_resnet50(image_shape, n_outputs): - # TODO fix error regarding image shape - return ResNet(*image_shape, Bottleneck, num_blocks=[3, 4, 6, 3], n_outputs=n_outputs) - - -def make_resnet101(image_shape, n_outputs): - # TODO fix error regarding image shape - return ResNet(*image_shape, Bottleneck, num_blocks=[3, 4, 23, 3], n_outputs=n_outputs) - - -def make_resnet152(image_shape, n_outputs): - # TODO fix error regarding image shape - return ResNet(*image_shape, Bottleneck, num_blocks=[3, 8, 36, 3], n_outputs=n_outputs) - - -if __name__ == '__main__': - n_images = 2 - event_shape = (5, 8 * 7, 4 * 7) - - net = make_resnet18(event_shape, 15) - y = net(torch.randn(n_images, *event_shape)) - print(y.size())