diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index 6f83f720e..f8abfb71e 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -39,6 +39,7 @@ def __init__( self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype) self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype) + @torch.no_grad() def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding: original_size = (image.height, image.width) target_size = self.compute_target_size(original_size) @@ -47,6 +48,7 @@ def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding: original_image_size=original_size, ) + @torch.no_grad() def predict( self, input: Image.Image | ImageEmbedding, diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index 1e00e646d..0c5fbf978 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -289,7 +289,6 @@ def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> N assert torch.equal(input=iou_prediction, other=facebook_prediction) -@torch.no_grad() def test_predictor( facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, prompt: SAMPrompt ) -> None: @@ -312,7 +311,6 @@ def test_predictor( assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05) -@torch.no_grad() def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None: masks_ref, scores_ref, _ = sam_h.predict(truck, **one_prompt.__dict__)