diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 768e729c7..7ddf86478 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ on: jobs: lint_and_typecheck: - if: ${{ github.event.name == 'push' || github.event.label.name == 'run-ci' }} + if: ${{ github.event_name == 'push' || github.event.label.name == 'run-ci' }} runs-on: ubuntu-latest steps: diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 000000000..6b8d160ac --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,20 @@ +name: Deploy docs to GitHub Pages + +on: + push: + branches: + - main + +jobs: + deploy: + name: Deploy docs + runs-on: ubuntu-latest + steps: + - name: Checkout main + uses: actions/checkout@v2 + + - name: Deploy MkDocs + uses: mhausenblas/mkdocs-deploy-gh-pages@master + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + REQUIREMENTS: ./requirements.docs.txt diff --git a/README.md b/README.md index 620a37d72..dad2e4cf7 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ from PIL import Image from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL from refiners.foundationals.latent_diffusion import SDXLIPAdapter, SDXLT2IAdapter -from refiners.fluxion.utils import manual_seed, image_to_tensor, load_from_safetensors +from refiners.fluxion.utils import manual_seed, no_grad, image_to_tensor, load_from_safetensors # Load inputs init_image = Image.open("dropy_logo.png") @@ -122,22 +122,13 @@ t2i_adapter.set_scale(0.8) sdxl.set_num_inference_steps(50) sdxl.set_self_attention_guidance(enable=True, scale=0.75) -with torch.no_grad(): +with no_grad(): # Note: default text prompts for IP-Adapter clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality" ) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt)) - - negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) - negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) - - clip_text_embedding = torch.cat( - ( - torch.cat([negative_text_embedding, negative_image_embedding], dim=1), - torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), - ) - ) + ip_adapter.set_clip_image_embedding(clip_image_embedding) time_ids = sdxl.default_time_ids condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype) diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 000000000..9a1ce1c92 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,3 @@ +# Refiners - Docs + +WIP diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 000000000..ced975126 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,4 @@ +site_name: Refiners + +theme: + name: material diff --git a/notebooks/basics.ipynb b/notebooks/basics.ipynb new file mode 100644 index 000000000..04aca2524 --- /dev/null +++ b/notebooks/basics.ipynb @@ -0,0 +1,1574 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Refiners Demo\n", + "\n", + "This notebook aims to demonstrate the basics of using the [Refiners](https://github.com/finegrain-ai/refiners) micro-framework.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# to run you need to have `Refiners` installed (uncomment the line below)\n", + "# %pip install git+https://github.com/finegrain-ai/refiners.git" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from refiners.fluxion import layers as fl, manual_seed\n", + "from torch import nn\n", + "\n", + "torch.set_grad_enabled(mode=False)\n", + "manual_seed(82570858)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Basics" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The core idea of Refiners is to improve on the `Sequential` API of PyTorch.\n", + "\n", + "A `Sequential` is defined by:\n", + "\n", + "`Sequential([layer1, layer2, layer3])(x) = layer3(layer2(layer1(x)))`" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Sequential(\n", + " (0): Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n", + " (1): ReLU()\n", + " (2): Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n", + " (3): ReLU()\n", + " (4): Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n", + " (5): ReLU()\n", + " (6): Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n", + ")" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Native PyTorch sequential\n", + "sequential = nn.Sequential(\n", + " fl.Conv2d(3, 32, 3, padding=1),\n", + " nn.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + " nn.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + " nn.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + ")\n", + "\n", + "sequential" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(CHAIN)\n", + " ├── Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #1\n", + " ├── ReLU() #1\n", + " ├── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #2\n", + " ├── ReLU() #2\n", + " ├── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #3\n", + " ├── ReLU() #3\n", + " └── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #4" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Same as above, but with a Fluxion Chain\n", + "chain = fl.Chain(\n", + " fl.Conv2d(3, 32, 3, padding=1),\n", + " fl.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + " fl.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + " fl.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + ")\n", + "\n", + "chain" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note here that the keys of the Chain are the names of the layers, whereas in PyTorch Sequential API, the keys are the indices of the layers.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sequential keys:\n", + "0\n", + "1\n", + "2\n", + "3\n", + "4\n", + "5\n", + "6\n", + "\n", + "Chain keys:\n", + "Conv2d_1\n", + "ReLU_1\n", + "Conv2d_2\n", + "ReLU_2\n", + "Conv2d_3\n", + "ReLU_3\n", + "Conv2d_4\n" + ] + } + ], + "source": [ + "print(\"Sequential keys:\")\n", + "for key, _ in sequential.named_children():\n", + " print(key)\n", + "\n", + "print(\"\\nChain keys:\")\n", + "for key, _ in chain.named_children():\n", + " print(key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This choice is made because when a model is simple, it is easy to remember the indices of the layers, but when a model is complex, it is hard to remember the indices of the layers.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We also improved on the Errors to showcase exactly where the error is coming from.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.randn(1, 4, 32, 32)\n", + "# uncomment to run\n", + "# sequential(x)" + ] + }, + { + "attachments": { + "image.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image.png](attachment:image.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# uncomment to run\n", + "# chain(x)" + ] + }, + { + "attachments": { + "image.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![image.png](attachment:image.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Sequential is excellent for building basic and straightforward models, but most models don't have such a simple linear structure. \n", + "\n", + "Let's say you want to add a skip connection to the ConvNet." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 32, 32, 32])\n" + ] + }, + { + "data": { + "text/plain": [ + "ConvNet(\n", + " (sequential): Sequential(\n", + " (0): Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n", + " (1): ReLU()\n", + " (2): Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n", + " (3): ReLU()\n", + " (4): Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n", + " (5): ReLU()\n", + " (6): Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n", + " )\n", + " (skip): Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=(1, 1))\n", + ")" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ConvNet with a residual connection in PyTorch\n", + "\n", + "\n", + "class ConvNet(nn.Module):\n", + " def __init__(self) -> None:\n", + " super().__init__()\n", + " self.sequential = nn.Sequential(\n", + " fl.Conv2d(3, 32, 3, padding=1),\n", + " nn.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + " nn.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + " nn.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + " )\n", + " self.skip = fl.Conv2d(3, 32, 3, padding=1)\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " return self.sequential(x) + self.skip(x)\n", + "\n", + "\n", + "convnet = ConvNet()\n", + "x = torch.randn(1, 3, 32, 32)\n", + "print(convnet(x).shape)\n", + "convnet" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `repr` of this PyTorch is not declarative anymore: you cannot know how the model works.\n", + "\n", + "You can use Refiners' predefined `Chain` subclasses to handle such cases and build more complex models. \n", + "\n", + "Let's start with the `Sum` class.\n", + "\n", + "`fl.Sum([layer1, layer2, layer3])(x) = layer1(x) + layer2(x) + layer3(x)`" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 32, 32, 32])\n" + ] + }, + { + "data": { + "text/plain": [ + "(SUM)\n", + " ├── (CHAIN)\n", + " │ ├── Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #1\n", + " │ ├── ReLU() #1\n", + " │ ├── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #2\n", + " │ ├── ReLU() #2\n", + " │ ├── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #3\n", + " │ ├── ReLU() #3\n", + " │ └── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #4\n", + " └── Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=(1, 1))" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ConvNet with a residual connection in Refiners\n", + "convnet = fl.Sum(\n", + " fl.Chain(\n", + " fl.Conv2d(3, 32, 3, padding=1),\n", + " fl.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + " fl.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + " fl.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + " ),\n", + " fl.Conv2d(3, 32, 3, padding=1),\n", + ")\n", + "\n", + "x = torch.randn(1, 3, 32, 32)\n", + "print(convnet(x).shape)\n", + "convnet" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can subclass the basics `Chain` to give a name to improve declarativity. The `repr` will still tell you which kind of `Chain` it is and the name of the `Chain`." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(SUM) ResidualNet()\n", + " ├── (CHAIN) ConvNet()\n", + " │ ├── Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #1\n", + " │ ├── ReLU() #1\n", + " │ ├── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #2\n", + " │ ├── ReLU() #2\n", + " │ ├── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #3\n", + " │ ├── ReLU() #3\n", + " │ └── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1)) #4\n", + " └── Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=(1, 1))" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class ConvNet(fl.Chain):\n", + " def __init__(self) -> None:\n", + " super().__init__(\n", + " fl.Conv2d(3, 32, 3, padding=1),\n", + " fl.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + " fl.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + " fl.ReLU(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + " )\n", + "\n", + "\n", + "class ResidualNet(fl.Sum):\n", + " def __init__(self) -> None:\n", + " super().__init__(\n", + " ConvNet(),\n", + " fl.Conv2d(32, 32, 3, padding=1),\n", + " )\n", + "\n", + "\n", + "ResidualNet()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here are some examples of `Chain` subclasses:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[-0.4723, -0.2809]]), tensor([[-0.5384, 0.6123, 0.2659, 0.0916]]))" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Run layers in parallel to output a tuple\n", + "par = fl.Parallel(\n", + " fl.Linear(2, 2),\n", + " fl.Linear(2, 4),\n", + ")\n", + "\n", + "x = torch.randn(1, 2)\n", + "par(x) # (Linear_1(x), Linear_2(x))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[-1.0949, 0.0749, 0.2607, 0.4013]])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Run layers in parallel and then concatenate the outputs\n", + "cat = fl.Concatenate(\n", + " fl.Linear(2, 2),\n", + " fl.Linear(2, 2),\n", + " dim=-1,\n", + ")\n", + "\n", + "x = torch.randn(1, 2)\n", + "cat(x) # Concatenate((Linear_1(x), Linear_2(x)), dim=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[-0.1487, -1.1180]])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Run sequentially layers and then add the input\n", + "residual = fl.Residual(\n", + " fl.Linear(2, 2),\n", + " fl.Linear(2, 2),\n", + ")\n", + "\n", + "x = torch.randn(1, 2)\n", + "residual(x) # Linear_2(Linear_1(x)) + x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's now build something more complex such as a Vision Transformer. \n", + "\n", + "Let's start with the heart of a transformer layer: the Multi-Head Attention." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 197, 128])\n" + ] + }, + { + "data": { + "text/plain": [ + "(RES) Attention()\n", + " ├── (PAR)\n", + " │ └── Linear(in_features=128, out_features=128) (x3)\n", + " ├── ScaledDotProductAttention(num_heads=8)\n", + " └── Linear(in_features=128, out_features=128)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from refiners.fluxion.layers.attentions import ScaledDotProductAttention\n", + "\n", + "\n", + "class Attention(fl.Residual):\n", + " def __init__(self, dim: int = 128, num_heads: int = 8) -> None:\n", + " self.dim = dim\n", + " self.num_heads = num_heads\n", + " super().__init__(\n", + " fl.Parallel(\n", + " fl.Linear(dim, dim),\n", + " fl.Linear(dim, dim),\n", + " fl.Linear(dim, dim),\n", + " ),\n", + " ScaledDotProductAttention(num_heads=num_heads),\n", + " fl.Linear(dim, dim),\n", + " )\n", + "\n", + "\n", + "x = torch.randn(1, 197, 128)\n", + "attention = Attention()\n", + "print(attention(x).shape)\n", + "attention" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(RES) FeedForward()\n", + " ├── Linear(in_features=128, out_features=512) #1\n", + " ├── SiLU()\n", + " └── Linear(in_features=512, out_features=128) #2" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class FeedForward(fl.Residual):\n", + " def __init__(self, dim: int = 128, inner_dim: int = 512) -> None:\n", + " self.dim = dim\n", + " self.inner_dim = inner_dim\n", + " super().__init__(\n", + " fl.Linear(dim, inner_dim),\n", + " fl.SiLU(),\n", + " fl.Linear(inner_dim, dim),\n", + " )\n", + "\n", + "\n", + "FeedForward()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(CHAIN) TranformerLayer()\n", + " ├── LayerNorm(normalized_shape=(128,)) #1\n", + " ├── (RES) Attention()\n", + " │ ├── (PAR)\n", + " │ │ └── Linear(in_features=128, out_features=128) (x3)\n", + " │ ├── ScaledDotProductAttention(num_heads=8)\n", + " │ └── Linear(in_features=128, out_features=128)\n", + " ├── LayerNorm(normalized_shape=(128,)) #2\n", + " └── (RES) FeedForward()\n", + " ├── Linear(in_features=128, out_features=512) #1\n", + " ├── SiLU()\n", + " └── Linear(in_features=512, out_features=128) #2" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class TranformerLayer(fl.Chain):\n", + " def __init__(\n", + " self, dim: int = 128, num_heads: int = 8, inner_dim: int = 512\n", + " ) -> None:\n", + " self.dim = dim\n", + " self.num_heads = num_heads\n", + " self.inner_dim = inner_dim\n", + " super().__init__(\n", + " fl.LayerNorm(dim),\n", + " Attention(dim, num_heads),\n", + " fl.LayerNorm(dim),\n", + " FeedForward(dim, inner_dim),\n", + " )\n", + "\n", + "\n", + "TranformerLayer()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 196, 128])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class PatchEncoder(fl.Chain):\n", + " def __init__(\n", + " self, in_channels: int = 3, dim: int = 128, patch_size: int = 16\n", + " ) -> None:\n", + " self.in_channels = in_channels\n", + " self.dim = dim\n", + " self.patch_size = patch_size\n", + " super().__init__(\n", + " fl.Conv2d(\n", + " in_channels=in_channels,\n", + " out_channels=dim,\n", + " kernel_size=patch_size,\n", + " stride=patch_size,\n", + " ),\n", + " fl.Reshape(-1, dim), # Reshape always preserves the batch dimension\n", + " )\n", + "\n", + "\n", + "x = torch.randn(1, 3, 224, 224)\n", + "PatchEncoder()(x).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "class PositionalToken(fl.Residual):\n", + " def __init__(self, num_patches: int = 196) -> None:\n", + " self.num_patches = num_patches\n", + " super().__init__(fl.Parameter(num_patches, 128))\n", + "\n", + "\n", + "class ClassToken(fl.Chain):\n", + " def __init__(self, dim: int = 128) -> None:\n", + " self.dim = dim\n", + " super().__init__(fl.Parameter(1, dim))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we have every bit to build a full Vision Transformer." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(CHAIN) ViT()\n", + " ├── (CAT)\n", + " │ ├── (CHAIN) PatchEncoder()\n", + " │ │ ├── Conv2d(in_channels=3, out_channels=128, kernel_size=(16, 16), stride=(16, 16))\n", + " │ │ └── Reshape(shape=(-1, 128))\n", + " │ └── (CHAIN) ClassToken()\n", + " │ └── Parameter(dims=(1, 128))\n", + " ├── (RES) PositionalToken(num_patches=197)\n", + " │ └── Parameter(dims=(197, 128))\n", + " └── (CHAIN) Transformer()\n", + " └── (CHAIN) TranformerLayer() (x4)\n", + " ├── LayerNorm(normalized_shape=(128,)) #1\n", + " ├── (RES) Attention()\n", + " │ ├── (PAR)\n", + " │ │ └── Linear(in_features=128, out_features=128) (x3)\n", + " │ ├── ScaledDotProductAttention(num_heads=8)\n", + " │ └── Linear(in_features=128, out_features=128)\n", + " ├── LayerNorm(normalized_shape=(128,)) #2\n", + " └── (RES) FeedForward()\n", + " ├── Linear(in_features=128, out_features=512) #1\n", + " ├── SiLU()\n", + " └── Linear(in_features=512, out_features=128) #2\n", + "torch.Size([1, 197, 128])\n" + ] + } + ], + "source": [ + "class Transformer(fl.Chain):\n", + " pass\n", + "\n", + "\n", + "class ViT(fl.Chain):\n", + " def __init__(\n", + " self,\n", + " dim: int = 128,\n", + " patch_size: int = 16,\n", + " image_size: int = 224,\n", + " num_layers: int = 4,\n", + " ) -> None:\n", + " self.dim = dim\n", + " self.patch_size = patch_size\n", + " self.image_size = image_size\n", + " self.num_layers = num_layers\n", + " self.num_patches = (image_size // patch_size) ** 2 + 1\n", + " super().__init__(\n", + " fl.Concatenate(\n", + " PatchEncoder(in_channels=3, dim=dim, patch_size=patch_size),\n", + " ClassToken(dim=dim),\n", + " dim=1,\n", + " ),\n", + " PositionalToken(num_patches=self.num_patches),\n", + " Transformer(TranformerLayer(dim=dim) for _ in range(num_layers)),\n", + " )\n", + "\n", + "\n", + "x = torch.randn(1, 3, 224, 224)\n", + "vit = ViT()\n", + "print(repr(vit))\n", + "print(vit(x).shape)" + ] + }, + { + "attachments": { + "image.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Advanced - Context API\n", + "\n", + "This ViT is still rudimentary and linear: we have one input in and one output out. But often, you want to use multiple inputs/modalities in the flow of your model.\n", + "\n", + "Let's take, for example, the `MaskDecoder` of the Segment Anything model by Meta; it's a Transformer that takes as input an image and a prompt and outputs a segmentation mask. You can prompt it with points to guide the segmentation. [Here is a link](https://github.com/finegrain-ai/refiners/blob/main/src/refiners/foundationals/segment_anything/mask_decoder.py) to the complete implementation in Refiners\n", + "\n", + "So the inputs are:\n", + "\n", + " - an image of shape (3, 224, 224)\n", + " - several points of shape (N, 2) \n", + "\n", + "One way to consider the points is to add a `CrossAttention` layer that will attend to the points from the image. Cross attention is a standard `Attention` layer, but the key and value come from a source different from the query. In our case, the query is the image, the key and value are the points embeddings.\n", + "\n", + "![image.png](attachment:image.png)\n", + "\n", + "\n", + "Let's start by building a point encoder (to simplify that all points have the same \"meaning\")." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 5, 128])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class PointEncoder(fl.Chain):\n", + " def __init__(self, dim: int = 128) -> None:\n", + " self.dim = dim\n", + " super().__init__(\n", + " fl.Linear(2, dim),\n", + " fl.SiLU(),\n", + " fl.Linear(dim, dim),\n", + " fl.SiLU(),\n", + " fl.Linear(dim, dim),\n", + " fl.Unsqueeze(0),\n", + " )\n", + "\n", + "\n", + "points = torch.randn(5, 2)\n", + "PointEncoder()(points).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 197, 128])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Chains can handle multiple inputs\n", + "class CrossAttention(fl.Chain):\n", + " def __init__(self, dim: int = 128, num_heads: int = 8) -> None:\n", + " self.dim = dim\n", + " self.num_heads = num_heads\n", + " super().__init__(\n", + " fl.Parallel(\n", + " fl.GetArg(0),\n", + " fl.GetArg(1),\n", + " fl.GetArg(1),\n", + " ),\n", + " fl.Distribute(\n", + " fl.Linear(dim, dim),\n", + " fl.Linear(dim, dim),\n", + " fl.Linear(dim, dim),\n", + " ),\n", + " ScaledDotProductAttention(num_heads=num_heads),\n", + " fl.Linear(dim, dim),\n", + " )\n", + "\n", + "\n", + "points_embedding = torch.randn(1, 5, 128)\n", + "patch_embedding = torch.randn(1, 197, 128)\n", + "CrossAttention()(patch_embedding, points_embedding).shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now ideally, I would like to insert this `CrossAttention` layer in the middle of the `Transformer` like this:\n", + "\n", + "```python\n", + "class TranformerLayer(fl.Chain):\n", + " def __init__(\n", + " self, dim: int = 128, num_heads: int = 8, inner_dim: int = 512\n", + " ) -> None:\n", + " self.dim = dim\n", + " self.num_heads = num_heads\n", + " self.inner_dim = inner_dim\n", + " super().__init__(\n", + " fl.LayerNorm(dim),\n", + " Attention(dim, num_heads),\n", + " fl.LayerNorm(dim),\n", + " CrossAttention(dim, num_heads),\n", + " fl.LayerNorm(dim),\n", + " FeedForward(dim, inner_dim),\n", + " )\n", + "\n", + "```\n", + "\n", + "But how do the `point_embedding` get into the `CrossAttention` layer? That's where the `Context` API comes into play." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chain = fl.Chain(\n", + " fl.Linear(2, 2),\n", + " fl.Concatenate(\n", + " fl.Linear(2, 2),\n", + " fl.UseContext(\"embedding\", \"value\"),\n", + " dim=-1,\n", + " ),\n", + " fl.Linear(5, 2),\n", + ")\n", + "\n", + "chain.set_context(\"embedding\", {\"value\": torch.randn(1, 3)})\n", + "print(f\"Current embedding context: {chain.use_context('embedding')}\")\n", + "\n", + "x = torch.randn(1, 2)\n", + "chain(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that the context is recursive, so you can access the context of an outer `Chain` from an inner `Chain`.\n", + "\n", + "We can rewrite the `CrossAttention` layer using context instead of passing the `point_embedding` as an argument." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 197, 128])\n" + ] + }, + { + "data": { + "text/plain": [ + "(CHAIN) PointsCrossAttention()\n", + " ├── (PAR)\n", + " │ ├── Identity()\n", + " │ └── UseContext(context=vit, key=points_embedding) (x2)\n", + " ├── (DISTR)\n", + " │ └── Linear(in_features=128, out_features=128) (x3)\n", + " ├── ScaledDotProductAttention(num_heads=8)\n", + " └── Linear(in_features=128, out_features=128)" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class PointsCrossAttention(fl.Chain):\n", + " def __init__(self, dim: int = 128, num_heads: int = 8) -> None:\n", + " self.dim = dim\n", + " self.num_heads = num_heads\n", + " super().__init__(\n", + " fl.Parallel(\n", + " fl.Identity(),\n", + " fl.UseContext(\"vit\", \"points_embedding\"),\n", + " fl.UseContext(\"vit\", \"points_embedding\"),\n", + " ),\n", + " fl.Distribute(\n", + " fl.Linear(dim, dim),\n", + " fl.Linear(dim, dim),\n", + " fl.Linear(dim, dim),\n", + " ),\n", + " ScaledDotProductAttention(num_heads=num_heads),\n", + " fl.Linear(dim, dim),\n", + " )\n", + "\n", + "\n", + "points_cross_attention = PointsCrossAttention()\n", + "\n", + "# If the context is not set, the layer will raise an error\n", + "points_embedding = torch.randn(1, 5, 128)\n", + "points_cross_attention.set_context(\"vit\", {\"points_embedding\": points_embedding})\n", + "\n", + "x = torch.randn(1, 197, 128)\n", + "\n", + "print(points_cross_attention(x).shape)\n", + "points_cross_attention" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's rewrite the `TransformerLayer` using the `PointsCrossAttention` layer.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 197, 128])\n" + ] + } + ], + "source": [ + "class TranformerLayer(fl.Chain):\n", + " def __init__(\n", + " self, dim: int = 128, num_heads: int = 8, inner_dim: int = 512\n", + " ) -> None:\n", + " self.dim = dim\n", + " self.num_heads = num_heads\n", + " self.inner_dim = inner_dim\n", + " super().__init__(\n", + " fl.LayerNorm(dim),\n", + " Attention(dim, num_heads),\n", + " fl.LayerNorm(dim),\n", + " PointsCrossAttention(dim, num_heads),\n", + " fl.LayerNorm(dim),\n", + " FeedForward(dim, inner_dim),\n", + " )\n", + "\n", + "\n", + "layer = TranformerLayer()\n", + "x = torch.randn(1, 197, 128)\n", + "points_embedding = torch.randn(1, 5, 128)\n", + "layer.set_context(\"vit\", {\"points_embedding\": points_embedding})\n", + "print(layer(x).shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The ViT is still valid as is, but we might want to add the `PointEncoder` directly into the model to not have to deal with multiple models separately. To do that, we can wrap the `PointEncoder` into a `Passthrough` layer that will let the main arguments pass through, but will also add the `point_embedding` to the context using a `SetContext` layer." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 197, 128])\n" + ] + }, + { + "data": { + "text/plain": [ + "(CHAIN) ViT()\n", + " ├── (PASS) PointEncoder()\n", + " │ ├── UseContext(context=vit, key=points_tensor)\n", + " │ ├── Linear(in_features=2, out_features=128) #1\n", + " │ ├── SiLU() #1\n", + " │ ├── Linear(in_features=128, out_features=128) #2\n", + " │ ├── SiLU() #2\n", + " │ ├── Linear(in_features=128, out_features=128) #3\n", + " │ ├── Unsqueeze(dim=0)\n", + " │ └── SetContext(context=vit, key=points_embedding)\n", + " ├── (CAT)\n", + " │ ├── (CHAIN) PatchEncoder()\n", + " │ │ ├── Conv2d(in_channels=3, out_channels=128, kernel_size=(16, 16), stride=(16, 16))\n", + " │ │ └── Reshape(shape=(-1, 128))\n", + " │ └── (CHAIN) ClassToken()\n", + " │ └── Parameter(dims=(1, 128))\n", + " ├── (RES) PositionalToken(num_patches=197)\n", + " │ └── Parameter(dims=(197, 128))\n", + " └── (CHAIN) Transformer()\n", + " └── (CHAIN) TranformerLayer() (x4)\n", + " ├── LayerNorm(normalized_shape=(128,)) #1\n", + " ├── (RES) Attention()\n", + " │ ├── (PAR)\n", + " │ │ └── Linear(in_features=128, out_features=128) (x3)\n", + " │ ├── ScaledDotProductAttention(num_heads=8)\n", + " │ └── Linear(in_features=128, out_features=128)\n", + " ├── LayerNorm(normalized_shape=(128,)) #2\n", + " ├── (CHAIN) PointsCrossAttention()\n", + " │ ├── (PAR)\n", + " │ │ ├── Identity()\n", + " │ │ └── UseContext(context=vit, key=points_embedding) (x2)\n", + " │ ├── (DISTR)\n", + " │ │ └── Linear(in_features=128, out_features=128) (x3)\n", + " │ ├── ScaledDotProductAttention(num_heads=8)\n", + " │ └── Linear(in_features=128, out_features=128)\n", + " ├── LayerNorm(normalized_shape=(128,)) #3\n", + " └── (RES) FeedForward()\n", + " ├── Linear(in_features=128, out_features=512) #1\n", + " ├── SiLU()\n", + " └── Linear(in_features=512, out_features=128) #2" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class PointEncoder(fl.Passthrough):\n", + " def __init__(self, dim: int = 128) -> None:\n", + " self.dim = dim\n", + " super().__init__(\n", + " fl.UseContext(\"vit\", \"points_tensor\"),\n", + " fl.Linear(2, dim),\n", + " fl.SiLU(),\n", + " fl.Linear(dim, dim),\n", + " fl.SiLU(),\n", + " fl.Linear(dim, dim),\n", + " fl.Unsqueeze(0),\n", + " fl.SetContext(\"vit\", \"points_embedding\"),\n", + " )\n", + "\n", + "\n", + "class ViT(fl.Chain):\n", + " def __init__(\n", + " self,\n", + " dim: int = 128,\n", + " patch_size: int = 16,\n", + " image_size: int = 224,\n", + " num_layers: int = 4,\n", + " ) -> None:\n", + " self.dim = dim\n", + " self.patch_size = patch_size\n", + " self.image_size = image_size\n", + " self.num_layers = num_layers\n", + " self.num_patches = (image_size // patch_size) ** 2 + 1\n", + " super().__init__(\n", + " PointEncoder(dim=dim),\n", + " fl.Concatenate(\n", + " PatchEncoder(in_channels=3, dim=dim, patch_size=patch_size),\n", + " ClassToken(dim=dim),\n", + " dim=1,\n", + " ),\n", + " PositionalToken(num_patches=self.num_patches),\n", + " Transformer(TranformerLayer(dim=dim) for _ in range(num_layers)),\n", + " )\n", + "\n", + "\n", + "vit = ViT()\n", + "x = torch.randn(1, 3, 224, 224)\n", + "points = torch.randn(5, 2)\n", + "vit.set_context(\"vit\", {\"points_tensor\": points})\n", + "print(vit(x).shape)\n", + "vit" + ] + }, + { + "attachments": { + "image.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Adaptation\n", + "\n", + "I think to have a very explicit and declarative model like we showcased here is an interesting property, but the place where it shines the most is when you want to adapt a model to a new task.\n", + "\n", + "Let's demonstrate on a simple example how to create a LoRA adaptation on the ViT without having to rewrite the whole model.\n", + "\n", + "The low-rank adaptation technique adds lighter new layers on top of the model. The rank is the inner_dim of the new layers. The outer layer is zero-initialized, so the model's output is the same before training.\n", + "![image.png](attachment:image.png)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 1, 128])\n" + ] + }, + { + "data": { + "text/plain": [ + "(CHAIN) Lora(in_features=128, out_features=128)\n", + " ├── Linear(in_features=128, out_features=16) #1\n", + " └── Linear(in_features=16, out_features=128) #2" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class Lora(fl.Chain):\n", + " def __init__(\n", + " self,\n", + " in_features: int,\n", + " out_features: int,\n", + " rank: int = 16,\n", + " ) -> None:\n", + " self.in_features = in_features\n", + " self.out_features = out_features\n", + " self.rank = rank\n", + " self.scale: float = 1.0\n", + "\n", + " super().__init__(\n", + " fl.Linear(in_features=in_features, out_features=rank, bias=False),\n", + " fl.Linear(in_features=rank, out_features=out_features),\n", + " )\n", + "\n", + " nn.init.normal_(tensor=self.Linear_1.weight, std=1 / self.rank)\n", + " nn.init.zeros_(tensor=self.Linear_2.weight)\n", + "\n", + "\n", + "lora = Lora(128, 128)\n", + "x = torch.randn(1, 1, 128)\n", + "print(lora(x).shape)\n", + "lora" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we want to be able to insert this into any `Linear` layer of the Model. To do that, we can use the `Adapter` class." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Adapter:\n", + "(SUM) LoraAdapter()\n", + " ├── Linear(in_features=128, out_features=128)\n", + " └── (CHAIN) Lora(in_features=128, out_features=128)\n", + " ├── Linear(in_features=128, out_features=16) #1\n", + " └── Linear(in_features=16, out_features=128) #2\n", + "\n", + "Note that the original attention is not modified:\n", + "(RES) Attention()\n", + " ├── (PAR)\n", + " │ └── Linear(in_features=128, out_features=128) (x3)\n", + " ├── ScaledDotProductAttention(num_heads=8)\n", + " └── Linear(in_features=128, out_features=128) \n", + "\n" + ] + } + ], + "source": [ + "from refiners.fluxion.adapters import Adapter\n", + "\n", + "\n", + "class LoraAdapter(fl.Sum, Adapter[fl.Linear]):\n", + " def __init__(\n", + " self,\n", + " target: fl.Linear,\n", + " rank: int = 16,\n", + " ) -> None:\n", + " self.in_features = target.in_features\n", + " self.out_features = target.out_features\n", + " self.rank = rank\n", + " # the setup_adapter method is used to remove boilerplate code\n", + " with self.setup_adapter(target):\n", + " super().__init__(\n", + " target,\n", + " Lora(\n", + " in_features=target.in_features,\n", + " out_features=target.out_features,\n", + " rank=rank,\n", + " ),\n", + " )\n", + "\n", + "\n", + "attention = Attention()\n", + "linear = attention.ensure_find(fl.Linear)\n", + "adapter = LoraAdapter(linear)\n", + "print(\n", + " f\"\"\"\n", + "Adapter:\n", + "{repr(adapter)}\n", + "\n", + "Note that the original attention is not modified:\n", + "{repr(attention)} \n", + "\"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's now `inject` the `Adapter` into the `FeedForward` layer of the `TransformerLayer`. One subtlety is that the `Linear` layer is considered a `WeightedModule` and as such can belong to multiple `Chain` at the same time. So we need to specify which `Chain` we want to inject the `Adapter` into." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(RES) Attention()\n", + " ├── (PAR)\n", + " │ ├── (SUM) LoraAdapter()\n", + " │ │ ├── Linear(in_features=128, out_features=128)\n", + " │ │ └── (CHAIN) Lora(in_features=128, out_features=128)\n", + " │ │ ├── Linear(in_features=128, out_features=16) #1\n", + " │ │ └── Linear(in_features=16, out_features=128) #2\n", + " │ └── Linear(in_features=128, out_features=128) (x2)\n", + " ├── ScaledDotProductAttention(num_heads=8)\n", + " └── Linear(in_features=128, out_features=128)\n", + "(RES) Attention()\n", + " ├── (PAR)\n", + " │ └── Linear(in_features=128, out_features=128) (x3)\n", + " ├── ScaledDotProductAttention(num_heads=8)\n", + " └── Linear(in_features=128, out_features=128)\n" + ] + } + ], + "source": [ + "adapter.inject(parent=attention.Parallel)\n", + "print(repr(attention))\n", + "\n", + "# we can also `eject` the adapter to get back to normal\n", + "adapter.eject()\n", + "print(repr(attention))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally let's write a top-level adapter that will inject the `Adapter` into all the `Linear` layers of the `ViT`." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(CHAIN) ViT()\n", + " ├── (PASS) PointEncoder()\n", + " │ ├── UseContext(context=vit, key=points_tensor)\n", + " │ ├── (SUM) LoraAdapter() #1\n", + " │ │ ├── Linear(in_features=2, out_features=128)\n", + " │ │ └── (CHAIN) Lora(in_features=2, out_features=128)\n", + " │ │ ├── Linear(in_features=2, out_features=16) #1\n", + " │ │ └── Linear(in_features=16, out_features=128) #2\n", + " │ ├── SiLU() #1\n", + " │ ├── (SUM) LoraAdapter() #2\n", + " │ │ ├── Linear(in_features=128, out_features=128)\n", + " │ │ └── (CHAIN) Lora(in_features=128, out_features=128)\n", + " │ │ ├── Linear(in_features=128, out_features=16) #1\n", + " │ │ └── Linear(in_features=16, out_features=128) #2\n", + " │ ├── SiLU() #2\n", + " │ ├── (SUM) LoraAdapter() #3\n", + " │ │ ├── Linear(in_features=128, out_features=128)\n", + " │ │ └── (CHAIN) Lora(in_features=128, out_features=128)\n", + " │ │ ├── Linear(in_features=128, out_features=16) #1\n", + " │ │ └── Linear(in_features=16, out_features=128) #2\n", + " │ ├── Unsqueeze(dim=0)\n", + " │ └── SetContext(context=vit, key=points_embedding)\n", + " ├── (CAT)\n", + " │ ├── (CHAIN) PatchEncoder()\n", + " │ │ ├── Conv2d(in_channels=3, out_channels=128, kernel_size=(16, 16), stride=(16, 16))\n", + " │ │ └── Reshape(shape=(-1, 128))\n", + " │ └── (CHAIN) ClassToken()\n", + " │ └── Parameter(dims=(1, 128))\n", + " ├── (RES) PositionalToken(num_patches=197)\n", + " │ └── Parameter(dims=(197, 128))\n", + " └── (CHAIN) Transformer()\n", + " └── (CHAIN) TranformerLayer() (x4)\n", + " ├── LayerNorm(normalized_shape=(128,)) #1\n", + " ├── (RES) Attention()\n", + " │ ├── (PAR)\n", + " │ │ └── (SUM) LoraAdapter() (x3)\n", + " │ │ ├── Linear(in_features=128, out_features=128)\n", + " │ │ └── (CHAIN) Lora(in_features=128, out_features=128)\n", + " │ │ ├── Linear(in_features=128, out_features=16) #1\n", + " │ │ └── Linear(in_features=16, out_features=128) #2\n", + " │ ├── ScaledDotProductAttention(num_heads=8)\n", + " │ └── (SUM) LoraAdapter()\n", + " │ ├── Linear(in_features=128, out_features=128)\n", + " │ └── (CHAIN) Lora(in_features=128, out_features=128)\n", + " │ ├── Linear(in_features=128, out_features=16) #1\n", + " │ └── Linear(in_features=16, out_features=128) #2\n", + " ├── LayerNorm(normalized_shape=(128,)) #2\n", + " ├── (CHAIN) PointsCrossAttention()\n", + " │ ├── (PAR)\n", + " │ │ ├── Identity()\n", + " │ │ └── UseContext(context=vit, key=points_embedding) (x2)\n", + " │ ├── (DISTR)\n", + " │ │ └── (SUM) LoraAdapter() (x3)\n", + " │ │ ├── Linear(in_features=128, out_features=128)\n", + " │ │ └── (CHAIN) Lora(in_features=128, out_features=128)\n", + " │ │ ├── Linear(in_features=128, out_features=16) #1\n", + " │ │ └── Linear(in_features=16, out_features=128) #2\n", + " │ ├── ScaledDotProductAttention(num_heads=8)\n", + " │ └── (SUM) LoraAdapter()\n", + " │ ├── Linear(in_features=128, out_features=128)\n", + " │ └── (CHAIN) Lora(in_features=128, out_features=128)\n", + " │ ├── Linear(in_features=128, out_features=16) #1\n", + " │ └── Linear(in_features=16, out_features=128) #2\n", + " ├── LayerNorm(normalized_shape=(128,)) #3\n", + " └── (RES) FeedForward()\n", + " ├── (SUM) LoraAdapter() #1\n", + " │ ├── Linear(in_features=128, out_features=512)\n", + " │ └── (CHAIN) Lora(in_features=128, out_features=512)\n", + " │ ├── Linear(in_features=128, out_features=16) #1\n", + " │ └── Linear(in_features=16, out_features=512) #2\n", + " ├── SiLU()\n", + " └── (SUM) LoraAdapter() #2\n", + " ├── Linear(in_features=512, out_features=128)\n", + " └── (CHAIN) Lora(in_features=512, out_features=128)\n", + " ├── Linear(in_features=512, out_features=16) #1\n", + " └── Linear(in_features=16, out_features=128) #2\n", + "torch.Size([1, 197, 128])\n", + "(CHAIN) ViT()\n", + " ├── (PASS) PointEncoder()\n", + " │ ├── UseContext(context=vit, key=points_tensor)\n", + " │ ├── Linear(in_features=2, out_features=128) #1\n", + " │ ├── SiLU() #1\n", + " │ ├── Linear(in_features=128, out_features=128) #2\n", + " │ ├── SiLU() #2\n", + " │ ├── Linear(in_features=128, out_features=128) #3\n", + " │ ├── Unsqueeze(dim=0)\n", + " │ └── SetContext(context=vit, key=points_embedding)\n", + " ├── (CAT)\n", + " │ ├── (CHAIN) PatchEncoder()\n", + " │ │ ├── Conv2d(in_channels=3, out_channels=128, kernel_size=(16, 16), stride=(16, 16))\n", + " │ │ └── Reshape(shape=(-1, 128))\n", + " │ └── (CHAIN) ClassToken()\n", + " │ └── Parameter(dims=(1, 128))\n", + " ├── (RES) PositionalToken(num_patches=197)\n", + " │ └── Parameter(dims=(197, 128))\n", + " └── (CHAIN) Transformer()\n", + " └── (CHAIN) TranformerLayer() (x4)\n", + " ├── LayerNorm(normalized_shape=(128,)) #1\n", + " ├── (RES) Attention()\n", + " │ ├── (PAR)\n", + " │ │ └── Linear(in_features=128, out_features=128) (x3)\n", + " │ ├── ScaledDotProductAttention(num_heads=8)\n", + " │ └── Linear(in_features=128, out_features=128)\n", + " ├── LayerNorm(normalized_shape=(128,)) #2\n", + " ├── (CHAIN) PointsCrossAttention()\n", + " │ ├── (PAR)\n", + " │ │ ├── Identity()\n", + " │ │ └── UseContext(context=vit, key=points_embedding) (x2)\n", + " │ ├── (DISTR)\n", + " │ │ └── Linear(in_features=128, out_features=128) (x3)\n", + " │ ├── ScaledDotProductAttention(num_heads=8)\n", + " │ └── Linear(in_features=128, out_features=128)\n", + " ├── LayerNorm(normalized_shape=(128,)) #3\n", + " └── (RES) FeedForward()\n", + " ├── Linear(in_features=128, out_features=512) #1\n", + " ├── SiLU()\n", + " └── Linear(in_features=512, out_features=128) #2\n" + ] + } + ], + "source": [ + "from typing import Self\n", + "\n", + "\n", + "class ViTLoraAdapter(fl.Chain, Adapter[ViT]):\n", + " def __init__(\n", + " self,\n", + " target: ViT,\n", + " rank: int = 16,\n", + " ) -> None:\n", + " self.rank = rank\n", + " with self.setup_adapter(target):\n", + " super().__init__(target)\n", + "\n", + " # Let's wrap all the Linear layers in the ViT model into LoraAdapters\n", + " self.sub_adapters: list[tuple[LoraAdapter, fl.Chain]] = []\n", + " for linear, parent in self.target.walk(fl.Linear):\n", + " self.sub_adapters.append((LoraAdapter(target=linear, rank=rank), parent))\n", + "\n", + " def inject(self, parent: fl.Chain | None = None) -> Self:\n", + " for adapter, adapter_parent in self.sub_adapters:\n", + " adapter.inject(adapter_parent)\n", + " return super().inject(parent)\n", + "\n", + " def eject(self) -> None:\n", + " for adapter, _ in self.sub_adapters:\n", + " adapter.eject()\n", + " super().eject()\n", + "\n", + "\n", + "vit = ViT()\n", + "x = torch.randn(1, 3, 224, 224)\n", + "points = torch.randn(5, 2)\n", + "vit.set_context(\"vit\", {\"points_tensor\": points})\n", + "adapter = ViTLoraAdapter(vit)\n", + "adapter.inject() # since `ViT` has no parent, no need to pass it to `inject`\n", + "print(repr(vit))\n", + "print(vit(x).shape)\n", + "\n", + "# we can also `eject` the adapter to get back to normal\n", + "adapter.eject()\n", + "print(repr(vit))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 17d2e93c5..0ca1f8b4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,10 +54,11 @@ build-backend = "hatchling.build" [tool.rye] managed = true dev-dependencies = [ - "pyright == 1.1.333", + "pyright == 1.1.342", "ruff>=0.0.292", "docformatter>=1.7.5", "pytest>=7.4.2", + "mkdocs-material>=9.5.3", ] @@ -66,6 +67,7 @@ allow-direct-references = true [tool.rye.scripts] lint = { chain = ["ruff format .", "ruff --fix ."] } +serve-docs = "mkdocs serve" [tool.black] line-length = 120 diff --git a/requirements.docs.txt b/requirements.docs.txt new file mode 100644 index 000000000..00707e027 --- /dev/null +++ b/requirements.docs.txt @@ -0,0 +1 @@ +mkdocs-material==9.5.3 diff --git a/scripts/conversion/convert_diffusers_controlnet.py b/scripts/conversion/convert_diffusers_controlnet.py index cacdfdd54..5185193db 100644 --- a/scripts/conversion/convert_diffusers_controlnet.py +++ b/scripts/conversion/convert_diffusers_controlnet.py @@ -7,7 +7,7 @@ from torch import nn from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import save_to_safetensors +from refiners.fluxion.utils import no_grad, save_to_safetensors from refiners.foundationals.latent_diffusion import ( DPMSolver, SD1ControlnetAdapter, @@ -20,7 +20,7 @@ class Args(argparse.Namespace): output_path: str | None -@torch.no_grad() +@no_grad() def convert(args: Args) -> dict[str, torch.Tensor]: # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` controlnet_src: nn.Module = ControlNetModel.from_pretrained( # type: ignore diff --git a/scripts/conversion/convert_diffusers_ip_adapter.py b/scripts/conversion/convert_diffusers_ip_adapter.py index 8fd33be03..8282db1af 100644 --- a/scripts/conversion/convert_diffusers_ip_adapter.py +++ b/scripts/conversion/convert_diffusers_ip_adapter.py @@ -133,24 +133,14 @@ def main() -> None: ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"] assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2 - for i, cross_attn in enumerate(ip_adapter.sub_adapters): + for i, _ in enumerate(ip_adapter.sub_adapters): cross_attn_index = cross_attn_mapping[i] k_ip = f"{cross_attn_index}.to_k_ip.weight" v_ip = f"{cross_attn_index}.to_v_ip.weight" - # Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights - - names = [k for k, _ in cross_attn.named_parameters()] - assert len(names) == 2 - - cross_attn_state_dict: dict[str, Any] = { - names[0]: ip_adapter_weights[k_ip], - names[1]: ip_adapter_weights[v_ip], - } - cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False) - - for k, v in cross_attn_state_dict.items(): - state_dict[f"ip_adapter.{i:03d}.{k}"] = v + # the name of the key is not checked at runtime, so we keep the original name + state_dict[f"ip_adapter.{i:03d}.to_k_ip.weight"] = ip_adapter_weights[k_ip] + state_dict[f"ip_adapter.{i:03d}.to_v_ip.weight"] = ip_adapter_weights[v_ip] if args.half: state_dict = {key: value.half() for key, value in state_dict.items()} diff --git a/scripts/conversion/convert_diffusers_lora.py b/scripts/conversion/convert_diffusers_lora.py index 9abffd872..1c37d8dbd 100644 --- a/scripts/conversion/convert_diffusers_lora.py +++ b/scripts/conversion/convert_diffusers_lora.py @@ -11,7 +11,7 @@ import refiners.fluxion.layers as fl from refiners.fluxion.adapters.lora import Lora, LoraAdapter from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import save_to_safetensors +from refiners.fluxion.utils import no_grad, save_to_safetensors from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets @@ -37,7 +37,7 @@ class Args(argparse.Namespace): verbose: bool -@torch.no_grad() +@no_grad() def process(args: Args) -> None: diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` diff --git a/scripts/conversion/convert_segment_anything.py b/scripts/conversion/convert_segment_anything.py index 9057cbb33..14ba2ef2b 100644 --- a/scripts/conversion/convert_segment_anything.py +++ b/scripts/conversion/convert_segment_anything.py @@ -37,13 +37,36 @@ class Args(argparse.Namespace): def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]: + manual_seed(seed=0) + refiners_mask_encoder = MaskEncoder() + + converter = ModelConverter( + source_model=prompt_encoder.mask_downscaling, + target_model=refiners_mask_encoder, + custom_layer_mapping=custom_layers, # type: ignore + ) + + x = torch.randn(1, 256, 256) + mapping = converter.map_state_dicts(source_args=(x,)) + assert mapping + + source_state_dict = prompt_encoder.mask_downscaling.state_dict() + target_state_dict = refiners_mask_encoder.state_dict() + + # Mapping handled manually (see below) because nn.Parameter is a special case + del target_state_dict["no_mask_embedding"] + + converted_source = converter._convert_state_dict( # pyright: ignore[reportPrivateUsage] + source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping + ) + state_dict: dict[str, Tensor] = { "no_mask_embedding": nn.Parameter(data=prompt_encoder.no_mask_embed.weight.clone()), # type: ignore } - refiners_mask_encoder = MaskEncoder() - # TODO: handle other weights - refiners_mask_encoder.load_state_dict(state_dict=state_dict, strict=False) + state_dict.update(converted_source) + + refiners_mask_encoder.load_state_dict(state_dict=state_dict) return state_dict diff --git a/src/refiners/fluxion/layers/sampling.py b/src/refiners/fluxion/layers/sampling.py index d6368e3b1..69d1412e3 100644 --- a/src/refiners/fluxion/layers/sampling.py +++ b/src/refiners/fluxion/layers/sampling.py @@ -1,3 +1,5 @@ +from typing import Callable + from torch import Size, Tensor, device as Device, dtype as DType from torch.nn.functional import pad @@ -40,7 +42,8 @@ def __init__( ), ) if padding == 0: - self.insert(0, Lambda(lambda x: pad(x, (0, 1, 0, 1)))) + zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1)) + self.insert(0, Lambda(zero_pad)) if register_shape: self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape)) diff --git a/src/refiners/fluxion/model_converter.py b/src/refiners/fluxion/model_converter.py index ef14d0238..8e47ebb49 100644 --- a/src/refiners/fluxion/model_converter.py +++ b/src/refiners/fluxion/model_converter.py @@ -7,7 +7,7 @@ from torch import Tensor, nn from torch.utils.hooks import RemovableHandle -from refiners.fluxion.utils import norm, save_to_safetensors +from refiners.fluxion.utils import no_grad, norm, save_to_safetensors TORCH_BASIC_LAYERS: list[type[nn.Module]] = [ nn.Conv1d, @@ -512,7 +512,7 @@ def _verify_missing_basic_layers(self) -> bool: return True - @torch.no_grad() + @no_grad() def _trace_module_execution_order( self, module: nn.Module, @@ -603,7 +603,7 @@ def _convert_state_dict( return converted_state_dict - @torch.no_grad() + @no_grad() def _collect_layers_outputs( self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str] ) -> list[tuple[str, Tensor]]: diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 7c4f5e06d..deb0d4693 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Iterable, Literal, TypeVar +from typing import Any, Iterable, Literal, TypeVar import torch from jaxtyping import Float @@ -7,7 +7,14 @@ from PIL import Image from safetensors import safe_open as _safe_open # type: ignore from safetensors.torch import save_file as _save_file # type: ignore -from torch import Tensor, device as Device, dtype as DType, manual_seed as _manual_seed, norm as _norm # type: ignore +from torch import ( + Tensor, + device as Device, + dtype as DType, + manual_seed as _manual_seed, # type: ignore + no_grad as _no_grad, # type: ignore + norm as _norm, # type: ignore +) from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore T = TypeVar("T") @@ -22,6 +29,11 @@ def manual_seed(seed: int) -> None: _manual_seed(seed) +class no_grad(_no_grad): + def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore + return object.__new__(cls) + + def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant") -> Tensor: return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore diff --git a/src/refiners/foundationals/clip/image_encoder.py b/src/refiners/foundationals/clip/image_encoder.py index ed6db3d7b..270d1beb9 100644 --- a/src/refiners/foundationals/clip/image_encoder.py +++ b/src/refiners/foundationals/clip/image_encoder.py @@ -1,4 +1,6 @@ -from torch import device as Device, dtype as DType +from typing import Callable + +from torch import Tensor, device as Device, dtype as DType import refiners.fluxion.layers as fl from refiners.foundationals.clip.common import FeedForward, PositionalEncoder @@ -126,6 +128,7 @@ def __init__( self.num_layers = num_layers self.num_attention_heads = num_attention_heads self.feedforward_dim = feedforward_dim + cls_token_pooling: Callable[[Tensor], Tensor] = lambda x: x[:, 0, :] super().__init__( ViTEmbeddings( image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype @@ -142,7 +145,7 @@ def __init__( ) for _ in range(num_layers) ), - fl.Lambda(func=lambda x: x[:, 0, :]), + fl.Lambda(func=cls_token_pooling), fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype), fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype), ) diff --git a/src/refiners/foundationals/latent_diffusion/freeu.py b/src/refiners/foundationals/latent_diffusion/freeu.py index 61726bd81..3e2580f2e 100644 --- a/src/refiners/foundationals/latent_diffusion/freeu.py +++ b/src/refiners/foundationals/latent_diffusion/freeu.py @@ -1,5 +1,5 @@ import math -from typing import Any, Generic, TypeVar +from typing import Any, Callable, Generic, TypeVar import torch from torch import Tensor @@ -54,9 +54,10 @@ def forward(self, x: Tensor) -> Tensor: class FreeUSkipFeatures(fl.Chain): def __init__(self, n: int, skip_scale: float) -> None: + apply_filter: Callable[[Tensor], Tensor] = lambda x: fourier_filter(x, scale=skip_scale) super().__init__( fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[n]), - fl.Lambda(lambda x: fourier_filter(x, scale=skip_scale)), + fl.Lambda(apply_filter), ) diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 9c4bac019..b0cdfdd21 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -1,17 +1,15 @@ import math -from enum import IntEnum -from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from jaxtyping import Float from PIL import Image -from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like +from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter -from refiners.fluxion.adapters.lora import Lora from refiners.fluxion.context import Contexts from refiners.fluxion.layers.attentions import ScaledDotProductAttention +from refiners.fluxion.layers.chain import Distribute from refiners.fluxion.utils import image_to_tensor, normalize from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH @@ -236,120 +234,89 @@ def init_context(self) -> Contexts: return {"perceiver_resampler": {"x": None}} -class _CrossAttnIndex(IntEnum): - TXT_CROSS_ATTN = 0 # text cross-attention - IMG_CROSS_ATTN = 1 # image cross-attention - - class InjectionPoint(fl.Chain): pass +class ImageCrossAttention(fl.Chain): + def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None: + self.scale = scale + super().__init__( + fl.Distribute( + fl.UseContext(context="ip_adapter", key="query_projection"), + fl.Chain( + fl.UseContext(context="ip_adapter", key="clip_image_embedding"), + fl.Linear( + in_features=text_cross_attention.key_embedding_dim, + out_features=text_cross_attention.inner_dim, + bias=text_cross_attention.use_bias, + device=text_cross_attention.device, + dtype=text_cross_attention.dtype, + ), + ), + fl.Chain( + fl.UseContext(context="ip_adapter", key="clip_image_embedding"), + fl.Linear( + in_features=text_cross_attention.key_embedding_dim, + out_features=text_cross_attention.inner_dim, + bias=text_cross_attention.use_bias, + device=text_cross_attention.device, + dtype=text_cross_attention.dtype, + ), + ), + ), + ScaledDotProductAttention( + num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal + ), + fl.Multiply(self.scale), + ) + + +class SetQueryProjection(fl.Passthrough): + def __init__(self) -> None: + super().__init__(fl.GetArg(index=0), fl.SetContext(context="ip_adapter", key="query_projection")) + + class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): def __init__( self, target: fl.Attention, - text_sequence_length: int = 77, - image_sequence_length: int = 4, scale: float = 1.0, ) -> None: - self.text_sequence_length = text_sequence_length - self.image_sequence_length = image_sequence_length - self.scale = scale - with self.setup_adapter(target): super().__init__( - fl.Distribute( - # Note: the same query is used for image cross-attention as for text cross-attention - InjectionPoint(), # Wq - fl.Parallel( - fl.Chain( - fl.Slicing(dim=1, end=text_sequence_length), - InjectionPoint(), # Wk - ), - fl.Chain( - fl.Slicing(dim=1, start=text_sequence_length), - fl.Linear( - in_features=self.target.key_embedding_dim, - out_features=self.target.inner_dim, - bias=self.target.use_bias, - device=target.device, - dtype=target.dtype, - ), # Wk' - ), - ), - fl.Parallel( - fl.Chain( - fl.Slicing(dim=1, end=text_sequence_length), - InjectionPoint(), # Wv - ), - fl.Chain( - fl.Slicing(dim=1, start=text_sequence_length), - fl.Linear( - in_features=self.target.key_embedding_dim, - out_features=self.target.inner_dim, - bias=self.target.use_bias, - device=target.device, - dtype=target.dtype, - ), # Wv' - ), - ), - ), fl.Sum( - fl.Chain( - fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)), - ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal), - ), - fl.Chain( - fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)), - ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal), - fl.Lambda(func=self.scale_outputs), - ), + target[:-1], # original text cross attention + ImageCrossAttention(text_cross_attention=target, scale=scale), ), - InjectionPoint(), # proj + target[-1], # projection ) + self.ensure_find(fl.Attention).insert_after_type(Distribute, SetQueryProjection()) - def select_qkv( - self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex - ) -> tuple[Tensor, Tensor, Tensor]: - return (query, keys[index.value], values[index.value]) - - def scale_outputs(self, x: Tensor) -> Tensor: - return x * self.scale - - def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]: - def f(m: fl.Module, _: fl.Chain) -> bool: - if isinstance(m, Lora): # do not adapt LoRAs - raise StopIteration - return isinstance(m, k) - - return f - - def _target_linears(self) -> list[fl.Linear]: - return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)] - - def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter": - linears = self._target_linears() - assert len(linears) == 4 # Wq, Wk, Wv and Proj - - injection_points = list(self.layers(InjectionPoint)) - assert len(injection_points) == 4 + @property + def image_cross_attention(self) -> ImageCrossAttention: + return self.ensure_find(ImageCrossAttention) - for linear, ip in zip(linears, injection_points): - ip.append(linear) - assert len(ip) == 1 + @property + def image_key_projection(self) -> fl.Linear: + return self.image_cross_attention.Distribute[1].Linear - return super().inject(parent) + @property + def image_value_projection(self) -> fl.Linear: + return self.image_cross_attention.Distribute[2].Linear - def eject(self) -> None: - injection_points = list(self.layers(InjectionPoint)) - assert len(injection_points) == 4 + @property + def scale(self) -> float: + return self.image_cross_attention.scale - for ip in injection_points: - ip.pop() - assert len(ip) == 0 + @scale.setter + def scale(self, value: float) -> None: + self.image_cross_attention.scale = value - super().eject() + def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None: + self.image_key_projection.weight = nn.Parameter(key_tensor) + self.image_value_projection.weight = nn.Parameter(value_tensor) + self.image_cross_attention.to(self.device, self.dtype) class IPAdapter(Generic[T], fl.Chain, Adapter[T]): @@ -377,7 +344,7 @@ def __init__( self._image_proj = [image_proj] self.sub_adapters = [ - CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens) + CrossAttentionAdapter(target=cross_attn, scale=scale) for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention)) ] @@ -388,14 +355,15 @@ def __init__( self.image_proj.load_state_dict(image_proj_state_dict) for i, cross_attn in enumerate(self.sub_adapters): - cross_attn_state_dict: dict[str, Tensor] = {} + cross_attention_weights: list[Tensor] = [] for k, v in weights.items(): prefix = f"ip_adapter.{i:03d}." if not k.startswith(prefix): continue - cross_attn_state_dict[k.removeprefix(prefix)] = v + cross_attention_weights.append(v) - cross_attn.load_state_dict(state_dict=cross_attn_state_dict) + assert len(cross_attention_weights) == 2 + cross_attn.load_weights(*cross_attention_weights) @property def clip_image_encoder(self) -> CLIPImageEncoderH: @@ -420,10 +388,22 @@ def eject(self) -> None: adapter.eject() super().eject() + @property + def scale(self) -> float: + return self.sub_adapters[0].scale + + @scale.setter + def scale(self, value: float) -> None: + for cross_attn in self.sub_adapters: + cross_attn.scale = value + def set_scale(self, scale: float) -> None: for cross_attn in self.sub_adapters: cross_attn.scale = scale + def set_clip_image_embedding(self, image_embedding: Tensor) -> None: + self.set_context("ip_adapter", {"clip_image_embedding": image_embedding}) + # These should be concatenated to the CLIP text embedding before setting the UNet context def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor: image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index bc041c8b1..d8820ab6f 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -122,7 +122,7 @@ def from_safetensors( assert metadata is not None, "Invalid safetensors checkpoint: missing metadata" tensors = load_from_safetensors(checkpoint_path, device=target.device) - sub_targets = {} + sub_targets: dict[str, list[LoraTarget]] = {} for model_name in MODELS: if not (v := metadata.get(f"{model_name}_targets", "")): continue diff --git a/src/refiners/foundationals/latent_diffusion/reference_only_control.py b/src/refiners/foundationals/latent_diffusion/reference_only_control.py index bf17bc724..1f0e049ca 100644 --- a/src/refiners/foundationals/latent_diffusion/reference_only_control.py +++ b/src/refiners/foundationals/latent_diffusion/reference_only_control.py @@ -1,3 +1,5 @@ +from typing import Callable + from torch import Tensor from refiners.fluxion.adapters.adapter import Adapter @@ -45,8 +47,9 @@ def __init__( ) with self.setup_adapter(target): + slice_tensor: Callable[[Tensor], Tensor] = lambda x: x[:1] super().__init__( - Parallel(sa_guided, Chain(Lambda(lambda x: x[:1]), target)), + Parallel(sa_guided, Chain(Lambda(slice_tensor), target)), Lambda(self.compute_averaged_unconditioned_x), ) diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index f8abfb71e..905c4b67f 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -3,11 +3,12 @@ import numpy as np import torch +from jaxtyping import Float from PIL import Image from torch import Tensor, device as Device, dtype as DType import refiners.fluxion.layers as fl -from refiners.fluxion.utils import image_to_tensor, interpolate, normalize, pad +from refiners.fluxion.utils import interpolate, no_grad, normalize, pad from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder @@ -39,7 +40,7 @@ def __init__( self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype) self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype) - @torch.no_grad() + @no_grad() def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding: original_size = (image.height, image.width) target_size = self.compute_target_size(original_size) @@ -48,14 +49,14 @@ def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding: original_image_size=original_size, ) - @torch.no_grad() + @no_grad() def predict( self, input: Image.Image | ImageEmbedding, foreground_points: Sequence[tuple[float, float]] | None = None, background_points: Sequence[tuple[float, float]] | None = None, box_points: Sequence[Sequence[tuple[float, float]]] | None = None, - masks: Sequence[Image.Image] | None = None, + low_res_mask: Float[Tensor, "1 1 256 256"] | None = None, binarize: bool = True, ) -> tuple[Tensor, Tensor, Tensor]: if isinstance(input, ImageEmbedding): @@ -74,15 +75,13 @@ def predict( ) self.point_encoder.set_type_mask(type_mask=type_mask) - if masks is not None: - mask_tensor = torch.stack( - tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks] - ) - mask_embedding = self.mask_encoder(mask_tensor) + if low_res_mask is not None: + mask_embedding = self.mask_encoder(low_res_mask) else: mask_embedding = self.mask_encoder.get_no_mask_dense_embedding( image_embedding_size=self.image_encoder.image_embedding_size ) + point_embedding = self.point_encoder( self.normalize(coordinates, target_size=target_size, original_size=original_size) ) diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index dff85fca7..f4f8ccf0a 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -250,7 +250,7 @@ def on_compute_loss_end(self, trainer: LatentDiffusionTrainer[Any]) -> None: self.timestep_bins[bin_index].append(loss_value) def on_epoch_end(self, trainer: LatentDiffusionTrainer[Any]) -> None: - log_data = {} + log_data: dict[str, WandbLoggable] = {} for bin_index, losses in self.timestep_bins.items(): if losses: avg_loss = sum(losses) / len(losses) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 87276d8ca..730bb8ae7 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -6,7 +6,7 @@ import numpy as np from loguru import logger -from torch import Tensor, cuda, device as Device, get_rng_state, no_grad, set_rng_state, stack +from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack from torch.autograd import backward from torch.nn import Parameter from torch.optim import Optimizer @@ -26,7 +26,7 @@ from torch.utils.data import DataLoader, Dataset from refiners.fluxion import layers as fl -from refiners.fluxion.utils import manual_seed +from refiners.fluxion.utils import manual_seed, no_grad from refiners.training_utils.callback import ( Callback, ClockCallback, diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 201c1b0e4..4850e79c3 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -6,7 +6,7 @@ import torch from PIL import Image -from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed +from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.latent_diffusion import ( SD1ControlnetAdapter, @@ -501,7 +501,7 @@ def sdxl_ddim( return sdxl -@torch.no_grad() +@no_grad() def test_diffusion_std_random_init( sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device ): @@ -529,7 +529,7 @@ def test_diffusion_std_random_init( ensure_similar_images(predicted_image, expected_image_std_random_init) -@torch.no_grad() +@no_grad() def test_diffusion_karras_random_init( sd15_ddim_karras: StableDiffusion_1, expected_karras_random_init: Image.Image, test_device: torch.device ): @@ -554,7 +554,7 @@ def test_diffusion_karras_random_init( ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_std_random_init_float16( sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device ): @@ -583,7 +583,7 @@ def test_diffusion_std_random_init_float16( ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_std_random_init_sag( sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device ): @@ -612,7 +612,7 @@ def test_diffusion_std_random_init_sag( ensure_similar_images(predicted_image, expected_image_std_random_init_sag) -@torch.no_grad() +@no_grad() def test_diffusion_std_init_image( sd15_std: StableDiffusion_1, cutecat_init: Image.Image, @@ -643,7 +643,7 @@ def test_diffusion_std_init_image( ensure_similar_images(predicted_image, expected_image_std_init_image) -@torch.no_grad() +@no_grad() def test_rectangular_init_latents( sd15_std: StableDiffusion_1, cutecat_init: Image.Image, @@ -658,7 +658,7 @@ def test_rectangular_init_latents( assert sd15.lda.decode_latents(x).size == (width, height) -@torch.no_grad() +@no_grad() def test_diffusion_inpainting( sd15_inpainting: StableDiffusion_1_Inpainting, kitchen_dog: Image.Image, @@ -692,7 +692,7 @@ def test_diffusion_inpainting( ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95) -@torch.no_grad() +@no_grad() def test_diffusion_inpainting_float16( sd15_inpainting_float16: StableDiffusion_1_Inpainting, kitchen_dog: Image.Image, @@ -727,7 +727,7 @@ def test_diffusion_inpainting_float16( ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92) -@torch.no_grad() +@no_grad() def test_diffusion_controlnet( sd15_std: StableDiffusion_1, controlnet_data: tuple[str, Image.Image, Image.Image, Path], @@ -770,7 +770,7 @@ def test_diffusion_controlnet( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_controlnet_structural_copy( sd15_std: StableDiffusion_1, controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path], @@ -814,7 +814,7 @@ def test_diffusion_controlnet_structural_copy( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_controlnet_float16( sd15_std_float16: StableDiffusion_1, controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path], @@ -857,7 +857,7 @@ def test_diffusion_controlnet_float16( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_controlnet_stack( sd15_std: StableDiffusion_1, controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path], @@ -912,7 +912,7 @@ def test_diffusion_controlnet_stack( ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_lora( sd15_std: StableDiffusion_1, lora_data_pokemon: tuple[Image.Image, Path], @@ -949,7 +949,7 @@ def test_diffusion_lora( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_lora_float16( sd15_std_float16: StableDiffusion_1, lora_data_pokemon: tuple[Image.Image, Path], @@ -986,7 +986,7 @@ def test_diffusion_lora_float16( ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_lora_twice( sd15_std: StableDiffusion_1, lora_data_pokemon: tuple[Image.Image, Path], @@ -1025,7 +1025,7 @@ def test_diffusion_lora_twice( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_refonly( sd15_ddim: StableDiffusion_1, condition_image_refonly: Image.Image, @@ -1061,7 +1061,7 @@ def test_diffusion_refonly( ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=35, min_ssim=0.99) -@torch.no_grad() +@no_grad() def test_diffusion_inpainting_refonly( sd15_inpainting: StableDiffusion_1_Inpainting, scene_image_inpainting_refonly: Image.Image, @@ -1106,7 +1106,7 @@ def test_diffusion_inpainting_refonly( ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99) -@torch.no_grad() +@no_grad() def test_diffusion_textual_inversion_random_init( sd15_std: StableDiffusion_1, expected_image_textual_inversion_random_init: Image.Image, @@ -1141,7 +1141,7 @@ def test_diffusion_textual_inversion_random_init( ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_ip_adapter( sd15_ddim_lda_ft_mse: StableDiffusion_1, ip_adapter_weights: Path, @@ -1168,16 +1168,7 @@ def test_diffusion_ip_adapter( clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) - - negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) - negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) - - clip_text_embedding = torch.cat( - ( - torch.cat([negative_text_embedding, negative_image_embedding], dim=1), - torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), - ) - ) + ip_adapter.set_clip_image_embedding(clip_image_embedding) sd15.set_num_inference_steps(n_steps) @@ -1196,7 +1187,7 @@ def test_diffusion_ip_adapter( ensure_similar_images(predicted_image, expected_image_ip_adapter_woman) -@torch.no_grad() +@no_grad() def test_diffusion_sdxl_ip_adapter( sdxl_ddim: StableDiffusion_XL, sdxl_ip_adapter_weights: Path, @@ -1215,28 +1206,20 @@ def test_diffusion_sdxl_ip_adapter( ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.inject() - with torch.no_grad(): + with no_grad(): clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt, negative_text=negative_prompt ) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) + ip_adapter.set_clip_image_embedding(clip_image_embedding) - negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) - negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) - - clip_text_embedding = torch.cat( - ( - torch.cat([negative_text_embedding, negative_image_embedding], dim=1), - torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), - ) - ) time_ids = sdxl.default_time_ids sdxl.set_num_inference_steps(n_steps) manual_seed(2) x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16) - with torch.no_grad(): + with no_grad(): for step in sdxl.steps: x = sdxl( x, @@ -1254,7 +1237,7 @@ def test_diffusion_sdxl_ip_adapter( ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman) -@torch.no_grad() +@no_grad() def test_diffusion_ip_adapter_controlnet( sd15_ddim: StableDiffusion_1, ip_adapter_weights: Path, @@ -1285,16 +1268,7 @@ def test_diffusion_ip_adapter_controlnet( clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image)) - - negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) - negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) - - clip_text_embedding = torch.cat( - ( - torch.cat([negative_text_embedding, negative_image_embedding], dim=1), - torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), - ) - ) + ip_adapter.set_clip_image_embedding(clip_image_embedding) depth_cn_condition = image_to_tensor( depth_condition_image.convert("RGB"), @@ -1320,7 +1294,7 @@ def test_diffusion_ip_adapter_controlnet( ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet) -@torch.no_grad() +@no_grad() def test_diffusion_ip_adapter_plus( sd15_ddim_lda_ft_mse: StableDiffusion_1, ip_adapter_plus_weights: Path, @@ -1343,16 +1317,7 @@ def test_diffusion_ip_adapter_plus( clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image)) - - negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) - negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) - - clip_text_embedding = torch.cat( - ( - torch.cat([negative_text_embedding, negative_image_embedding], dim=1), - torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), - ) - ) + ip_adapter.set_clip_image_embedding(clip_image_embedding) sd15.set_num_inference_steps(n_steps) @@ -1371,7 +1336,7 @@ def test_diffusion_ip_adapter_plus( ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_sdxl_ip_adapter_plus( sdxl_ddim: StableDiffusion_XL, sdxl_ip_adapter_plus_weights: Path, @@ -1396,16 +1361,8 @@ def test_diffusion_sdxl_ip_adapter_plus( text=prompt, negative_text=negative_prompt ) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) + ip_adapter.set_clip_image_embedding(clip_image_embedding) - negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) - negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) - - clip_text_embedding = torch.cat( - ( - torch.cat([negative_text_embedding, negative_image_embedding], dim=1), - torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), - ) - ) time_ids = sdxl.default_time_ids sdxl.set_num_inference_steps(n_steps) @@ -1427,7 +1384,7 @@ def test_diffusion_sdxl_ip_adapter_plus( ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman) -@torch.no_grad() +@no_grad() def test_sdxl_random_init( sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device ) -> None: @@ -1462,7 +1419,7 @@ def test_sdxl_random_init( ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_sdxl_random_init_sag( sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device ) -> None: @@ -1498,7 +1455,7 @@ def test_sdxl_random_init_sag( ensure_similar_images(img_1=predicted_image, img_2=expected_image) -@torch.no_grad() +@no_grad() def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None: manual_seed(seed=2) sd = sd15_ddim @@ -1529,7 +1486,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_t2i_adapter_depth( sd15_std: StableDiffusion_1, t2i_adapter_data_depth: tuple[str, Image.Image, Image.Image, Path], @@ -1570,7 +1527,7 @@ def test_t2i_adapter_depth( ensure_similar_images(predicted_image, expected_image) -@torch.no_grad() +@no_grad() def test_t2i_adapter_xl_canny( sdxl_ddim: StableDiffusion_XL, t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path], @@ -1619,7 +1576,7 @@ def test_t2i_adapter_xl_canny( ensure_similar_images(predicted_image, expected_image) -@torch.no_grad() +@no_grad() def test_restart( sd15_ddim: StableDiffusion_1, expected_restart: Image.Image, @@ -1659,7 +1616,7 @@ def test_restart( ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_freeu( sd15_std: StableDiffusion_1, expected_freeu: Image.Image, diff --git a/tests/e2e/test_diffusion_ref/expected_ip_adapter_controlnet.png b/tests/e2e/test_diffusion_ref/expected_ip_adapter_controlnet.png index e838df167..41a90cc4e 100644 Binary files a/tests/e2e/test_diffusion_ref/expected_ip_adapter_controlnet.png and b/tests/e2e/test_diffusion_ref/expected_ip_adapter_controlnet.png differ diff --git a/tests/e2e/test_preprocessors.py b/tests/e2e/test_preprocessors.py index 69131b21c..4492638d0 100644 --- a/tests/e2e/test_preprocessors.py +++ b/tests/e2e/test_preprocessors.py @@ -5,7 +5,7 @@ import torch from PIL import Image -from refiners.fluxion.utils import image_to_tensor, tensor_to_image +from refiners.fluxion.utils import image_to_tensor, no_grad, tensor_to_image from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings from tests.utils import ensure_similar_images @@ -41,7 +41,7 @@ def informative_drawings_model(informative_drawings_weights: Path, test_device: return model -@torch.no_grad() +@no_grad() def test_preprocessor_informative_drawing( informative_drawings_model: InformativeDrawings, cutecat_init: Image.Image, diff --git a/tests/fluxion/layers/test_converter.py b/tests/fluxion/layers/test_converter.py index 8a3318838..356cdb1fd 100644 --- a/tests/fluxion/layers/test_converter.py +++ b/tests/fluxion/layers/test_converter.py @@ -1,3 +1,4 @@ +from typing import Any, Callable from warnings import warn import pytest @@ -60,8 +61,9 @@ def test_converter_multiple_tensors(test_device: torch.device) -> None: def test_converter_no_parent_device_or_dtype() -> None: + identity: Callable[[Any], Any] = lambda x: x chain = fl.Chain( - fl.Lambda(func=(lambda x: x)), + fl.Lambda(func=identity), fl.Converter(set_device=True, set_dtype=False), ) diff --git a/tests/fluxion/test_utils.py b/tests/fluxion/test_utils.py index 8811c4716..883755089 100644 --- a/tests/fluxion/test_utils.py +++ b/tests/fluxion/test_utils.py @@ -7,7 +7,7 @@ from torch import device as Device, dtype as DType from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore -from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, tensor_to_image +from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, no_grad, tensor_to_image @dataclass @@ -62,3 +62,18 @@ def test_tensor_to_image() -> None: assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB" assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L" assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA" + + +def test_no_grad() -> None: + x = torch.randn(1, 1, requires_grad=True) + + with torch.no_grad(): + y = x + 1 + assert not y.requires_grad + + with no_grad(): + z = x + 1 + assert not z.requires_grad + + w = x + 1 + assert w.requires_grad diff --git a/tests/foundationals/clip/test_concepts.py b/tests/foundationals/clip/test_concepts.py index ed8656145..9c4ed749f 100644 --- a/tests/foundationals/clip/test_concepts.py +++ b/tests/foundationals/clip/test_concepts.py @@ -7,7 +7,7 @@ from diffusers import StableDiffusionPipeline # type: ignore import refiners.fluxion.layers as fl -from refiners.fluxion.utils import load_from_safetensors +from refiners.fluxion.utils import load_from_safetensors, no_grad from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.tokenizer import CLIPTokenizer @@ -124,7 +124,7 @@ def test_encoder( our_tokens = tokenizer(prompt) assert torch.equal(our_tokens, ref_tokens) - with torch.no_grad(): + with no_grad(): ref_embeddings = ref_encoder_with_new_concepts(ref_tokens.to(test_device))[0] our_embeddings = our_encoder_with_new_concepts(prompt) diff --git a/tests/foundationals/clip/test_image_encoder.py b/tests/foundationals/clip/test_image_encoder.py index 3aac6684e..ff990bda5 100644 --- a/tests/foundationals/clip/test_image_encoder.py +++ b/tests/foundationals/clip/test_image_encoder.py @@ -5,7 +5,7 @@ import torch from transformers import CLIPVisionModelWithProjection # type: ignore -from refiners.fluxion.utils import load_from_safetensors +from refiners.fluxion.utils import load_from_safetensors, no_grad from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH @@ -44,7 +44,7 @@ def test_encoder( ): x = torch.randn(1, 3, 224, 224).to(test_device) - with torch.no_grad(): + with no_grad(): ref_embeddings = ref_encoder(x).image_embeds our_embeddings = our_encoder(x) diff --git a/tests/foundationals/clip/test_text_encoder.py b/tests/foundationals/clip/test_text_encoder.py index 0e108b7f8..f1b6f07c6 100644 --- a/tests/foundationals/clip/test_text_encoder.py +++ b/tests/foundationals/clip/test_text_encoder.py @@ -5,7 +5,7 @@ import torch import transformers # type: ignore -from refiners.fluxion.utils import load_from_safetensors +from refiners.fluxion.utils import load_from_safetensors, no_grad from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.tokenizer import CLIPTokenizer @@ -89,7 +89,7 @@ def test_encoder( our_tokens = tokenizer(prompt) assert torch.equal(our_tokens, ref_tokens) - with torch.no_grad(): + with no_grad(): ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0] our_embeddings = our_encoder(prompt) diff --git a/tests/foundationals/dinov2/test_dinov2.py b/tests/foundationals/dinov2/test_dinov2.py index 9b5f40e77..7bcf81863 100644 --- a/tests/foundationals/dinov2/test_dinov2.py +++ b/tests/foundationals/dinov2/test_dinov2.py @@ -7,7 +7,7 @@ from transformers import AutoModel # type: ignore from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore -from refiners.fluxion.utils import load_from_safetensors, manual_seed +from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad from refiners.foundationals.dinov2 import ( DINOv2_base, DINOv2_base_reg, @@ -124,7 +124,7 @@ def test_encoder( x = torch.randn(1, 3, 518, 518).to(test_device) - with torch.no_grad(): + with no_grad(): ref_features = ref_backbone(x).last_hidden_state our_features = our_backbone(x) diff --git a/tests/foundationals/latent_diffusion/test_auto_encoder.py b/tests/foundationals/latent_diffusion/test_auto_encoder.py index 2ddca248e..462c40776 100644 --- a/tests/foundationals/latent_diffusion/test_auto_encoder.py +++ b/tests/foundationals/latent_diffusion/test_auto_encoder.py @@ -6,7 +6,7 @@ from PIL import Image from tests.utils import ensure_similar_images -from refiners.fluxion.utils import load_from_safetensors +from refiners.fluxion.utils import load_from_safetensors, no_grad from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder @@ -38,7 +38,7 @@ def sample_image(ref_path: Path) -> Image.Image: return img -@torch.no_grad() +@no_grad() def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image): encoded = encoder.encode_image(sample_image) decoded = encoder.decode_latents(encoded) diff --git a/tests/foundationals/latent_diffusion/test_controlnet.py b/tests/foundationals/latent_diffusion/test_controlnet.py index 36f3b04b7..4bfc5e600 100644 --- a/tests/foundationals/latent_diffusion/test_controlnet.py +++ b/tests/foundationals/latent_diffusion/test_controlnet.py @@ -1,10 +1,10 @@ from typing import Iterator import pytest -import torch import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import lookup_top_adapter +from refiners.fluxion.utils import no_grad from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter, SD1UNet from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet @@ -18,7 +18,7 @@ def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]: yield unet -@torch.no_grad() +@no_grad() def test_single_controlnet(unet: SD1UNet) -> None: original_parent = unet.parent cn = SD1ControlnetAdapter(unet, name="cn") @@ -43,7 +43,7 @@ def test_single_controlnet(unet: SD1UNet) -> None: assert len(list(unet.walk(Controlnet))) == 0 -@torch.no_grad() +@no_grad() def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None: original_parent = unet.parent cn1 = SD1ControlnetAdapter(unet, name="cn1").inject() @@ -71,7 +71,7 @@ def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None: assert len(list(unet.walk(Controlnet))) == 0 -@torch.no_grad() +@no_grad() def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None: original_parent = unet.parent cn1 = SD1ControlnetAdapter(unet, name="cn1").inject() @@ -86,7 +86,7 @@ def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None: assert len(list(unet.walk(Controlnet))) == 0 -@torch.no_grad() +@no_grad() def test_two_controlnets_same_name(unet: SD1UNet) -> None: SD1ControlnetAdapter(unet, name="cnx").inject() cn2 = SD1ControlnetAdapter(unet, name="cnx") diff --git a/tests/foundationals/latent_diffusion/test_freeu.py b/tests/foundationals/latent_diffusion/test_freeu.py index 6b7001b24..3e4553ec2 100644 --- a/tests/foundationals/latent_diffusion/test_freeu.py +++ b/tests/foundationals/latent_diffusion/test_freeu.py @@ -4,6 +4,7 @@ import torch from refiners.fluxion import manual_seed +from refiners.fluxion.utils import no_grad from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter @@ -52,14 +53,14 @@ def test_freeu_identity_scales() -> None: unet = SD1UNet(in_channels=4) unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s - with torch.no_grad(): + with no_grad(): unet.set_timestep(timestep=timestep) y_1 = unet(x.clone()) freeu = SDFreeUAdapter(unet, backbone_scales=[1.0, 1.0], skip_scales=[1.0, 1.0]) freeu.inject() - with torch.no_grad(): + with no_grad(): unet.set_timestep(timestep=timestep) y_2 = unet(x.clone()) diff --git a/tests/foundationals/latent_diffusion/test_image_prompt.py b/tests/foundationals/latent_diffusion/test_image_prompt.py deleted file mode 100644 index 612de2e58..000000000 --- a/tests/foundationals/latent_diffusion/test_image_prompt.py +++ /dev/null @@ -1,29 +0,0 @@ -import refiners.fluxion.layers as fl -from refiners.foundationals.latent_diffusion.image_prompt import CrossAttentionAdapter, InjectionPoint - - -def test_cross_attention_adapter() -> None: - base = fl.Chain(fl.Attention(embedding_dim=4)) - adapter = CrossAttentionAdapter(base.Attention).inject() - - assert list(base) == [adapter] - assert len(list(adapter.layers(fl.Linear))) == 6 - assert len(list(base.layers(fl.Linear))) == 6 - - injection_points = list(adapter.layers(InjectionPoint)) - assert len(injection_points) == 4 - for ip in injection_points: - assert len(ip) == 1 - assert isinstance(ip[0], fl.Linear) - - adapter.eject() - - assert len(base) == 1 - assert isinstance(base[0], fl.Attention) - assert len(list(adapter.layers(fl.Linear))) == 2 - assert len(list(base.layers(fl.Linear))) == 4 - - injection_points = list(adapter.layers(InjectionPoint)) - assert len(injection_points) == 4 - for ip in injection_points: - assert len(ip) == 0 diff --git a/tests/foundationals/latent_diffusion/test_reference_only_control.py b/tests/foundationals/latent_diffusion/test_reference_only_control.py index 68833b3ae..d0ed8a3f5 100644 --- a/tests/foundationals/latent_diffusion/test_reference_only_control.py +++ b/tests/foundationals/latent_diffusion/test_reference_only_control.py @@ -1,6 +1,6 @@ import pytest -import torch +from refiners.fluxion.utils import no_grad from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock from refiners.foundationals.latent_diffusion.reference_only_control import ( @@ -11,7 +11,7 @@ ) -@torch.no_grad() +@no_grad() def test_refonly_inject_eject() -> None: unet = SD1UNet(in_channels=9) adapter = ReferenceOnlyControlAdapter(unet) diff --git a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py index cb51253bb..9435b89da 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py @@ -7,7 +7,7 @@ from torch import Tensor import refiners.fluxion.layers as fl -from refiners.fluxion.utils import manual_seed +from refiners.fluxion.utils import manual_seed, no_grad from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder @@ -65,7 +65,7 @@ def double_text_encoder(double_text_encoder_weights: Path) -> DoubleTextEncoder: return double_text_encoder -@torch.no_grad() +@no_grad() def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None: manual_seed(seed=0) prompt = "A photo of a pizza." diff --git a/tests/foundationals/latent_diffusion/test_sdxl_unet.py b/tests/foundationals/latent_diffusion/test_sdxl_unet.py index 95b031bc9..c3d0f10a4 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_unet.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_unet.py @@ -6,7 +6,7 @@ import torch from refiners.fluxion.model_converter import ConversionStage, ModelConverter -from refiners.fluxion.utils import manual_seed +from refiners.fluxion.utils import manual_seed, no_grad from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet @@ -37,7 +37,7 @@ def refiners_sdxl_unet() -> SDXLUNet: return unet -@torch.no_grad() +@no_grad() def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None: source = diffusers_sdxl_unet target = refiners_sdxl_unet diff --git a/tests/foundationals/latent_diffusion/test_unet.py b/tests/foundationals/latent_diffusion/test_unet.py index 4fe09e34a..210eca720 100644 --- a/tests/foundationals/latent_diffusion/test_unet.py +++ b/tests/foundationals/latent_diffusion/test_unet.py @@ -1,6 +1,7 @@ import torch from refiners.fluxion import manual_seed +from refiners.fluxion.utils import no_grad from refiners.foundationals.latent_diffusion import SD1UNet @@ -13,11 +14,11 @@ def test_unet_context_flush(): unet = SD1UNet(in_channels=4) unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s - with torch.no_grad(): + with no_grad(): unet.set_timestep(timestep=timestep) y_1 = unet(x.clone()) - with torch.no_grad(): + with no_grad(): unet.set_timestep(timestep=timestep) y_2 = unet(x.clone()) diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index 0c5fbf978..1e5668542 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -18,12 +18,12 @@ from refiners.fluxion import manual_seed from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import image_to_tensor +from refiners.fluxion.utils import image_to_tensor, no_grad from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention from refiners.foundationals.segment_anything.model import SegmentAnythingH from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer -# See predictor_example.ipynb official notebook (note: mask_input is not yet properly supported) +# See predictor_example.ipynb official notebook PROMPTS: list[SAMPrompt] = [ SAMPrompt(foreground_points=((500, 375),)), SAMPrompt(background_points=((500, 375),)), @@ -41,7 +41,9 @@ def prompt(request: pytest.FixtureRequest) -> SAMPrompt: @pytest.fixture def one_prompt() -> SAMPrompt: - return PROMPTS[0] + # Using the third prompt of the PROMPTS list in order to strictly do the same test as the official notebook in the + # test_predictor_dense_mask test. + return PROMPTS[2] @pytest.fixture(scope="module") @@ -83,8 +85,7 @@ def facebook_sam_h_predictor(facebook_sam_h: FacebookSAM) -> FacebookSAMPredicto @pytest.fixture(scope="module") def sam_h(sam_h_weights: Path, test_device: torch.device) -> SegmentAnythingH: sam_h = SegmentAnythingH(device=test_device) - # TODO: make strict=True when the MasKEncoder conversion is done - sam_h.load_from_safetensors(tensors_path=sam_h_weights, strict=False) + sam_h.load_from_safetensors(tensors_path=sam_h_weights) return sam_h @@ -98,7 +99,7 @@ def truck(ref_path: Path) -> Image.Image: return Image.open(ref_path / "truck.jpg").convert("RGB") -@torch.no_grad() +@no_grad() def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None: manual_seed(seed=0) x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device) @@ -124,7 +125,7 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None: assert torch.equal(input=y_1, other=y_2) -@torch.no_grad() +@no_grad() def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None: image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device) y_1 = facebook_sam_h.image_encoder(image_tensor) @@ -133,7 +134,7 @@ def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, tru assert torch.allclose(input=y_1, other=y_2, atol=1e-4) -@torch.no_grad() +@no_grad() def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: facebook_prompt_encoder = facebook_sam_h.prompt_encoder refiners_prompt_encoder = sam_h.point_encoder @@ -144,7 +145,7 @@ def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe) -@torch.no_grad() +@no_grad() def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: facebook_prompt_encoder = facebook_sam_h.prompt_encoder refiners_prompt_encoder = sam_h.mask_encoder @@ -155,7 +156,7 @@ def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe) -@torch.no_grad() +@no_grad() def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, prompt: SAMPrompt) -> None: facebook_prompt_encoder = facebook_sam_h.prompt_encoder refiners_prompt_encoder = sam_h.point_encoder @@ -164,7 +165,14 @@ def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, pro **prompt.facebook_prompt_encoder_kwargs(device=facebook_sam_h.device) ) - coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt.__dict__) + prompt_dict = prompt.__dict__ + # Skip mask prompt, if any, since the point encoder only consumes points and boxes + # TODO: split `SAMPrompt` and introduce a dedicated one for dense prompts + prompt_dict.pop("low_res_mask", None) + + assert prompt_dict is not None, "`test_point_encoder` cannot be called with just a `low_res_mask`" + + coordinates, type_mask = refiners_prompt_encoder.points_to_tensor(**prompt_dict) # Shift to center of pixel + normalize in [0, 1] (see `_embed_points` in segment-anything official repo) coordinates[:, :, 0] = (coordinates[:, :, 0] + 0.5) / 1024.0 coordinates[:, :, 1] = (coordinates[:, :, 1] + 0.5) / 1024.0 @@ -174,7 +182,7 @@ def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, pro assert torch.equal(input=refiners_sparse_pe, other=facebook_sparse_pe) -@torch.no_grad() +@no_grad() def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None: dense_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device) dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device) @@ -223,7 +231,7 @@ def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None: assert torch.equal(input=y_1, other=y_2) -@torch.no_grad() +@no_grad() def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: manual_seed(seed=0) facebook_mask_decoder = facebook_sam_h.mask_decoder @@ -319,3 +327,91 @@ def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, assert torch.equal(masks, masks_ref) assert torch.equal(scores_ref, scores) + + +def test_predictor_dense_mask( + facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt +) -> None: + """ + NOTE : Binarizing intermediate masks isn't necessary, as per SamPredictor.predict_torch docstring: + > mask_input (np.ndarray): A low resolution mask input to the model, typically + > coming from a previous prediction iteration. Has form Bx1xHxW, where + > for SAM, H=W=256. Masks returned by a previous iteration of the + > predict method do not need further transformation. + """ + predictor = facebook_sam_h_predictor + predictor.set_image(np.array(truck)) + facebook_masks, facebook_scores, facebook_logits = predictor.predict( + **one_prompt.facebook_predict_kwargs(), # type: ignore + multimask_output=True, + ) + + assert len(facebook_masks) == 3 + + facebook_mask_input = facebook_logits[np.argmax(facebook_scores)] # shape: HxW + + # Using the same mask coordinates inputs as the official notebook + facebook_prompt = SAMPrompt( + foreground_points=((500, 375),), background_points=((1125, 625),), low_res_mask=facebook_mask_input[None, ...] + ) + facebook_dense_masks, _, _ = predictor.predict(**facebook_prompt.facebook_predict_kwargs(), multimask_output=True) # type: ignore + + assert len(facebook_dense_masks) == 3 + + masks, scores, logits = sam_h.predict(truck, **one_prompt.__dict__) + masks = masks.squeeze(0) + scores = scores.squeeze(0) + + assert len(masks) == 3 + + mask_input = logits[:, scores.max(dim=0).indices, ...] # shape: 1xHxW + + assert np.allclose( + mask_input.cpu().numpy(), facebook_mask_input, atol=1e-1 + ) # Lower doesn't pass, but it's close enough for logits + + refiners_prompt = SAMPrompt( + foreground_points=((500, 375),), background_points=((1125, 625),), low_res_mask=mask_input.unsqueeze(0) + ) + dense_masks, _, _ = sam_h.predict(truck, **refiners_prompt.__dict__) + dense_masks = dense_masks.squeeze(0) + + assert len(dense_masks) == 3 + + for i in range(3): + dense_mask_prediction = dense_masks[i].cpu() + facebook_dense_mask = torch.as_tensor(facebook_dense_masks[i]) + assert dense_mask_prediction.shape == facebook_dense_mask.shape + assert isclose(intersection_over_union(dense_mask_prediction, facebook_dense_mask), 1.0, rel_tol=5e-05) + + +def test_mask_encoder( + facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt +) -> None: + predictor = facebook_sam_h_predictor + predictor.set_image(np.array(truck)) + _, facebook_scores, facebook_logits = predictor.predict( + **one_prompt.facebook_predict_kwargs(), # type: ignore + multimask_output=True, + ) + facebook_mask_input = facebook_logits[np.argmax(facebook_scores)] + facebook_mask_input = ( + torch.from_numpy(facebook_mask_input) # type: ignore + .to(device=predictor.model.device) + .unsqueeze(0) + .unsqueeze(0) # shape: 1x1xHxW + ) + + _, fb_dense_embeddings = predictor.model.prompt_encoder( + points=None, + boxes=None, + masks=facebook_mask_input, + ) + + _, scores, logits = sam_h.predict(truck, **one_prompt.__dict__) + scores = scores.squeeze(0) + mask_input = logits[:, scores.max(dim=0).indices, ...].unsqueeze(0) # shape: 1x1xHxW + dense_embeddings = sam_h.mask_encoder(mask_input) + + assert facebook_mask_input.shape == mask_input.shape + assert torch.allclose(dense_embeddings, fb_dense_embeddings, atol=1e-4, rtol=1e-4) diff --git a/tests/foundationals/segment_anything/utils.py b/tests/foundationals/segment_anything/utils.py index ef73e36ae..fa18e8841 100644 --- a/tests/foundationals/segment_anything/utils.py +++ b/tests/foundationals/segment_anything/utils.py @@ -63,8 +63,7 @@ class SAMPrompt: foreground_points: Sequence[tuple[float, float]] | None = None background_points: Sequence[tuple[float, float]] | None = None box_points: Sequence[Sequence[tuple[float, float]]] | None = None - # TODO: support masks - # masks: Sequence[Image.Image] | None = None + low_res_mask: Tensor | None = None def facebook_predict_kwargs(self) -> dict[str, NDArray]: prompt: dict[str, NDArray] = {} @@ -85,13 +84,18 @@ def facebook_predict_kwargs(self) -> dict[str, NDArray]: prompt["box"] = np.array([coord for batch in self.box_points for xy in batch for coord in xy]).reshape( len(self.box_points), 4 ) + if self.low_res_mask is not None: + prompt["mask_input"] = np.array(self.low_res_mask) return prompt - def facebook_prompt_encoder_kwargs(self, device: torch.device | None = None): + def facebook_prompt_encoder_kwargs( + self, device: torch.device | None = None + ) -> dict[str, Tensor | tuple[Tensor, Tensor | None] | None]: prompt = self.facebook_predict_kwargs() coords: Tensor | None = None labels: Tensor | None = None boxes: Tensor | None = None + masks: Tensor | None = None if "point_coords" in prompt: coords = torch.as_tensor(prompt["point_coords"], dtype=torch.float, device=device).unsqueeze(0) if "point_labels" in prompt: @@ -99,8 +103,9 @@ def facebook_prompt_encoder_kwargs(self, device: torch.device | None = None): if "box" in prompt: boxes = torch.as_tensor(prompt["box"], dtype=torch.float, device=device).unsqueeze(0) points = (coords, labels) if coords is not None else None - # TODO: support masks - return {"points": points, "boxes": boxes, "masks": None} + if "mask_input" in prompt: + masks = torch.as_tensor(prompt["mask_input"], dtype=torch.float, device=device).unsqueeze(0) + return {"points": points, "boxes": boxes, "masks": masks} def intersection_over_union(