Skip to content

Commit

Permalink
fix: Fixed jina clip image preprocessor
Browse files Browse the repository at this point in the history
  • Loading branch information
hh-space-invader committed Nov 20, 2024
1 parent 8e3b331 commit c7583f7
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 19 deletions.
6 changes: 3 additions & 3 deletions fastembed/image/transform/functional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sized, Union
from typing import Sized, Union, Optional

import numpy as np
from PIL import Image, ImageOps
Expand Down Expand Up @@ -126,8 +126,8 @@ def pil2ndarray(image: Union[Image.Image, np.ndarray]):

def pad2square(
image: Image,
fill_color: str | int | tuple[int, ...] | None = None,
resample: Image.Resampling = Image.Resampling.BILINEAR,
fill_color: Optional[Union[str, int, tuple[int, ...]]] = None,
resample: Union[Image.Resampling, int] = Image.Resampling.BILINEAR,
):
width, height = image.size
max_dim = max(width, height)
Expand Down
55 changes: 39 additions & 16 deletions fastembed/image/transform/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
pil2ndarray,
rescale,
resize,
pad2sqaure,
pad2square,
)


Expand Down Expand Up @@ -71,14 +71,14 @@ class PadtoSquare(Transform):
def __init__(
self,
fill_color: Optional[Union[str, int, tuple[int, ...]]] = None,
resample: Image.Resampling = Image.Resampling.BICUBIC,
resample: Union[Image.Resampling, int] = Image.Resampling.BICUBIC,
):
self.fill_color = fill_color
self.resample = resample

def __call__(self, images: list[np.ndarray]) -> list[np.ndarray]:
return [
pad2sqaure(image=image, fill_color=self.fill_color, resample=self.resample)
pad2square(image=image, fill_color=self.fill_color, resample=self.resample)
for image in images
]

Expand Down Expand Up @@ -125,8 +125,8 @@ def from_config(cls, config: dict[str, Any]) -> "Compose":
"""
transforms = []
cls._get_convert_to_rgb(transforms, config)
cls._get_pad2square(transforms, config)
cls._get_resize(transforms, config)
cls._get_padtosquare(transforms, config)
cls._get_center_crop(transforms, config)
cls._get_pil2ndarray(transforms, config)
cls._get_rescale(transforms, config)
Expand Down Expand Up @@ -188,7 +188,11 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]):
transforms.append(
Resize(
size=config["size"],
resample=config.get("interpolation", Image.Resampling.BICUBIC),
resample=(
Compose._interpolation_resolver(config.get("interpolation"))
if isinstance(config.get("interpolation"), str)
else config.get("interpolation") or Image.Resampling.BICUBIC
),
)
)
else:
Expand Down Expand Up @@ -229,19 +233,38 @@ def _get_normalize(transforms: list[Transform], config: dict[str, Any]):
if config.get("do_normalize", False) or ("mean" in config and "std" in config):
transforms.append(
Normalize(
mean=config["image_mean"] or config["mean"],
std=config["image_std"] or config["std"],
mean=config.get("image_mean", config.get("mean")),
std=config.get("image_std", config.get("std")),
)
)

@staticmethod
def _get_padtosquare(transforms: list[Transform], config: dict[str, Any]):
if config.get("do_pad_to_square", False):
transforms.append(
PadtoSquare(
fill_color=config["fill_color"],
resample=config.get("interpolation")
or config.get("resample")
or Image.Resampling.BICUBIC,
)
def _get_pad2square(transforms: list[Transform], config: dict[str, Any]):
mode = config.get("image_processor_type", "CLIPImageProcessor")
if mode == "CLIPImageProcessor":
pass
elif mode == "ConvNextFeatureExtractor":
pass
elif mode == "JinaCLIPImageProcessor":
resample = (
Compose._interpolation_resolver(config.get("interpolation"))
if isinstance(config.get("interpolation"), str)
else config.get("interpolation") or Image.Resampling.BICUBIC
)
transforms.append(PadtoSquare(fill_color=config["fill_color"], resample=resample))

@staticmethod
def _interpolation_resolver(resample: Optional[str] = None) -> Image.Resampling:
interpolation_map = {
"nearest": Image.Resampling.NEAREST,
"lanczos": Image.Resampling.LANCZOS,
"bilinear": Image.Resampling.BILINEAR,
"bicubic": Image.Resampling.BICUBIC,
"box": Image.Resampling.BOX,
"hamming": Image.Resampling.HAMMING,
}

if resample and (method := interpolation_map.get(resample.lower())):
return method

raise ValueError(f"Unknown interpolation method: {resample}")

0 comments on commit c7583f7

Please sign in to comment.