Skip to content

Commit

Permalink
add function to calculate image embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsmechtel committed Dec 9, 2024
1 parent 928abcb commit bb22306
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
20 changes: 20 additions & 0 deletions bioimageio_colab/register_sam_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,19 @@ def segment(

return features

def compute_embedding(model_cache: TTLCache, model_name, image, context=None):
user_id = context["user"].get("id")
sam_predictor = _load_model(model_cache, model_name, user_id)
logger.info(
f"User {user_id} - Computing image embedding (model: '{model_name}')..."
)
sam_predictor.set_image(_to_image(image))
return {
"model_name": model_name,
"original_size": sam_predictor.original_size,
"input_size": sam_predictor.input_size,
"features": sam_predictor.get_image_embedding().cpu().numpy(),
}

def clear_cache(embedding_cache: TTLCache, context: dict = None) -> bool:
user_id = context["user"].get("id")
Expand Down Expand Up @@ -239,6 +252,13 @@ async def register_service(args: dict) -> None:
# Returns:
# - a list of XY coordinates of the segmented polygon in the format (1, N, 2)
"segment": partial(segment, model_cache, embedding_cache),
# **Compute the embedding of an image**
# Params:
# - model name
# - image to compute the embeddings on
# Returns:
# - a dictionary containing the computed embedding, original size, and input size
"compute_embedding": partial(compute_embedding, model_cache),
# **Clear the embedding cache**
# Returns:
# - True if the embedding was removed successfully
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"]

[project]
name = "bioimageio-colab"
version = "0.1.3"
version = "0.1.4"
readme = "README.md"
description = "Collaborative image annotation and model training with human in the loop."
dependencies = [
Expand Down
16 changes: 15 additions & 1 deletion test/test_model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
SERVER_URL = "https://hypha.aicell.io"
WORKSPACE_NAME = "bioimageio-colab"
SERVICE_ID = "microsam"
MODEL_NAME = "vit_b"


def test_service_available():
Expand All @@ -24,6 +25,19 @@ def test_get_service():
assert segment_svc.get("segment")
assert segment_svc.get("clear_cache")

features = segment_svc.segment(model_name="vit_b", image=np.random.rand(256, 256), point_coordinates=[[128, 128]], point_labels=[1])
# Test segmentation
image = np.random.rand(256, 256)
features = segment_svc.segment(model_name=MODEL_NAME, image=image, point_coordinates=[[128, 128]], point_labels=[1])
assert features

# Test embedding caching
features = segment_svc.segment(model_name=MODEL_NAME, image=image, point_coordinates=[[20, 50]], point_labels=[1])
features = segment_svc.segment(model_name=MODEL_NAME, image=image, point_coordinates=[[180, 10]], point_labels=[1])

# Test embedding computation for running SAM client-side
result = segment_svc.compute_embedding(model_name=MODEL_NAME, image=image)
assert result
embedding = result["features"]
assert embedding.shape == (1, 256, 64, 64)

assert segment_svc.clear_cache()

0 comments on commit bb22306

Please sign in to comment.