Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Added dissolving transformation & updated docs #62

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a3e9ac8
updated
shijianjian Aug 7, 2024
959a508
update
shijianjian Jul 22, 2024
47b89b1
update
shijianjian Aug 5, 2024
55f33c1
update
shijianjian Aug 7, 2024
7866cfa
Added tests for LazyLoader
shijianjian Aug 8, 2024
1fb28ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 8, 2024
ac1ad7f
updated tests
shijianjian Aug 10, 2024
2d215f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2024
60383fd
updated typing
shijianjian Aug 10, 2024
a0ab5da
Added tests
shijianjian Aug 12, 2024
777129f
Updated caching for diffusers
shijianjian Aug 20, 2024
3a8ab3c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2024
e76f57b
updated
shijianjian Aug 20, 2024
6520b33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2024
aa3b200
update
shijianjian Aug 20, 2024
6254123
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2024
c56094e
update
shijianjian Aug 21, 2024
1d8405e
updated
shijianjian Aug 21, 2024
5d154de
update
shijianjian Aug 21, 2024
97dc6cd
update
shijianjian Aug 21, 2024
b447669
update
shijianjian Aug 21, 2024
d8e3f9b
update
shijianjian Aug 21, 2024
0fca9ed
update
shijianjian Aug 22, 2024
e75380c
update
shijianjian Aug 22, 2024
c881b73
update
shijianjian Aug 22, 2024
ae64b51
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2024
02d0dd0
update
shijianjian Aug 22, 2024
ed2d11f
update
shijianjian Aug 22, 2024
093f203
update
shijianjian Aug 23, 2024
8f0e140
update
shijianjian Aug 23, 2024
ff7819f
updated README
shijianjian Aug 23, 2024
d7fbba8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2024
01912fd
update
shijianjian Aug 23, 2024
7fc1f15
updated
shijianjian Aug 25, 2024
dc5c7c0
update
shijianjian Aug 25, 2024
abbf224
update
shijianjian Aug 25, 2024
001c39c
Update README.md
edgarriba Aug 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions .github/download-models-weights.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,34 @@
import argparse
import os

import diffusers
import torch

fonts = {
"sold2_wireframe": "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth",
models = {
"sold2_wireframe": ("torchhub", "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth"),
"stabilityai/stable-diffusion-2-1": ("diffusers", "StableDiffusionPipeline"),
}


if __name__ == "__main__":
parser = argparse.ArgumentParser("WeightsDownloader")
parser.add_argument("--target_directory", "-t", required=False, default="target_directory")

args = parser.parse_args()

torch.hub.set_dir(args.target_directory)
# For HuggingFace model caching
os.environ["HF_HOME"] = args.target_directory

for name, url in fonts.items():
print(f"Downloading weights of `{name}` from `url`. Caching to dir `{args.target_directory}`")
torch.hub.load_state_dict_from_url(url, model_dir=args.target_directory, map_location=torch.device("cpu"))
for name, (src, path) in models.items():
if src == "torchhub":
print(f"Downloading weights of `{name}` from `{path}`. Caching to dir `{args.target_directory}`")
torch.hub.load_state_dict_from_url(path, model_dir=args.target_directory, map_location=torch.device("cpu"))
elif src == "diffusers":
print(f"Downloading `{name}` from diffusers. Caching to dir `{args.target_directory}`")
if path == "StableDiffusionPipeline":
diffusers.StableDiffusionPipeline.from_pretrained(
name, cache_dir=args.target_directory, device_map="balanced"
)

raise SystemExit(0)
104 changes: 101 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,53 @@ English | [简体中文](README_zh-CN.md)
</p>
</div>

**Kornia** is a differentiable computer vision library for [PyTorch](https://pytorch.org).
**Kornia** is a differentiable computer vision library that provides a rich set of differentiable image processing and geometric vision algorithms. Built on top of [PyTorch](https://pytorch.org), Kornia integrates seamlessly into existing AI workflows, allowing you to leverage powerful [batch transformations](), [auto-differentiation]() and [GPU acceleration](). Whether you’re working on image transformations, augmentations, or AI-driven image processing, Kornia equips you with the tools you need to bring your ideas to life.

## Key Components
1. **Differentiable Image Processing**<br>
Kornia provides a comprehensive suite of image processing operators, all differentiable and ready to integrate into deep learning pipelines.
- **Filters**: Gaussian, Sobel, Median, Box Blur, etc.
- **Transformations**: Affine, Homography, Perspective, etc.
- **Enhancements**: Histogram Equalization, CLAHE, Gamma Correction, etc.
- **Edge Detection**: Canny, Laplacian, Sobel, etc.
- ... check our [docs](https://kornia.readthedocs.io) for more.
2. **Advanced Augmentations**<br>
Perform powerful data augmentation with Kornia’s built-in functions, ideal for training AI models with complex augmentation pipelines.
- **Augmentation Pipeline**: AugmentationSequential, PatchSequential, VideoSequential, etc.
- **Automatic Augmentation**: AutoAugment, RandAugment, TrivialAugment.
3. **AI Models**<br>
Leverage pre-trained AI models optimized for a variety of vision tasks, all within the Kornia ecosystem.
- **Face Detection**: YuNet
- **Feature Matching**: LoFTR, LightGlue
- **Feature Descriptor**: DISK, DeDoDe, SOLD2
- **Segmentation**: SAM
- **Classification**: MobileViT, VisionTransformer.

It consists of a set of routines and differentiable modules to solve generic computer vision problems. At its core, the package uses *PyTorch* as its main backend both for efficiency and to take advantage of the reverse-mode auto-differentiation to define and compute the gradient of complex functions.
<details>
<summary>See here for some of the methods that we support! (>500 ops in total !)</summary>

| **Category** | **Methods/Models** |
|----------------------------|---------------------------------------------------------------------------------------------------------------------|
| **Image Processing** | - Color conversions (RGB, Grayscale, HSV, etc.)<br>- Geometric transformations (Affine, Homography, Resizing, etc.)<br>- Filtering (Gaussian blur, Median blur, etc.)<br>- Edge detection (Sobel, Canny, etc.)<br>- Morphological operations (Erosion, Dilation, etc.) |
| **Augmentation** | - Random cropping, Erasing<br> - Random geometric transformations (Affine, flipping, Fish Eye, Perspecive, Thin plate spline, Elastic)<br>- Random noises (Gaussian, Median, Motion, Box, Rain, Snow, Salt and Pepper)<br>- Random color jittering (Contrast, Brightness, CLAHE, Equalize, Gamma, Hue, Invert, JPEG, Plasma, Posterize, Saturation, Sharpness, Solarize)<br> - Random MixUp, CutMix, Mosaic, Transplantation, etc. |
| **Feature Detection** | - Detector (Harris, GFTT, Hessian, DoG, KeyNet, DISK and DeDoDe)<br> - Descriptor (SIFT, HardNet, TFeat, HyNet, SOSNet, and LAFDescriptor)<br>- Matching (nearest neighbor, mutual nearest neighbor, geometrically aware matching, AdaLAM LightGlue, and LoFTR) |
| **Geometry** | - Camera models and calibration<br>- Stereo vision (epipolar geometry, disparity, etc.)<br>- Homography estimation<br>- Depth estimation from disparity<br>- 3D transformations |
| **Deep Learning Layers** | - Custom convolution layers<br>- Recurrent layers for vision tasks<br>- Loss functions (e.g., SSIM, PSNR, etc.)<br>- Vision-specific optimizers |
| **Photometric Functions** | - Photometric loss functions<br>- Photometric augmentations |
| **Filtering** | - Bilateral filtering<br>- DexiNed<br>- Dissolving<br>- Guided Blur<br>- Laplacian<br>- Gaussian<br>- Non-local means<br>- Sobel<br>- Unsharp masking |
| **Color** | - Color space conversions<br>- Brightness/contrast adjustment<br>- Gamma correction |
| **Stereo Vision** | - Disparity estimation<br>- Depth estimation<br>- Rectification |
| **Image Registration** | - Affine and homography-based registration<br>- Image alignment using feature matching |
| **Pose Estimation** | - Essential and Fundamental matrix estimation<br>- PnP problem solvers<br>- Pose refinement |
| **Optical Flow** | - Farneback optical flow<br>- Dense optical flow<br>- Sparse optical flow |
| **3D Vision** | - Depth estimation<br>- Point cloud operations<br>- Nerf<br> |
| **Image Denoising** | - Gaussian noise removal<br>- Poisson noise removal |
| **Edge Detection** | - Sobel operator<br>- Canny edge detection | |
| **Transformations** | - Rotation<br>- Translation<br>- Scaling<br>- Shearing |
| **Loss Functions** | - SSIM (Structural Similarity Index Measure)<br>- PSNR (Peak Signal-to-Noise Ratio)<br>- Cauchy<br>- Charbonnier<br>- Depth Smooth<br>- Dice<br>- Hausdorff<br>- Tversky<br>- Welsch<br> | |
| **Morphological Operations**| - Dilation<br>- Erosion<br>- Opening<br>- Closing |

Inspired by existing packages, this library is composed by a subset of packages containing operators that can be inserted within neural networks to train models to perform image transformations, epipolar geometry, depth estimation, and low-level image processing such as filtering and edge detection that operate directly on tensors.
</details>

## Sponsorship

Expand Down Expand Up @@ -66,6 +108,62 @@ Kornia is an open-source project that is developed and maintained by volunteers.

</details>

## Quick Start

Kornia is not just another computer vision library — it's your gateway to effortless Computer Vision and AI.

```python
import numpy as np
import kornia_rs as kr

from kornia.augmentation import AugmentationSequential, RandomAffine, RandomBrightness
from kornia.filters import StableDiffusionDissolving

# Load and prepare your image
img: np.ndarray = kr.read_image_any("img.jpeg")
img = kr.resize(img, (256, 256), interpolation="bilinear")

# alternatively, load image with PIL
# img = Image.open("img.jpeg").resize((256, 256))
# img = np.array(img)

img = np.stack([img] * 2) # batch images

# Define an augmentation pipeline
augmentation_pipeline = AugmentationSequential(
RandomAffine((-45., 45.), p=1.),
RandomBrightness((0.,1.), p=1.)
)

# Leveraging StableDiffusion models
dslv_op = StableDiffusionDissolving()

img = augmentation_pipeline(img)
dslv_op(img, step_number=500)

dslv_op.save("Kornia-enhanced.jpg")
```

## Call For Contributors

Are you passionate about computer vision, AI, and open-source development? Join us in shaping the future of Kornia! We are actively seeking contributors to help expand and enhance our library, making it even more powerful, accessible, and versatile. Whether you're an experienced developer or just starting, there's a place for you in our community.

### Accessible AI Models

We are excited to announce our latest advancement: a new initiative designed to seamlessly integrate lightweight AI models into Kornia.
We aim to run any models as smooth as big models such as StableDiffusion, to support them well in many perspectives.
We have already included a selection of lightweight AI models like [YuNet (Face Detection)](), [Loftr (Feature Matching)](), and [SAM (Segmentation)](). Now, we're looking for contributors to help us:

- Expand the Model Selection: Import decent models into our library. If you are a researcher, Kornia is an excellent place for you to promote your model!
- Model Optimization: Work on optimizing models to reduce their computational footprint while maintaining accuracy and performance. You may start from offering ONNX support!
- Model Documentation: Create detailed guides and examples to help users get the most out of these models in their projects.


### Documentation And Tutorial Optimization

Kornia's foundation lies in its extensive collection of classic computer vision operators, providing robust tools for image processing, feature extraction, and geometric transformations. We continuously seek for contributors to help us improve our documentation and present nice tutorials to our users.


## Cite

If you are using kornia in your research-related documents, it is recommended that you cite the paper. See more in [CITATION](./CITATION.md).
Expand Down
2 changes: 2 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ def pytest_sessionstart(session):

os.makedirs(WEIGHTS_CACHE_DIR, exist_ok=True)
torch.hub.set_dir(WEIGHTS_CACHE_DIR)
# For HuggingFace model caching
os.environ["HF_HOME"] = WEIGHTS_CACHE_DIR


def _get_env_info() -> Dict[str, Dict[str, str]]:
Expand Down
1 change: 1 addition & 0 deletions docs/source/augmentation.module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Intensity
.. autoclass:: RandomClahe
.. autoclass:: RandomContrast
.. autoclass:: RandomEqualize
.. autoclass:: RandomDissolving
.. autoclass:: RandomGamma
.. autoclass:: RandomGaussianBlur
.. autoclass:: RandomGaussianIllumination
Expand Down
14 changes: 14 additions & 0 deletions docs/source/get-started/highlights.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
Highlighted Features
====================

At Kornia, we are dedicated to pushing the boundaries of computer vision by providing a robust, efficient, and versatile toolkit. Our library is built on the powerful PyTorch backend, leveraging its efficiency and auto-differentiation capabilities to deliver high-performance solutions for a wide range of vision tasks.


Accessible AI Models
--------------------

We are excited to announce our latest advancement: a new initiative designed to seamlessly integrate lightweight AI models into Kornia. This addition is crafted to empower developers, researchers, and enthusiasts to harness the full potential of accessible AI, simplifying complex vision tasks and accelerating innovation.

We have curated a selection of lightweight AI models, including YuNet, Loftr, and SAM, optimized for performance and efficiency. These models offer efficient computations that do not require expensive GPUs, making cutting-edge AI accessible to everyone. We welcome the whole community of developers and researchers who are passionate about advancing computer vision, throwing PRs for your lightning fast models!


Classic Operators
-----------------

.. image:: https://github.com/kornia/data/raw/main/kornia_paper_mosaic.png
:align: center

Expand Down
7 changes: 7 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,10 @@ @article{wang2023vggsfm
journal={arXiv preprint arXiv:2312.04563},
year={2023}
}

@misc{shi2024dissolving,
Author = {Jian Shi and Pengyi Zhang and Ni Zhang and Hakim Ghazzai and Peter Wonka},
Title = {Dissolving Is Amplifying: Towards Fine-Grained Anomaly Detection},
booktitle = {ECCV},
Year = {2024},
}
1 change: 1 addition & 0 deletions kornia/augmentation/_2d/intensity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from kornia.augmentation._2d.intensity.color_jitter import ColorJitter
from kornia.augmentation._2d.intensity.contrast import RandomContrast
from kornia.augmentation._2d.intensity.denormalize import Denormalize
from kornia.augmentation._2d.intensity.dissolving import RandomDissolving
from kornia.augmentation._2d.intensity.equalize import RandomEqualize
from kornia.augmentation._2d.intensity.erasing import RandomErasing
from kornia.augmentation._2d.intensity.gamma import RandomGamma
Expand Down
57 changes: 57 additions & 0 deletions kornia/augmentation/_2d/intensity/dissolving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Any, Dict, Optional, Tuple

from kornia.augmentation import random_generator as rg
from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
from kornia.core import Tensor
from kornia.filters import StableDiffusionDissolving


class RandomDissolving(IntensityAugmentationBase2D):
r"""Perform dissolving transformation using StableDiffusion models.

Based on :cite:`shi2024dissolving`, the dissolving transformation is essentially applying one-step
reverse diffusion. Our implementation currently supports HuggingFace implementations of SD 1.4, 1.5
and 2.1. SD 1.X tends to remove more details than SD2.1.

.. list-table:: Title
:widths: 32 32 32
:header-rows: 1

* - SD 1.4
- SD 1.5
- SD 2.1
* - figure:: https://raw.githubusercontent.com/kornia/data/main/dslv-sd-1.4.png
- figure:: https://raw.githubusercontent.com/kornia/data/main/dslv-sd-1.5.png
- figure:: https://raw.githubusercontent.com/kornia/data/main/dslv-sd-2.1.png

Args:
p: probability of applying the transformation.
version: the version of the stable diffusion model.
step_range: the step range of the diffusion model steps. Higher the step, stronger
the dissolving effects.
keepdim: whether to keep the output shape the same as input (True) or broadcast it
to the batch form (False).
**kwargs: additional arguments for `.from_pretrained` for HF StableDiffusionPipeline.

Shape:
- Input: :math:`(C, H, W)` or :math:`(B, C, H, W)`.
- Output: :math:`(B, C, H, W)`
"""

def __init__(
self,
step_range: Tuple[float, float] = (100, 500),
version: str = "2.1",
p: float = 0.5,
keepdim: bool = False,
**kwargs: Any,
) -> None:
super().__init__(p=p, same_on_batch=True, keepdim=keepdim)
self.step_range = step_range
self._dslv = StableDiffusionDissolving(version, **kwargs)
self._param_generator = rg.PlainUniformGenerator((self.step_range, "step_range_factor", None, None))

def apply_transform(
self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
) -> Tensor:
return self._dslv(input, params["step_range_factor"][0].long().item())
2 changes: 2 additions & 0 deletions kornia/augmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
RandomContrast,
RandomCrop,
RandomCutMixV2,
RandomDissolving,
RandomElasticTransform,
RandomEqualize,
RandomErasing,
Expand Down Expand Up @@ -115,6 +116,7 @@
"RandomChannelShuffle",
"RandomContrast",
"RandomCrop",
"RandomDissolving",
"RandomErasing",
"RandomElasticTransform",
"RandomFisheye",
Expand Down
2 changes: 2 additions & 0 deletions kornia/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import external
from ._backend import (
Device,
Dtype,
Expand Down Expand Up @@ -35,6 +36,7 @@
from .tensor_wrapper import TensorWrapper # type: ignore

__all__ = [
"external",
"arange",
"concatenate",
"Device",
Expand Down
1 change: 1 addition & 0 deletions kornia/core/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@ def __dir__(self) -> List[str]:

numpy = LazyLoader("numpy")
PILImage = LazyLoader("PIL.Image")
diffusers = LazyLoader("diffusers")
2 changes: 1 addition & 1 deletion kornia/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def save(self, name: Optional[str] = None, n_row: Optional[int] = None) -> None:
n_row: Number of images displayed in each row of the grid.
"""
if name is None:
name = f"Kornia-{datetime.datetime.now(tz=datetime.UTC).strftime('%Y%m%d%H%M%S')!s}.jpg"
name = f"Kornia-{datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y%m%d%H%M%S')!s}.jpg"
if len(self._output_image.shape) == 3:
out_image = self._output_image
if len(self._output_image.shape) == 4:
Expand Down
2 changes: 2 additions & 0 deletions kornia/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from .canny import Canny, canny
from .dexined import DexiNed
from .dissolving import StableDiffusionDissolving
from .filter import filter2d, filter2d_separable, filter3d
from .gaussian import GaussianBlur2d, gaussian_blur2d, gaussian_blur2d_t
from .guided import GuidedBlur, guided_blur
Expand Down Expand Up @@ -92,6 +93,7 @@
"Canny",
"BoxBlur",
"BlurPool2D",
"StableDiffusionDissolving",
"MaxBlurPool2D",
"EdgeAwareBlurPool2D",
"MedianBlur",
Expand Down
Loading
Loading