diff --git a/.github/download-models-weights.py b/.github/download-models-weights.py index fa7918bc6e..20efa2a308 100644 --- a/.github/download-models-weights.py +++ b/.github/download-models-weights.py @@ -1,11 +1,15 @@ import argparse +import os +import diffusers import torch -fonts = { - "sold2_wireframe": "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth", +models = { + "sold2_wireframe": ("torchhub", "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth"), + "stabilityai/stable-diffusion-2-1": ("diffusers", "StableDiffusionPipeline"), } + if __name__ == "__main__": parser = argparse.ArgumentParser("WeightsDownloader") parser.add_argument("--target_directory", "-t", required=False, default="target_directory") @@ -13,9 +17,18 @@ args = parser.parse_args() torch.hub.set_dir(args.target_directory) + # For HuggingFace model caching + os.environ["HF_HOME"] = args.target_directory - for name, url in fonts.items(): - print(f"Downloading weights of `{name}` from `url`. Caching to dir `{args.target_directory}`") - torch.hub.load_state_dict_from_url(url, model_dir=args.target_directory, map_location=torch.device("cpu")) + for name, (src, path) in models.items(): + if src == "torchhub": + print(f"Downloading weights of `{name}` from `{path}`. Caching to dir `{args.target_directory}`") + torch.hub.load_state_dict_from_url(path, model_dir=args.target_directory, map_location=torch.device("cpu")) + elif src == "diffusers": + print(f"Downloading `{name}` from diffusers. Caching to dir `{args.target_directory}`") + if path == "StableDiffusionPipeline": + diffusers.StableDiffusionPipeline.from_pretrained( + name, cache_dir=args.target_directory, device_map="balanced" + ) raise SystemExit(0) diff --git a/README.md b/README.md index fb790b4026..3150bab19c 100644 --- a/README.md +++ b/README.md @@ -24,11 +24,53 @@ English | [简体中文](README_zh-CN.md)

-**Kornia** is a differentiable computer vision library for [PyTorch](https://pytorch.org). +**Kornia** is a differentiable computer vision library that provides a rich set of differentiable image processing and geometric vision algorithms. Built on top of [PyTorch](https://pytorch.org), Kornia integrates seamlessly into existing AI workflows, allowing you to leverage powerful [batch transformations](), [auto-differentiation]() and [GPU acceleration](). Whether you’re working on image transformations, augmentations, or AI-driven image processing, Kornia equips you with the tools you need to bring your ideas to life. + +## Key Components +1. **Differentiable Image Processing**
+ Kornia provides a comprehensive suite of image processing operators, all differentiable and ready to integrate into deep learning pipelines. + - **Filters**: Gaussian, Sobel, Median, Box Blur, etc. + - **Transformations**: Affine, Homography, Perspective, etc. + - **Enhancements**: Histogram Equalization, CLAHE, Gamma Correction, etc. + - **Edge Detection**: Canny, Laplacian, Sobel, etc. + - ... check our [docs](https://kornia.readthedocs.io) for more. +2. **Advanced Augmentations**
+Perform powerful data augmentation with Kornia’s built-in functions, ideal for training AI models with complex augmentation pipelines. + - **Augmentation Pipeline**: AugmentationSequential, PatchSequential, VideoSequential, etc. + - **Automatic Augmentation**: AutoAugment, RandAugment, TrivialAugment. +3. **AI Models**
+Leverage pre-trained AI models optimized for a variety of vision tasks, all within the Kornia ecosystem. + - **Face Detection**: YuNet + - **Feature Matching**: LoFTR, LightGlue + - **Feature Descriptor**: DISK, DeDoDe, SOLD2 + - **Segmentation**: SAM + - **Classification**: MobileViT, VisionTransformer. -It consists of a set of routines and differentiable modules to solve generic computer vision problems. At its core, the package uses *PyTorch* as its main backend both for efficiency and to take advantage of the reverse-mode auto-differentiation to define and compute the gradient of complex functions. +
+See here for some of the methods that we support! (>500 ops in total !) + +| **Category** | **Methods/Models** | +|----------------------------|---------------------------------------------------------------------------------------------------------------------| +| **Image Processing** | - Color conversions (RGB, Grayscale, HSV, etc.)
- Geometric transformations (Affine, Homography, Resizing, etc.)
- Filtering (Gaussian blur, Median blur, etc.)
- Edge detection (Sobel, Canny, etc.)
- Morphological operations (Erosion, Dilation, etc.) | +| **Augmentation** | - Random cropping, Erasing
- Random geometric transformations (Affine, flipping, Fish Eye, Perspecive, Thin plate spline, Elastic)
- Random noises (Gaussian, Median, Motion, Box, Rain, Snow, Salt and Pepper)
- Random color jittering (Contrast, Brightness, CLAHE, Equalize, Gamma, Hue, Invert, JPEG, Plasma, Posterize, Saturation, Sharpness, Solarize)
- Random MixUp, CutMix, Mosaic, Transplantation, etc. | +| **Feature Detection** | - Detector (Harris, GFTT, Hessian, DoG, KeyNet, DISK and DeDoDe)
- Descriptor (SIFT, HardNet, TFeat, HyNet, SOSNet, and LAFDescriptor)
- Matching (nearest neighbor, mutual nearest neighbor, geometrically aware matching, AdaLAM LightGlue, and LoFTR) | +| **Geometry** | - Camera models and calibration
- Stereo vision (epipolar geometry, disparity, etc.)
- Homography estimation
- Depth estimation from disparity
- 3D transformations | +| **Deep Learning Layers** | - Custom convolution layers
- Recurrent layers for vision tasks
- Loss functions (e.g., SSIM, PSNR, etc.)
- Vision-specific optimizers | +| **Photometric Functions** | - Photometric loss functions
- Photometric augmentations | +| **Filtering** | - Bilateral filtering
- DexiNed
- Dissolving
- Guided Blur
- Laplacian
- Gaussian
- Non-local means
- Sobel
- Unsharp masking | +| **Color** | - Color space conversions
- Brightness/contrast adjustment
- Gamma correction | +| **Stereo Vision** | - Disparity estimation
- Depth estimation
- Rectification | +| **Image Registration** | - Affine and homography-based registration
- Image alignment using feature matching | +| **Pose Estimation** | - Essential and Fundamental matrix estimation
- PnP problem solvers
- Pose refinement | +| **Optical Flow** | - Farneback optical flow
- Dense optical flow
- Sparse optical flow | +| **3D Vision** | - Depth estimation
- Point cloud operations
- Nerf
| +| **Image Denoising** | - Gaussian noise removal
- Poisson noise removal | +| **Edge Detection** | - Sobel operator
- Canny edge detection | | +| **Transformations** | - Rotation
- Translation
- Scaling
- Shearing | +| **Loss Functions** | - SSIM (Structural Similarity Index Measure)
- PSNR (Peak Signal-to-Noise Ratio)
- Cauchy
- Charbonnier
- Depth Smooth
- Dice
- Hausdorff
- Tversky
- Welsch
| | +| **Morphological Operations**| - Dilation
- Erosion
- Opening
- Closing | -Inspired by existing packages, this library is composed by a subset of packages containing operators that can be inserted within neural networks to train models to perform image transformations, epipolar geometry, depth estimation, and low-level image processing such as filtering and edge detection that operate directly on tensors. +
## Sponsorship @@ -66,6 +108,62 @@ Kornia is an open-source project that is developed and maintained by volunteers. +## Quick Start + +Kornia is not just another computer vision library — it's your gateway to effortless Computer Vision and AI. + +```python +import numpy as np +import kornia_rs as kr + +from kornia.augmentation import AugmentationSequential, RandomAffine, RandomBrightness +from kornia.filters import StableDiffusionDissolving + +# Load and prepare your image +img: np.ndarray = kr.read_image_any("img.jpeg") +img = kr.resize(img, (256, 256), interpolation="bilinear") + +# alternatively, load image with PIL +# img = Image.open("img.jpeg").resize((256, 256)) +# img = np.array(img) + +img = np.stack([img] * 2) # batch images + +# Define an augmentation pipeline +augmentation_pipeline = AugmentationSequential( + RandomAffine((-45., 45.), p=1.), + RandomBrightness((0.,1.), p=1.) +) + +# Leveraging StableDiffusion models +dslv_op = StableDiffusionDissolving() + +img = augmentation_pipeline(img) +dslv_op(img, step_number=500) + +dslv_op.save("Kornia-enhanced.jpg") +``` + +## Call For Contributors + +Are you passionate about computer vision, AI, and open-source development? Join us in shaping the future of Kornia! We are actively seeking contributors to help expand and enhance our library, making it even more powerful, accessible, and versatile. Whether you're an experienced developer or just starting, there's a place for you in our community. + +### Accessible AI Models + +We are excited to announce our latest advancement: a new initiative designed to seamlessly integrate lightweight AI models into Kornia. +We aim to run any models as smooth as big models such as StableDiffusion, to support them well in many perspectives. +We have already included a selection of lightweight AI models like [YuNet (Face Detection)](), [Loftr (Feature Matching)](), and [SAM (Segmentation)](). Now, we're looking for contributors to help us: + +- Expand the Model Selection: Import decent models into our library. If you are a researcher, Kornia is an excellent place for you to promote your model! +- Model Optimization: Work on optimizing models to reduce their computational footprint while maintaining accuracy and performance. You may start from offering ONNX support! +- Model Documentation: Create detailed guides and examples to help users get the most out of these models in their projects. + + +### Documentation And Tutorial Optimization + +Kornia's foundation lies in its extensive collection of classic computer vision operators, providing robust tools for image processing, feature extraction, and geometric transformations. We continuously seek for contributors to help us improve our documentation and present nice tutorials to our users. + + ## Cite If you are using kornia in your research-related documents, it is recommended that you cite the paper. See more in [CITATION](./CITATION.md). diff --git a/conftest.py b/conftest.py index 05c0ff213f..41493ab107 100644 --- a/conftest.py +++ b/conftest.py @@ -196,6 +196,8 @@ def pytest_sessionstart(session): os.makedirs(WEIGHTS_CACHE_DIR, exist_ok=True) torch.hub.set_dir(WEIGHTS_CACHE_DIR) + # For HuggingFace model caching + os.environ["HF_HOME"] = WEIGHTS_CACHE_DIR def _get_env_info() -> Dict[str, Dict[str, str]]: diff --git a/docs/source/augmentation.module.rst b/docs/source/augmentation.module.rst index 5cd67d97ea..99b0d7e233 100644 --- a/docs/source/augmentation.module.rst +++ b/docs/source/augmentation.module.rst @@ -21,6 +21,7 @@ Intensity .. autoclass:: RandomClahe .. autoclass:: RandomContrast .. autoclass:: RandomEqualize +.. autoclass:: RandomDissolving .. autoclass:: RandomGamma .. autoclass:: RandomGaussianBlur .. autoclass:: RandomGaussianIllumination diff --git a/docs/source/get-started/highlights.rst b/docs/source/get-started/highlights.rst index 3fda157824..76abc4a717 100644 --- a/docs/source/get-started/highlights.rst +++ b/docs/source/get-started/highlights.rst @@ -1,6 +1,20 @@ Highlighted Features ==================== +At Kornia, we are dedicated to pushing the boundaries of computer vision by providing a robust, efficient, and versatile toolkit. Our library is built on the powerful PyTorch backend, leveraging its efficiency and auto-differentiation capabilities to deliver high-performance solutions for a wide range of vision tasks. + + +Accessible AI Models +-------------------- + +We are excited to announce our latest advancement: a new initiative designed to seamlessly integrate lightweight AI models into Kornia. This addition is crafted to empower developers, researchers, and enthusiasts to harness the full potential of accessible AI, simplifying complex vision tasks and accelerating innovation. + +We have curated a selection of lightweight AI models, including YuNet, Loftr, and SAM, optimized for performance and efficiency. These models offer efficient computations that do not require expensive GPUs, making cutting-edge AI accessible to everyone. We welcome the whole community of developers and researchers who are passionate about advancing computer vision, throwing PRs for your lightning fast models! + + +Classic Operators +----------------- + .. image:: https://github.com/kornia/data/raw/main/kornia_paper_mosaic.png :align: center diff --git a/docs/source/references.bib b/docs/source/references.bib index aee01394a6..33a41a4080 100644 --- a/docs/source/references.bib +++ b/docs/source/references.bib @@ -405,3 +405,10 @@ @article{wang2023vggsfm journal={arXiv preprint arXiv:2312.04563}, year={2023} } + +@misc{shi2024dissolving, + Author = {Jian Shi and Pengyi Zhang and Ni Zhang and Hakim Ghazzai and Peter Wonka}, + Title = {Dissolving Is Amplifying: Towards Fine-Grained Anomaly Detection}, + booktitle = {ECCV}, + Year = {2024}, +} diff --git a/kornia/augmentation/_2d/intensity/__init__.py b/kornia/augmentation/_2d/intensity/__init__.py index 47d01a0a7b..f96ae03adc 100644 --- a/kornia/augmentation/_2d/intensity/__init__.py +++ b/kornia/augmentation/_2d/intensity/__init__.py @@ -8,6 +8,7 @@ from kornia.augmentation._2d.intensity.color_jitter import ColorJitter from kornia.augmentation._2d.intensity.contrast import RandomContrast from kornia.augmentation._2d.intensity.denormalize import Denormalize +from kornia.augmentation._2d.intensity.dissolving import RandomDissolving from kornia.augmentation._2d.intensity.equalize import RandomEqualize from kornia.augmentation._2d.intensity.erasing import RandomErasing from kornia.augmentation._2d.intensity.gamma import RandomGamma diff --git a/kornia/augmentation/_2d/intensity/dissolving.py b/kornia/augmentation/_2d/intensity/dissolving.py new file mode 100644 index 0000000000..f64bc93678 --- /dev/null +++ b/kornia/augmentation/_2d/intensity/dissolving.py @@ -0,0 +1,57 @@ +from typing import Any, Dict, Optional, Tuple + +from kornia.augmentation import random_generator as rg +from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D +from kornia.core import Tensor +from kornia.filters import StableDiffusionDissolving + + +class RandomDissolving(IntensityAugmentationBase2D): + r"""Perform dissolving transformation using StableDiffusion models. + + Based on :cite:`shi2024dissolving`, the dissolving transformation is essentially applying one-step + reverse diffusion. Our implementation currently supports HuggingFace implementations of SD 1.4, 1.5 + and 2.1. SD 1.X tends to remove more details than SD2.1. + + .. list-table:: Title + :widths: 32 32 32 + :header-rows: 1 + + * - SD 1.4 + - SD 1.5 + - SD 2.1 + * - figure:: https://raw.githubusercontent.com/kornia/data/main/dslv-sd-1.4.png + - figure:: https://raw.githubusercontent.com/kornia/data/main/dslv-sd-1.5.png + - figure:: https://raw.githubusercontent.com/kornia/data/main/dslv-sd-2.1.png + + Args: + p: probability of applying the transformation. + version: the version of the stable diffusion model. + step_range: the step range of the diffusion model steps. Higher the step, stronger + the dissolving effects. + keepdim: whether to keep the output shape the same as input (True) or broadcast it + to the batch form (False). + **kwargs: additional arguments for `.from_pretrained` for HF StableDiffusionPipeline. + + Shape: + - Input: :math:`(C, H, W)` or :math:`(B, C, H, W)`. + - Output: :math:`(B, C, H, W)` + """ + + def __init__( + self, + step_range: Tuple[float, float] = (100, 500), + version: str = "2.1", + p: float = 0.5, + keepdim: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(p=p, same_on_batch=True, keepdim=keepdim) + self.step_range = step_range + self._dslv = StableDiffusionDissolving(version, **kwargs) + self._param_generator = rg.PlainUniformGenerator((self.step_range, "step_range_factor", None, None)) + + def apply_transform( + self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None + ) -> Tensor: + return self._dslv(input, params["step_range_factor"][0].long().item()) diff --git a/kornia/augmentation/__init__.py b/kornia/augmentation/__init__.py index 0210675c97..f0338975b9 100644 --- a/kornia/augmentation/__init__.py +++ b/kornia/augmentation/__init__.py @@ -18,6 +18,7 @@ RandomContrast, RandomCrop, RandomCutMixV2, + RandomDissolving, RandomElasticTransform, RandomEqualize, RandomErasing, @@ -115,6 +116,7 @@ "RandomChannelShuffle", "RandomContrast", "RandomCrop", + "RandomDissolving", "RandomErasing", "RandomElasticTransform", "RandomFisheye", diff --git a/kornia/core/__init__.py b/kornia/core/__init__.py index 200bbf6a82..2e1ebc2f7c 100644 --- a/kornia/core/__init__.py +++ b/kornia/core/__init__.py @@ -1,3 +1,4 @@ +from . import external from ._backend import ( Device, Dtype, @@ -35,6 +36,7 @@ from .tensor_wrapper import TensorWrapper # type: ignore __all__ = [ + "external", "arange", "concatenate", "Device", diff --git a/kornia/core/external.py b/kornia/core/external.py index 8127719f3d..4efdbfe189 100644 --- a/kornia/core/external.py +++ b/kornia/core/external.py @@ -69,3 +69,4 @@ def __dir__(self) -> List[str]: numpy = LazyLoader("numpy") PILImage = LazyLoader("PIL.Image") +diffusers = LazyLoader("diffusers") diff --git a/kornia/core/module.py b/kornia/core/module.py index 49cd8ed69c..97132caf52 100644 --- a/kornia/core/module.py +++ b/kornia/core/module.py @@ -209,7 +209,7 @@ def save(self, name: Optional[str] = None, n_row: Optional[int] = None) -> None: n_row: Number of images displayed in each row of the grid. """ if name is None: - name = f"Kornia-{datetime.datetime.now(tz=datetime.UTC).strftime('%Y%m%d%H%M%S')!s}.jpg" + name = f"Kornia-{datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y%m%d%H%M%S')!s}.jpg" if len(self._output_image.shape) == 3: out_image = self._output_image if len(self._output_image.shape) == 4: diff --git a/kornia/filters/__init__.py b/kornia/filters/__init__.py index fc1a4e44e0..43dcc15ac7 100644 --- a/kornia/filters/__init__.py +++ b/kornia/filters/__init__.py @@ -12,6 +12,7 @@ ) from .canny import Canny, canny from .dexined import DexiNed +from .dissolving import StableDiffusionDissolving from .filter import filter2d, filter2d_separable, filter3d from .gaussian import GaussianBlur2d, gaussian_blur2d, gaussian_blur2d_t from .guided import GuidedBlur, guided_blur @@ -92,6 +93,7 @@ "Canny", "BoxBlur", "BlurPool2D", + "StableDiffusionDissolving", "MaxBlurPool2D", "EdgeAwareBlurPool2D", "MedianBlur", diff --git a/kornia/filters/dissolving.py b/kornia/filters/dissolving.py new file mode 100644 index 0000000000..154ac82404 --- /dev/null +++ b/kornia/filters/dissolving.py @@ -0,0 +1,136 @@ +from typing import Any + +import torch + +from kornia.core import ImageModule, Module, Tensor +from kornia.core.external import diffusers + + +class _DissolvingWraper_HF: + def __init__(self, model: Module, num_ddim_steps: int = 50) -> None: + self.model = model + self.num_ddim_steps = num_ddim_steps + self.tokenizer = self.model.tokenizer + self.model.scheduler.set_timesteps(self.num_ddim_steps) + self.total_steps = len(self.model.scheduler.timesteps) # Total number of sampling steps. + self.prompt: str + self.context: Tensor + + def predict_start_from_noise(self, noise_pred: Tensor, timestep: int, latent: Tensor) -> Tensor: + return ( + torch.sqrt(1.0 / self.model.scheduler.alphas_cumprod[timestep]) * latent + - torch.sqrt(1.0 / self.model.scheduler.alphas_cumprod[timestep] - 1) * noise_pred + ) + + @torch.no_grad() + def init_prompt(self, prompt: str) -> None: + uncond_input = self.model.tokenizer( + [""], padding="max_length", max_length=self.model.tokenizer.model_max_length, return_tensors="pt" + ) + uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0] + text_input = self.model.tokenizer( + [prompt], + padding="max_length", + max_length=self.model.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0] + self.context = torch.cat([uncond_embeddings, text_embeddings]) + self.prompt = prompt + + # Encode the image to latent using the VAE. + @torch.no_grad() + def encode_tensor_to_latent(self, image: Tensor) -> Tensor: + with torch.no_grad(): + image = (image / 0.5 - 1).to(self.model.device) + latents = self.model.vae.encode(image)["latent_dist"].sample() + latents = latents * 0.18215 + return latents + + @torch.no_grad() + def decode_tensor_to_latent(self, latents: Tensor) -> Tensor: + latents = 1 / 0.18215 * latents.detach() + image = self.model.vae.decode(latents)["sample"] + image = (image / 2 + 0.5).clamp(0, 1) + return image + + @torch.no_grad() + def one_step_dissolve(self, latent: Tensor, i: int) -> Tensor: + _, cond_embeddings = self.context.chunk(2) + latent = latent.clone().detach() + # NOTE: This implementation use a reversed timesteps but can reach to + # a stable dissolving effect. + t = self.num_ddim_steps - self.model.scheduler.timesteps[i] + latent = self.model.scheduler.scale_model_input(latent, t) + cond_embeddings = cond_embeddings.repeat(latent.size(0), 1, 1) + noise_pred = self.model.unet(latent, t, cond_embeddings).sample + pred_x0 = self.predict_start_from_noise(noise_pred, t, latent) + return pred_x0 + + @torch.no_grad() + def dissolve(self, image: Tensor, t: int) -> Tensor: + self.init_prompt("") + latent = self.encode_tensor_to_latent(image) + ddim_latents = self.one_step_dissolve(latent, t) + dissolved = self.decode_tensor_to_latent(ddim_latents) + return dissolved + + +class StableDiffusionDissolving(ImageModule): + r"""Perform dissolving transformation using StableDiffusion models. + + Based on :cite:`shi2024dissolving`, the dissolving transformation is essentially applying one-step + reverse diffusion. Our implementation currently supports HuggingFace implementations of SD 1.4, 1.5 + and 2.1. SD 1.X tends to remove more details than SD2.1. + + .. list-table:: Title + :widths: 32 32 32 + :header-rows: 1 + + * - SD 1.4 + - SD 1.5 + - SD 2.1 + * - figure:: https://raw.githubusercontent.com/kornia/data/main/dslv-sd-1.4.png + - figure:: https://raw.githubusercontent.com/kornia/data/main/dslv-sd-1.5.png + - figure:: https://raw.githubusercontent.com/kornia/data/main/dslv-sd-2.1.png + + Args: + version: the version of the stable diffusion model. + **kwargs: additional arguments for `.from_pretrained`. + """ + + def __init__(self, version: str = "2.1", **kwargs: Any): + super().__init__() + StableDiffusionPipeline = diffusers.StableDiffusionPipeline + DDIMScheduler = diffusers.DDIMScheduler + + # Load the scheduler and model pipeline from diffusers library + scheduler = DDIMScheduler( # type:ignore + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + + if version == "1.4": + self._sdm_model = StableDiffusionPipeline.from_pretrained( # type:ignore + "CompVis/stable-diffusion-v1-4", scheduler=scheduler, **kwargs + ) + elif version == "1.5": + self._sdm_model = StableDiffusionPipeline.from_pretrained( # type:ignore + "runwayml/stable-diffusion-v1-5", scheduler=scheduler, **kwargs + ) + elif version == "2.1": + self._sdm_model = StableDiffusionPipeline.from_pretrained( # type:ignore + "stabilityai/stable-diffusion-2-1", scheduler=scheduler, **kwargs + ) + else: + raise NotImplementedError + + self.model = _DissolvingWraper_HF(self._sdm_model, num_ddim_steps=1000) + + def forward(self, input: Tensor, step_number: int) -> Tensor: + return self.model.dissolve(input, step_number) diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 9f92712ea9..d6993c7702 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,4 +1,6 @@ +accelerate coverage +diffusers mypy numpy<3 onnx @@ -7,3 +9,4 @@ pre-commit>=2 pytest==8.3.2 pytest-timeout requests +transformers diff --git a/tests/augmentation/test_augmentation.py b/tests/augmentation/test_augmentation.py index 469be59f0e..939cdc6e3b 100644 --- a/tests/augmentation/test_augmentation.py +++ b/tests/augmentation/test_augmentation.py @@ -23,6 +23,7 @@ RandomClahe, RandomContrast, RandomCrop, + RandomDissolving, RandomElasticTransform, RandomEqualize, RandomErasing, @@ -5091,6 +5092,7 @@ def test_spawn_multiprocessing_context(self, context: str): torch.cuda.empty_cache() +@pytest.mark.slow class TestRandomJPEG(BaseTester): torch.manual_seed(0) # for random reproductibility @@ -5133,3 +5135,20 @@ def test_gradcheck(self, device): img_jpeg_mean_grad_ref = torch.tensor([0.1919]) # We use a slightly higher tolerance since our implementation varies from the reference implementation self.assert_close(img.grad.mean().view(-1), img_jpeg_mean_grad_ref, rtol=0.01, atol=0.01) + + +@pytest.mark.slow +class TestRandomDissolving(BaseTester): + torch.manual_seed(0) # for random reproductibility + + def test_batch_proc(self, device, dtype): + images = torch.rand(4, 3, 16, 16) + aug = RandomDissolving(p=1.0, version="2.1", cache_dir="weights/") + images_aug = aug(images) + assert images_aug.shape == images.shape + + def test_single_proc(self, device, dtype): + images = torch.rand(3, 16, 16) + aug = RandomDissolving(p=1.0, keepdim=True, version="2.1", cache_dir="weights/") + images_aug = aug(images) + assert images_aug.shape == images.shape diff --git a/tests/core/test_lazyloader.py b/tests/core/test_lazyloader.py new file mode 100644 index 0000000000..1025f97827 --- /dev/null +++ b/tests/core/test_lazyloader.py @@ -0,0 +1,56 @@ +import pytest + +from kornia.core.external import LazyLoader + + +class TestLazyLoader: + def test_lazy_loader_initialization(self): + # Test that the LazyLoader initializes with the correct module name and None module + loader = LazyLoader("math") + assert loader.module_name == "math" + assert loader.module is None + + def test_lazy_loader_loading_module(self): + # Test that the LazyLoader correctly loads the module upon attribute access + loader = LazyLoader("math") + assert loader.module is None # Should be None before any attribute access + + # Access an attribute to trigger module loading + assert loader.sqrt(4) == 2.0 + assert loader.module is not None # Should be loaded now + + def test_lazy_loader_invalid_module(self): + # Test that LazyLoader raises an ImportError for an invalid module + loader = LazyLoader("non_existent_module") + with pytest.raises(ImportError) as excinfo: + loader.non_existent_attribute # Accessing any attribute should raise the error + + assert "Optional dependency 'non_existent_module' is not installed" in str(excinfo.value) + + def test_lazy_loader_getattr(self): + # Test that __getattr__ works correctly for a valid module + loader = LazyLoader("math") + assert loader.sqrt(16) == 4.0 + assert loader.pi == 3.141592653589793 + + def test_lazy_loader_dir(self): + # Test that dir() returns the correct list of attributes for the module + loader = LazyLoader("math") + attributes = dir(loader) + assert "sqrt" in attributes + assert "pi" in attributes + assert loader.module is not None + + def test_lazy_loader_multiple_attributes(self): + # Test accessing multiple attributes to ensure the module is loaded only once + loader = LazyLoader("math") + assert loader.sqrt(25) == 5.0 + assert loader.pi == 3.141592653589793 + assert loader.pow(2, 3) == 8.0 + assert loader.module is not None + + def test_lazy_loader_non_existing_attribute(self): + # Test that accessing a non-existing attribute raises an AttributeError after loading + loader = LazyLoader("math") + with pytest.raises(AttributeError): + loader.non_existent_attribute diff --git a/tests/core/test_module.py b/tests/core/test_module.py index 258694feb1..f396027456 100644 --- a/tests/core/test_module.py +++ b/tests/core/test_module.py @@ -8,8 +8,6 @@ from kornia.core.module import ImageModule, ImageModuleMixIn -# Assuming ImageModuleMixIn and ImageModule have been imported from the module - class TestImageModuleMixIn: @pytest.fixture diff --git a/tests/filters/test_dissolving.py b/tests/filters/test_dissolving.py new file mode 100644 index 0000000000..d340e59d19 --- /dev/null +++ b/tests/filters/test_dissolving.py @@ -0,0 +1,43 @@ +import pytest +import torch + +from kornia.core import Tensor +from kornia.filters.dissolving import StableDiffusionDissolving + +WEIGHTS_CACHE_DIR = "weights/" + + +@pytest.mark.slow +class TestStableDiffusionDissolving: + @pytest.fixture(scope="class") + def sdm_2_1(self): + return StableDiffusionDissolving(version="2.1", cache_dir=WEIGHTS_CACHE_DIR) + + @pytest.fixture(scope="class") + def dummy_image(self): + # Create a dummy image tensor with shape [B, C, H, W], where B is the batch size. + return torch.rand(1, 3, 64, 64) + + def test_init(self, sdm_2_1): + assert isinstance(sdm_2_1, StableDiffusionDissolving), "Initialization failed" + + def test_encode_tensor_to_latent(self, sdm_2_1, dummy_image): + latents = sdm_2_1.model.encode_tensor_to_latent(dummy_image) + assert isinstance(latents, Tensor), "Latent encoding failed" + assert latents.shape == (1, 4, 8, 8), "Latent shape mismatch" + + def test_decode_tensor_to_latent(self, sdm_2_1, dummy_image): + latents = sdm_2_1.model.encode_tensor_to_latent(dummy_image) + reconstructed_image = sdm_2_1.model.decode_tensor_to_latent(latents) + assert isinstance(reconstructed_image, Tensor), "Latent decoding failed" + assert reconstructed_image.shape == dummy_image.shape, "Reconstructed image shape mismatch" + + def test_dissolve(self, sdm_2_1, dummy_image): + step_number = 500 # Test with a middle step + dissolved_image = sdm_2_1(dummy_image, step_number) + assert isinstance(dissolved_image, Tensor), "Dissolve failed" + assert dissolved_image.shape == dummy_image.shape, "Dissolved image shape mismatch" + + def test_invalid_version(self): + with pytest.raises(NotImplementedError): + StableDiffusionDissolving(version="invalid_version")