Skip to content

Commit

Permalink
fix(back): Remove top-level GitHub imports to prevent errors
Browse files Browse the repository at this point in the history
  • Loading branch information
cpvannier committed Feb 29, 2024
1 parent 96cae78 commit 0dc083d
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 25 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,12 @@ jobs:
python-version: "3.10"

# Install PyTorch and TensorFlow CPU versions manually to prevent installing CUDA
# Install GitHub models manually as they cannot be included in PyPI
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pylint
python -m pip install torch~=2.2.0 torchaudio~=2.2.0 torchvision~=0.17.0 --index-url https://download.pytorch.org/whl/cpu
python -m pip install tensorflow-cpu~=2.15.0
python -m pip install segment-anything@git+https://github.com/facebookresearch/segment-anything
python -m pip install mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM
python -m pip install groundingdino@git+https://github.com/IDEA-Research/GroundingDINO
python -m pip install .
- name: Lint backend code with Pylint
Expand Down
34 changes: 25 additions & 9 deletions pixano_inference/github/groundingdino.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

import pyarrow as pa
import shortuuid
from groundingdino.util.inference import load_image, load_model, predict
from pixano.core import BBox, Image
from pixano.models import InferenceModel
from torchvision.ops import box_convert

from pixano_inference.utils import attempt_import


class GroundingDINO(InferenceModel):
"""GroundingDINO Model
Expand Down Expand Up @@ -50,6 +51,12 @@ def __init__(
device (str, optional): Model GPU or CPU device (e.g. "cuda", "cpu"). Defaults to "cuda".
"""

# Import GroundingDINO
groundingdino = attempt_import(
"groundingdino",
"groundingdino@git+https://github.com/IDEA-Research/GroundingDINO",
)

super().__init__(
name="GroundingDINO",
model_id=model_id,
Expand All @@ -58,7 +65,7 @@ def __init__(
)

# Model
self.model = load_model(
self.model = groundingdino.util.inference.load_model(
config_path.as_posix(),
checkpoint_path.as_posix(),
)
Expand Down Expand Up @@ -87,21 +94,30 @@ def preannotate(

rows = []

# Import GroundingDINO
groundingdino = attempt_import(
"groundingdino",
"groundingdino@git+https://github.com/IDEA-Research/GroundingDINO",
)

for view in views:
# Iterate manually
for x in range(batch.num_rows):
# Preprocess image
im: Image = Image.from_dict(batch[view][x].as_py())
im.uri_prefix = uri_prefix
_, image = load_image(im.path.as_posix())

_, image = groundingdino.util.inference.load_image(im.path.as_posix())

# Inference
bbox_tensor, logit_tensor, category_list = predict(
model=self.model,
image=image,
caption=prompt,
box_threshold=0.35,
text_threshold=0.25,
bbox_tensor, logit_tensor, category_list = (
groundingdino.util.inference.predict(
model=self.model,
image=image,
caption=prompt,
box_threshold=0.35,
text_threshold=0.25,
)
)

# Convert bounding boxes from cyxcywh to xywh
Expand Down
34 changes: 28 additions & 6 deletions pixano_inference/github/mobile_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
import pyarrow as pa
import shortuuid
import torch
from mobile_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from mobile_sam.utils.onnx import SamOnnxModel
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic
from pixano.core import BBox, CompressedRLE, Image
from pixano.models import InferenceModel

from pixano_inference.utils import attempt_import


class MobileSAM(InferenceModel):
"""MobileSAM
Expand Down Expand Up @@ -54,6 +54,11 @@ def __init__(
device (str, optional): Model GPU or CPU device (e.g. "cuda", "cpu"). Defaults to "cpu".
"""

# Import MobileSAM
mobile_sam = attempt_import(
"mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM"
)

super().__init__(
name="Mobile_SAM",
model_id=model_id,
Expand All @@ -62,7 +67,7 @@ def __init__(
)

# Model
self.model = sam_model_registry["vit_t"](checkpoint=checkpoint_path)
self.model = mobile_sam.sam_model_registry["vit_t"](checkpoint=checkpoint_path)
self.model.to(device=self.device)

# Model path
Expand All @@ -89,6 +94,11 @@ def preannotate(
list[dict]: Processed rows
"""

# Import MobileSAM
mobile_sam = attempt_import(
"mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM"
)

rows = []
_ = prompt # This model does not use prompts

Expand All @@ -103,7 +113,7 @@ def preannotate(

# Inference
with torch.no_grad():
generator = SamAutomaticMaskGenerator(self.model)
generator = mobile_sam.SamAutomaticMaskGenerator(self.model)
output = generator.generate(im)

# Process model outputs
Expand Down Expand Up @@ -148,6 +158,11 @@ def precompute_embeddings(
pa.RecordBatch: Embedding rows
"""

# Import MobileSAM
mobile_sam = attempt_import(
"mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM"
)

rows = [
{
"id": batch["id"][x].as_py(),
Expand All @@ -166,7 +181,7 @@ def precompute_embeddings(

# Inference
with torch.no_grad():
predictor = SamPredictor(self.model)
predictor = mobile_sam.SamPredictor(self.model)
predictor.set_image(im)
img_embedding = predictor.get_image_embedding().cpu().numpy()

Expand All @@ -184,6 +199,11 @@ def export_to_onnx(self, library_dir: Path):
library_dir (Path): Dataset library directory
"""

# Import MobileSAM
mobile_sam = attempt_import(
"mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM"
)

# Model directory
model_dir = library_dir / "models"
model_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -192,7 +212,9 @@ def export_to_onnx(self, library_dir: Path):
self.model.to("cpu")

# Export settings
onnx_model = SamOnnxModel(self.model, return_single_mask=True)
onnx_model = mobile_sam.utils.onnx.SamOnnxModel(
self.model, return_single_mask=True
)
dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
Expand Down
40 changes: 34 additions & 6 deletions pixano_inference/github/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from onnxruntime.quantization.quantize import quantize_dynamic
from pixano.core import BBox, CompressedRLE, Image
from pixano.models import InferenceModel
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel

from pixano_inference.utils import attempt_import


class SAM(InferenceModel):
Expand Down Expand Up @@ -56,6 +56,12 @@ def __init__(
device (str, optional): Model GPU or CPU device (e.g. "cuda", "cpu"). Defaults to "cuda".
"""

# Import SAM
segment_anything = attempt_import(
"segment_anything",
"segment-anything@git+https://github.com/facebookresearch/segment-anything",
)

super().__init__(
name=f"SAM_ViT_{size.upper()}",
model_id=model_id,
Expand All @@ -64,7 +70,9 @@ def __init__(
)

# Model
self.model = sam_model_registry[f"vit_{size}"](checkpoint=checkpoint_path)
self.model = segment_anything.sam_model_registry[f"vit_{size}"](
checkpoint=checkpoint_path
)
self.model.to(device=self.device)

# Model path
Expand All @@ -91,6 +99,12 @@ def preannotate(
list[dict]: Processed rows
"""

# Import SAM
segment_anything = attempt_import(
"segment_anything",
"segment-anything@git+https://github.com/facebookresearch/segment-anything",
)

rows = []
_ = prompt # This model does not use prompts

Expand All @@ -105,7 +119,7 @@ def preannotate(

# Inference
with torch.no_grad():
generator = SamAutomaticMaskGenerator(self.model)
generator = segment_anything.SamAutomaticMaskGenerator(self.model)
output = generator.generate(im)

# Process model outputs
Expand Down Expand Up @@ -150,6 +164,12 @@ def precompute_embeddings(
pa.RecordBatch: Embedding rows
"""

# Import SAM
segment_anything = attempt_import(
"segment_anything",
"segment-anything@git+https://github.com/facebookresearch/segment-anything",
)

rows = [
{
"id": batch["id"][x].as_py(),
Expand All @@ -168,7 +188,7 @@ def precompute_embeddings(

# Inference
with torch.no_grad():
predictor = SamPredictor(self.model)
predictor = segment_anything.SamPredictor(self.model)
predictor.set_image(im)
img_embedding = predictor.get_image_embedding().cpu().numpy()

Expand All @@ -186,6 +206,12 @@ def export_to_onnx(self, library_dir: Path):
library_dir (Path): Dataset library directory
"""

# Import SAM
segment_anything = attempt_import(
"segment_anything",
"segment-anything@git+https://github.com/facebookresearch/segment-anything",
)

# Model directory
model_dir = library_dir / "models"
model_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -194,7 +220,9 @@ def export_to_onnx(self, library_dir: Path):
self.model.to("cpu")

# Export settings
onnx_model = SamOnnxModel(self.model, return_single_mask=True)
onnx_model = segment_anything.utils.onnx.SamOnnxModel(
self.model, return_single_mask=True
)
dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
Expand Down
18 changes: 18 additions & 0 deletions pixano_inference/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# @Copyright: CEA-LIST/DIASI/SIALV/LVA (2023)
# @Author: CEA-LIST/DIASI/SIALV/LVA <[email protected]>
# @License: CECILL-C
#
# This software is a collaborative computer program whose purpose is to
# generate and explore labeled data for computer vision applications.
# This software is governed by the CeCILL-C license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL-C
# license as circulated by CEA, CNRS and INRIA at the following URL
#
# http://www.cecill.info

from .main import attempt_import

__all__ = [
"attempt_import",
]
34 changes: 34 additions & 0 deletions pixano_inference/utils/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# @Copyright: CEA-LIST/DIASI/SIALV/LVA (2023)
# @Author: CEA-LIST/DIASI/SIALV/LVA <[email protected]>
# @License: CECILL-C
#
# This software is a collaborative computer program whose purpose is to
# generate and explore labeled data for computer vision applications.
# This software is governed by the CeCILL-C license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL-C
# license as circulated by CEA, CNRS and INRIA at the following URL
#
# http://www.cecill.info

import importlib
from types import ModuleType


def attempt_import(module: str, package: str = None) -> ModuleType:
"""Import specified module, or raise ImportError with a helpful message
Args:
module (str): The name of the module to import
package (str): The package to install, None if identical to module name. Defaults to None.
Returns:
ModuleType: Imported module
"""

try:
return importlib.import_module(module)
except ImportError as e:
raise ImportError(
f"Please install {module} to use this model: pip install {package or module}"
) from e

0 comments on commit 0dc083d

Please sign in to comment.