diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 86a9ffa..34d5a6b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -41,15 +41,12 @@ jobs: python-version: "3.10" # Install PyTorch and TensorFlow CPU versions manually to prevent installing CUDA - # Install SAM and MobileSAM manually as they cannot be included in PyPI - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install pylint python -m pip install torch~=2.2.0 torchaudio~=2.2.0 torchvision~=0.17.0 --index-url https://download.pytorch.org/whl/cpu python -m pip install tensorflow-cpu~=2.15.0 - python -m pip install segment-anything@git+https://github.com/facebookresearch/segment-anything - python -m pip install mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM python -m pip install . - name: Lint backend code with Pylint diff --git a/CHANGELOG.md b/CHANGELOG.md index 845209b..78fa292 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,23 @@ All notable changes to Pixano will be documented in this file. ## [Unreleased] +## [0.3.1] - 2024-03-18 + +### Added + +- Add **new GroundingDINO model** for semantic segmentation with text prompts (pixano/pixano-inference#6) + +### Changed + +- Update README badges with PyPI release + +### Fixed + +- Remove top-level imports for GitHub models to prevent import errors (pixano/pixano-inference#6) +- Fix preannotation with SAM and MobileSAM (pixano/pixano-inference#6) +- Add type hints for Image PixanoType (pixano/pixano-inference#6) +- Update Pixano requirement from 0.5.0 beta to 0.5.0 stable + ## [0.3.0] - 2024-02-29 ### Added @@ -20,7 +37,7 @@ All notable changes to Pixano will be documented in this file. - **Breaking:** Remove SAM and MobileSAM dependencies to allow publishing to PyPI (pixano/pixano-inference#14) - **Breaking:** Update to Pixano 0.5.0 - **Breaking:** Update InferenceModel `id` attribute to `model_id` to stop redefining built-in `id` -- **Breaking:** Update submodule names to `pytorch` and `tensorflow` +- **Breaking:** Update submodule names to `pytorch`, `tensorflow`, and `github` - Update README with a small header description listing main features and more detailed installation instructions - Generate API reference on documentation website automatically - Add cross-references to Pixano, TensorFlow, and Hugging Face Transformers in the API reference @@ -113,7 +130,8 @@ All notable changes to Pixano will be documented in this file. - Create first public release -[Unreleased]: https://github.com/pixano/pixano/compare/main...develop +[Unreleased]: https://github.com/pixano/pixano-inference/compare/main...develop +[0.3.1]: https://github.com/pixano/pixano-inference/compare/v0.3.0...v0.3.1 [0.3.0]: https://github.com/pixano/pixano-inference/compare/v0.2.1...v0.3.0 [0.2.1]: https://github.com/pixano/pixano-inference/compare/v0.2.0...v0.2.1 [0.2.0]: https://github.com/pixano/pixano-inference/compare/v0.1.6...v0.2.0 diff --git a/README.md b/README.md index df81cfd..087cbc3 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,10 @@ **_Under active development, subject to API change_** [![GitHub version](https://img.shields.io/github/v/release/pixano/pixano-inference?label=release&logo=github)](https://github.com/pixano/pixano-inference/releases) -[![Documentation](https://img.shields.io/website/https/pixano.github.io?up_message=online&up_color=green&down_message=offline&down_color=orange&label=docs)](https://pixano.github.io/pixano-inference) -[![License](https://img.shields.io/badge/license-CeCILL--C-green.svg)](LICENSE) +[![PyPI version](https://img.shields.io/pypi/v/pixano-inference?color=blue&label=release&logo=pypi&logoColor=white)](https://pypi.org/project/pixano-inference/) +[![Documentation](https://img.shields.io/website?url=https%3A%2F%2Fpixano.github.io%2F&up_message=online&down_message=offline&label=docs)](https://pixano.github.io) [![Python version](https://img.shields.io/pypi/pyversions/pixano?color=important&logo=python&logoColor=white)](https://www.python.org/downloads/) +[![License](https://img.shields.io/badge/license-CeCILL--C-blue.svg)](LICENSE) diff --git a/docs/getting_started/installing_pixano_inference.md b/docs/getting_started/installing_pixano_inference.md index 367ccb5..0a63412 100644 --- a/docs/getting_started/installing_pixano_inference.md +++ b/docs/getting_started/installing_pixano_inference.md @@ -14,9 +14,10 @@ pip install pixano pip install pixano-inference ``` -To use the inference models available through GitHub, install the following additional packages: +To use the inference models available through GitHub, install the following additional packages as needed: ```shell python -m pip install segment-anything@git+https://github.com/facebookresearch/segment-anything python -m pip install mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM +python -m pip install groundingdino@git+https://github.com/IDEA-Research/GroundingDINO ``` diff --git a/pixano_inference/__version__.py b/pixano_inference/__version__.py index 59f7192..63211b3 100644 --- a/pixano_inference/__version__.py +++ b/pixano_inference/__version__.py @@ -11,4 +11,4 @@ # # http://www.cecill.info -__version__ = "0.3.0" +__version__ = "0.3.1" diff --git a/pixano_inference/github/__init__.py b/pixano_inference/github/__init__.py index 7ba4248..d22e345 100644 --- a/pixano_inference/github/__init__.py +++ b/pixano_inference/github/__init__.py @@ -11,10 +11,12 @@ # # http://www.cecill.info +from .groundingdino import GroundingDINO from .mobile_sam import MobileSAM from .sam import SAM __all__ = [ "SAM", + "GroundingDINO", "MobileSAM", ] diff --git a/pixano_inference/github/groundingdino.py b/pixano_inference/github/groundingdino.py new file mode 100644 index 0000000..66c0b8c --- /dev/null +++ b/pixano_inference/github/groundingdino.py @@ -0,0 +1,145 @@ +# @Copyright: CEA-LIST/DIASI/SIALV/LVA (2023) +# @Author: CEA-LIST/DIASI/SIALV/LVA +# @License: CECILL-C +# +# This software is a collaborative computer program whose purpose is to +# generate and explore labeled data for computer vision applications. +# This software is governed by the CeCILL-C license under French law and +# abiding by the rules of distribution of free software. You can use, +# modify and/ or redistribute the software under the terms of the CeCILL-C +# license as circulated by CEA, CNRS and INRIA at the following URL +# +# http://www.cecill.info + +from pathlib import Path + +import pyarrow as pa +import shortuuid +from pixano.core import BBox, Image +from pixano.models import InferenceModel +from torchvision.ops import box_convert + +from pixano_inference.utils import attempt_import + + +class GroundingDINO(InferenceModel): + """GroundingDINO Model + + Attributes: + name (str): Model name + model_id (str): Model ID + device (str): Model GPU or CPU device + description (str): Model description + model (torch.nn.Module): PyTorch model + checkpoint_path (Path): Model checkpoint path + config_path (Path): Model config path + """ + + def __init__( + self, + checkpoint_path: Path, + config_path: Path, + model_id: str = "", + device: str = "cuda", + ) -> None: + """Initialize model + + Args: + checkpoint_path (Path): Model checkpoint path (download from https://github.com/IDEA-Research/GroundingDINO) + config_path (Path): Model config path (download from https://github.com/IDEA-Research/GroundingDINO) + model_id (str, optional): Previously used ID, generate new ID if "". Defaults to "". + device (str, optional): Model GPU or CPU device (e.g. "cuda", "cpu"). Defaults to "cuda". + """ + + # Import GroundingDINO + gd_inf = attempt_import( + "groundingdino.util.inference", + "groundingdino@git+https://github.com/IDEA-Research/GroundingDINO", + ) + + super().__init__( + name="GroundingDINO", + model_id=model_id, + device=device, + description="Fom GitHub, GroundingDINO model.", + ) + + # Model + self.model = gd_inf.load_model( + config_path.as_posix(), + checkpoint_path.as_posix(), + ) + self.model.to(self.device) + + def preannotate( + self, + batch: pa.RecordBatch, + views: list[str], + uri_prefix: str, + threshold: float = 0.0, + prompt: str = "", + ) -> list[dict]: + """Inference pre-annotation for a batch + + Args: + batch (pa.RecordBatch): Input batch + views (list[str]): Dataset views + uri_prefix (str): URI prefix for media files + threshold (float, optional): Confidence threshold. Defaults to 0.0. + prompt (str, optional): Annotation text prompt. Defaults to "". + + Returns: + list[dict]: Processed rows + """ + + rows = [] + + # Import GroundingDINO + gd_inf = attempt_import( + "groundingdino.util.inference", + "groundingdino@git+https://github.com/IDEA-Research/GroundingDINO", + ) + + for view in views: + # Iterate manually + for x in range(batch.num_rows): + # Preprocess image + im: Image = Image.from_dict(batch[view][x].as_py()) + im.uri_prefix = uri_prefix + + _, image = gd_inf.load_image(im.path.as_posix()) + + # Inference + bbox_tensor, logit_tensor, category_list = gd_inf.predict( + model=self.model, + image=image, + caption=prompt, + box_threshold=0.35, + text_threshold=0.25, + ) + + # Convert bounding boxes from cyxcywh to xywh + bbox_tensor = box_convert( + boxes=bbox_tensor, in_fmt="cxcywh", out_fmt="xywh" + ) + bbox_list = [[coord.item() for coord in bbox] for bbox in bbox_tensor] + + # Process model outputs + rows.extend( + [ + { + "id": shortuuid.uuid(), + "item_id": batch["id"][x].as_py(), + "view_id": view, + "bbox": BBox.from_xywh( + bbox_list[i], + confidence=logit_tensor[i].item(), + ).to_dict(), + "category": category_list[i], + } + for i in range(len(category_list)) + if logit_tensor[i].item() > threshold + ] + ) + + return rows diff --git a/pixano_inference/github/mobile_sam.py b/pixano_inference/github/mobile_sam.py index 76f1e8e..5963398 100644 --- a/pixano_inference/github/mobile_sam.py +++ b/pixano_inference/github/mobile_sam.py @@ -20,13 +20,13 @@ import pyarrow as pa import shortuuid import torch -from mobile_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry -from mobile_sam.utils.onnx import SamOnnxModel from onnxruntime.quantization import QuantType from onnxruntime.quantization.quantize import quantize_dynamic from pixano.core import BBox, CompressedRLE, Image from pixano.models import InferenceModel +from pixano_inference.utils import attempt_import + class MobileSAM(InferenceModel): """MobileSAM @@ -54,6 +54,11 @@ def __init__( device (str, optional): Model GPU or CPU device (e.g. "cuda", "cpu"). Defaults to "cpu". """ + # Import MobileSAM + mobile_sam = attempt_import( + "mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM" + ) + super().__init__( name="Mobile_SAM", model_id=model_id, @@ -62,7 +67,7 @@ def __init__( ) # Model - self.model = sam_model_registry["vit_t"](checkpoint=checkpoint_path) + self.model = mobile_sam.sam_model_registry["vit_t"](checkpoint=checkpoint_path) self.model.to(device=self.device) # Model path @@ -74,6 +79,7 @@ def preannotate( views: list[str], uri_prefix: str, threshold: float = 0.0, + prompt: str = "", ) -> list[dict]: """Inference pre-annotation for a batch @@ -82,25 +88,32 @@ def preannotate( views (list[str]): Dataset views uri_prefix (str): URI prefix for media files threshold (float, optional): Confidence threshold. Defaults to 0.0. + prompt (str, optional): Annotation text prompt. Defaults to "". Returns: list[dict]: Processed rows """ + # Import MobileSAM + mobile_sam = attempt_import( + "mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM" + ) + rows = [] + _ = prompt # This model does not use prompts for view in views: # Iterate manually for x in range(batch.num_rows): # Preprocess image - im = Image.from_dict(batch[view][x].as_py()) + im: Image = Image.from_dict(batch[view][x].as_py()) im.uri_prefix = uri_prefix im = im.as_cv2() im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # Inference with torch.no_grad(): - generator = SamAutomaticMaskGenerator(self.model) + generator = mobile_sam.SamAutomaticMaskGenerator(self.model) output = generator.generate(im) # Process model outputs @@ -112,8 +125,8 @@ def preannotate( "item_id": batch["id"][x].as_py(), "view_id": view, "bbox": BBox.from_xywh( - [coord.item() for coord in output[i]["bbox"]], - confidence=output[i]["predicted_iou"].item(), + [int(coord) for coord in output[i]["bbox"]], + confidence=float(output[i]["predicted_iou"]), ) .normalize(h, w) .to_dict(), @@ -145,6 +158,11 @@ def precompute_embeddings( pa.RecordBatch: Embedding rows """ + # Import MobileSAM + mobile_sam = attempt_import( + "mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM" + ) + rows = [ { "id": batch["id"][x].as_py(), @@ -156,14 +174,14 @@ def precompute_embeddings( # Iterate manually for x in range(batch.num_rows): # Preprocess image - im = Image.from_dict(batch[view][x].as_py()) + im: Image = Image.from_dict(batch[view][x].as_py()) im.uri_prefix = uri_prefix im = im.as_cv2() im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # Inference with torch.no_grad(): - predictor = SamPredictor(self.model) + predictor = mobile_sam.SamPredictor(self.model) predictor.set_image(im) img_embedding = predictor.get_image_embedding().cpu().numpy() @@ -181,6 +199,11 @@ def export_to_onnx(self, library_dir: Path): library_dir (Path): Dataset library directory """ + # Import MobileSAM + mobile_sam = attempt_import( + "mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM" + ) + # Model directory model_dir = library_dir / "models" model_dir.mkdir(parents=True, exist_ok=True) @@ -189,7 +212,9 @@ def export_to_onnx(self, library_dir: Path): self.model.to("cpu") # Export settings - onnx_model = SamOnnxModel(self.model, return_single_mask=True) + onnx_model = mobile_sam.utils.onnx.SamOnnxModel( + self.model, return_single_mask=True + ) dynamic_axes = { "point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}, diff --git a/pixano_inference/github/sam.py b/pixano_inference/github/sam.py index 0e8e749..cf049ce 100644 --- a/pixano_inference/github/sam.py +++ b/pixano_inference/github/sam.py @@ -24,8 +24,8 @@ from onnxruntime.quantization.quantize import quantize_dynamic from pixano.core import BBox, CompressedRLE, Image from pixano.models import InferenceModel -from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry -from segment_anything.utils.onnx import SamOnnxModel + +from pixano_inference.utils import attempt_import class SAM(InferenceModel): @@ -56,6 +56,12 @@ def __init__( device (str, optional): Model GPU or CPU device (e.g. "cuda", "cpu"). Defaults to "cuda". """ + # Import SAM + segment_anything = attempt_import( + "segment_anything", + "segment-anything@git+https://github.com/facebookresearch/segment-anything", + ) + super().__init__( name=f"SAM_ViT_{size.upper()}", model_id=model_id, @@ -64,7 +70,9 @@ def __init__( ) # Model - self.model = sam_model_registry[f"vit_{size}"](checkpoint=checkpoint_path) + self.model = segment_anything.sam_model_registry[f"vit_{size}"]( + checkpoint=checkpoint_path + ) self.model.to(device=self.device) # Model path @@ -76,6 +84,7 @@ def preannotate( views: list[str], uri_prefix: str, threshold: float = 0.0, + prompt: str = "", ) -> list[dict]: """Inference pre-annotation for a batch @@ -84,25 +93,33 @@ def preannotate( views (list[str]): Dataset views uri_prefix (str): URI prefix for media files threshold (float, optional): Confidence threshold. Defaults to 0.0. + prompt (str, optional): Annotation text prompt. Defaults to "". Returns: list[dict]: Processed rows """ + # Import SAM + segment_anything = attempt_import( + "segment_anything", + "segment-anything@git+https://github.com/facebookresearch/segment-anything", + ) + rows = [] + _ = prompt # This model does not use prompts for view in views: # Iterate manually for x in range(batch.num_rows): # Preprocess image - im = Image.from_dict(batch[view][x].as_py()) + im: Image = Image.from_dict(batch[view][x].as_py()) im.uri_prefix = uri_prefix im = im.as_cv2() im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # Inference with torch.no_grad(): - generator = SamAutomaticMaskGenerator(self.model) + generator = segment_anything.SamAutomaticMaskGenerator(self.model) output = generator.generate(im) # Process model outputs @@ -114,8 +131,8 @@ def preannotate( "item_id": batch["id"][x].as_py(), "view_id": view, "bbox": BBox.from_xywh( - [coord.item() for coord in output[i]["bbox"]], - confidence=output[i]["predicted_iou"].item(), + [int(coord) for coord in output[i]["bbox"]], + confidence=float(output[i]["predicted_iou"]), ) .normalize(h, w) .to_dict(), @@ -147,6 +164,12 @@ def precompute_embeddings( pa.RecordBatch: Embedding rows """ + # Import SAM + segment_anything = attempt_import( + "segment_anything", + "segment-anything@git+https://github.com/facebookresearch/segment-anything", + ) + rows = [ { "id": batch["id"][x].as_py(), @@ -158,14 +181,14 @@ def precompute_embeddings( # Iterate manually for x in range(batch.num_rows): # Preprocess image - im = Image.from_dict(batch[view][x].as_py()) + im: Image = Image.from_dict(batch[view][x].as_py()) im.uri_prefix = uri_prefix im = im.as_cv2() im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # Inference with torch.no_grad(): - predictor = SamPredictor(self.model) + predictor = segment_anything.SamPredictor(self.model) predictor.set_image(im) img_embedding = predictor.get_image_embedding().cpu().numpy() @@ -183,6 +206,12 @@ def export_to_onnx(self, library_dir: Path): library_dir (Path): Dataset library directory """ + # Import SAM + segment_anything = attempt_import( + "segment_anything", + "segment-anything@git+https://github.com/facebookresearch/segment-anything", + ) + # Model directory model_dir = library_dir / "models" model_dir.mkdir(parents=True, exist_ok=True) @@ -191,7 +220,9 @@ def export_to_onnx(self, library_dir: Path): self.model.to("cpu") # Export settings - onnx_model = SamOnnxModel(self.model, return_single_mask=True) + onnx_model = segment_anything.utils.onnx.SamOnnxModel( + self.model, return_single_mask=True + ) dynamic_axes = { "point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}, diff --git a/pixano_inference/pytorch/deeplabv3.py b/pixano_inference/pytorch/deeplabv3.py index b0bdef1..f19b162 100644 --- a/pixano_inference/pytorch/deeplabv3.py +++ b/pixano_inference/pytorch/deeplabv3.py @@ -97,6 +97,7 @@ def preannotate( views: list[str], uri_prefix: str, threshold: float = 0.0, + prompt: str = "", ) -> list[dict]: """Inference pre-annotation for a batch @@ -105,18 +106,20 @@ def preannotate( views (list[str]): Dataset views uri_prefix (str): URI prefix for media files threshold (float, optional): Confidence threshold. Defaults to 0.0. + prompt (str, optional): Annotation text prompt. Defaults to "". Returns: list[dict]: Processed rows """ rows = [] + _ = prompt # This model does not use prompts for view in views: # PyTorch Transforms don't support different-sized image batches, so iterate manually for x in range(batch.num_rows): # Preprocess image - im = Image.from_dict(batch[view][x].as_py()) + im: Image = Image.from_dict(batch[view][x].as_py()) im.uri_prefix = uri_prefix im = im.as_pillow() im_tensor = self.transforms(im).unsqueeze(0).to(self.device) diff --git a/pixano_inference/pytorch/maskrcnnv2.py b/pixano_inference/pytorch/maskrcnnv2.py index 810d0cc..8f4b763 100644 --- a/pixano_inference/pytorch/maskrcnnv2.py +++ b/pixano_inference/pytorch/maskrcnnv2.py @@ -95,6 +95,7 @@ def preannotate( views: list[str], uri_prefix: str, threshold: float = 0.0, + prompt: str = "", ) -> list[dict]: """Inference pre-annotation for a batch @@ -103,18 +104,20 @@ def preannotate( views (list[str]): Dataset views uri_prefix (str): URI prefix for media files threshold (float, optional): Confidence threshold. Defaults to 0.0. + prompt (str, optional): Annotation text prompt. Defaults to "". Returns: list[dict]: Processed rows """ rows = [] + _ = prompt # This model does not use prompts for view in views: # PyTorch Transforms don't support different-sized image batches, so iterate manually for x in range(batch.num_rows): # Preprocess image - im = Image.from_dict(batch[view][x].as_py()) + im: Image = Image.from_dict(batch[view][x].as_py()) im.uri_prefix = uri_prefix im = im.as_pillow() im_tensor = self.transforms(im).unsqueeze(0).to(self.device) diff --git a/pixano_inference/pytorch/yolov5.py b/pixano_inference/pytorch/yolov5.py index 5347b6d..6aeb577 100644 --- a/pixano_inference/pytorch/yolov5.py +++ b/pixano_inference/pytorch/yolov5.py @@ -65,6 +65,7 @@ def preannotate( views: list[str], uri_prefix: str, threshold: float = 0.0, + prompt: str = "", ) -> list[dict]: """Inference pre-annotation for a batch @@ -73,18 +74,20 @@ def preannotate( views (list[str]): Dataset views uri_prefix (str): URI prefix for media files threshold (float, optional): Confidence threshold. Defaults to 0.0. + prompt (str, optional): Annotation text prompt. Defaults to "". Returns: list[dict]: Processed rows """ rows = [] + _ = prompt # This model does not use prompts for view in views: # Preprocess image batch im_batch = [] for x in range(batch.num_rows): - im = Image.from_dict(batch[view][x].as_py()) + im: Image = Image.from_dict(batch[view][x].as_py()) im.uri_prefix = uri_prefix im_batch.append(im.as_pillow()) diff --git a/pixano_inference/tensorflow/efficientdet.py b/pixano_inference/tensorflow/efficientdet.py index a963245..3444202 100644 --- a/pixano_inference/tensorflow/efficientdet.py +++ b/pixano_inference/tensorflow/efficientdet.py @@ -60,6 +60,7 @@ def preannotate( views: list[str], uri_prefix: str, threshold: float = 0.0, + prompt: str = "", ) -> list[dict]: """Inference pre-annotation for a batch @@ -68,18 +69,20 @@ def preannotate( views (list[str]): Dataset views uri_prefix (str): URI prefix for media files threshold (float, optional): Confidence threshold. Defaults to 0.0. + prompt (str, optional): Annotation text prompt. Defaults to "". Returns: list[dict]: Processed rows """ rows = [] + _ = prompt # This model does not use prompts for view in views: # TF.Hub Models don't support image batches, so iterate manually for x in range(batch.num_rows): # Preprocess image - im = Image.from_dict(batch[view][x].as_py()) + im: Image = Image.from_dict(batch[view][x].as_py()) im.uri_prefix = uri_prefix im = im.as_pillow() im_tensor = tf.expand_dims(tf.keras.utils.img_to_array(im), 0) diff --git a/pixano_inference/tensorflow/fasterrcnn.py b/pixano_inference/tensorflow/fasterrcnn.py index b892baf..7260a02 100644 --- a/pixano_inference/tensorflow/fasterrcnn.py +++ b/pixano_inference/tensorflow/fasterrcnn.py @@ -62,6 +62,7 @@ def preannotate( views: list[str], uri_prefix: str, threshold: float = 0.0, + prompt: str = "", ) -> list[dict]: """Inference pre-annotation for a batch @@ -70,18 +71,20 @@ def preannotate( views (list[str]): Dataset views uri_prefix (str): URI prefix for media files threshold (float, optional): Confidence threshold. Defaults to 0.0. + prompt (str, optional): Annotation text prompt. Defaults to "". Returns: list[dict]: Processed rows """ rows = [] + _ = prompt # This model does not use prompts for view in views: # TF.Hub Models don't support image batches, so iterate manually for x in range(batch.num_rows): # Preprocess image - im = Image.from_dict(batch[view][x].as_py()) + im: Image = Image.from_dict(batch[view][x].as_py()) im.uri_prefix = uri_prefix im = im.as_pillow() im_tensor = tf.expand_dims(tf.keras.utils.img_to_array(im), 0) diff --git a/pixano_inference/transformers/clip.py b/pixano_inference/transformers/clip.py index dc621fd..cf775fa 100644 --- a/pixano_inference/transformers/clip.py +++ b/pixano_inference/transformers/clip.py @@ -88,7 +88,7 @@ def precompute_embeddings( # Iterate manually for x in range(batch.num_rows): # Preprocess image - im = Image.from_dict(batch[view][x].as_py()) + im: Image = Image.from_dict(batch[view][x].as_py()) im.uri_prefix = uri_prefix im = im.as_pillow() diff --git a/pixano_inference/utils/__init__.py b/pixano_inference/utils/__init__.py new file mode 100644 index 0000000..24ecabe --- /dev/null +++ b/pixano_inference/utils/__init__.py @@ -0,0 +1,18 @@ +# @Copyright: CEA-LIST/DIASI/SIALV/LVA (2023) +# @Author: CEA-LIST/DIASI/SIALV/LVA +# @License: CECILL-C +# +# This software is a collaborative computer program whose purpose is to +# generate and explore labeled data for computer vision applications. +# This software is governed by the CeCILL-C license under French law and +# abiding by the rules of distribution of free software. You can use, +# modify and/ or redistribute the software under the terms of the CeCILL-C +# license as circulated by CEA, CNRS and INRIA at the following URL +# +# http://www.cecill.info + +from .main import attempt_import + +__all__ = [ + "attempt_import", +] diff --git a/pixano_inference/utils/main.py b/pixano_inference/utils/main.py new file mode 100644 index 0000000..956981c --- /dev/null +++ b/pixano_inference/utils/main.py @@ -0,0 +1,34 @@ +# @Copyright: CEA-LIST/DIASI/SIALV/LVA (2023) +# @Author: CEA-LIST/DIASI/SIALV/LVA +# @License: CECILL-C +# +# This software is a collaborative computer program whose purpose is to +# generate and explore labeled data for computer vision applications. +# This software is governed by the CeCILL-C license under French law and +# abiding by the rules of distribution of free software. You can use, +# modify and/ or redistribute the software under the terms of the CeCILL-C +# license as circulated by CEA, CNRS and INRIA at the following URL +# +# http://www.cecill.info + +import importlib +from types import ModuleType + + +def attempt_import(module: str, package: str = None) -> ModuleType: + """Import specified module, or raise ImportError with a helpful message + + Args: + module (str): The name of the module to import + package (str): The package to install, None if identical to module name. Defaults to None. + + Returns: + ModuleType: Imported module + """ + + try: + return importlib.import_module(module) + except ImportError as e: + raise ImportError( + f"Please install {module.split('.')[0]} to use this model: pip install {package or module}" + ) from e diff --git a/pyproject.toml b/pyproject.toml index ae97310..694fdfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "License :: CeCILL-C Free Software License Agreement (CECILL-C)", ] dependencies = [ - "pixano ~= 0.5.0b1", + "pixano ~= 0.5.0", "torch ~= 2.2.0", "torchaudio ~= 2.2.0", "torchvision ~= 0.17.0",