diff --git a/bioimageio_colab/register_sam_service.py b/bioimageio_colab/register_sam_service.py index d699607..ee8ab9a 100644 --- a/bioimageio_colab/register_sam_service.py +++ b/bioimageio_colab/register_sam_service.py @@ -181,6 +181,19 @@ def segment( return features +def compute_embedding(model_cache: TTLCache, model_name, image, context=None): + user_id = context["user"].get("id") + sam_predictor = _load_model(model_cache, model_name, user_id) + logger.info( + f"User {user_id} - Computing image embedding (model: '{model_name}')..." + ) + sam_predictor.set_image(_to_image(image)) + return { + "model_name": model_name, + "original_size": sam_predictor.original_size, + "input_size": sam_predictor.input_size, + "features": sam_predictor.get_image_embedding().cpu().numpy(), + } def clear_cache(embedding_cache: TTLCache, context: dict = None) -> bool: user_id = context["user"].get("id") @@ -239,6 +252,13 @@ async def register_service(args: dict) -> None: # Returns: # - a list of XY coordinates of the segmented polygon in the format (1, N, 2) "segment": partial(segment, model_cache, embedding_cache), + # **Compute the embedding of an image** + # Params: + # - model name + # - image to compute the embeddings on + # Returns: + # - a dictionary containing the computed embedding, original size, and input size + "compute_embedding": partial(compute_embedding, model_cache), # **Clear the embedding cache** # Returns: # - True if the embedding was removed successfully diff --git a/pyproject.toml b/pyproject.toml index 7fb3208..5bc71ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"] [project] name = "bioimageio-colab" -version = "0.1.3" +version = "0.1.4" readme = "README.md" description = "Collaborative image annotation and model training with human in the loop." dependencies = [ diff --git a/test/test_model_service.py b/test/test_model_service.py index ee3fc15..be49ecf 100644 --- a/test/test_model_service.py +++ b/test/test_model_service.py @@ -6,6 +6,7 @@ SERVER_URL = "https://hypha.aicell.io" WORKSPACE_NAME = "bioimageio-colab" SERVICE_ID = "microsam" +MODEL_NAME = "vit_b" def test_service_available(): @@ -24,6 +25,19 @@ def test_get_service(): assert segment_svc.get("segment") assert segment_svc.get("clear_cache") - features = segment_svc.segment(model_name="vit_b", image=np.random.rand(256, 256), point_coordinates=[[128, 128]], point_labels=[1]) + # Test segmentation + image = np.random.rand(256, 256) + features = segment_svc.segment(model_name=MODEL_NAME, image=image, point_coordinates=[[128, 128]], point_labels=[1]) assert features + + # Test embedding caching + features = segment_svc.segment(model_name=MODEL_NAME, image=image, point_coordinates=[[20, 50]], point_labels=[1]) + features = segment_svc.segment(model_name=MODEL_NAME, image=image, point_coordinates=[[180, 10]], point_labels=[1]) + + # Test embedding computation for running SAM client-side + result = segment_svc.compute_embedding(model_name=MODEL_NAME, image=image) + assert result + embedding = result["features"] + assert embedding.shape == (1, 256, 64, 64) + assert segment_svc.clear_cache()