Skip to content

Commit

Permalink
sam: wrap high-level methods with no_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
deltheil committed Dec 19, 2023
1 parent e789225 commit 22ce3fd
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/refiners/foundationals/segment_anything/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)

@torch.no_grad()
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
original_size = (image.height, image.width)
target_size = self.compute_target_size(original_size)
Expand All @@ -47,6 +48,7 @@ def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
original_image_size=original_size,
)

@torch.no_grad()
def predict(
self,
input: Image.Image | ImageEmbedding,
Expand Down
2 changes: 0 additions & 2 deletions tests/foundationals/segment_anything/test_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> N
assert torch.equal(input=iou_prediction, other=facebook_prediction)


@torch.no_grad()
def test_predictor(
facebook_sam_h_predictor: FacebookSAMPredictor, sam_h: SegmentAnythingH, truck: Image.Image, prompt: SAMPrompt
) -> None:
Expand All @@ -312,7 +311,6 @@ def test_predictor(
assert isclose(scores[i].item(), facebook_scores[i].item(), rel_tol=1e-05)


@torch.no_grad()
def test_predictor_image_embedding(sam_h: SegmentAnythingH, truck: Image.Image, one_prompt: SAMPrompt) -> None:
masks_ref, scores_ref, _ = sam_h.predict(truck, **one_prompt.__dict__)

Expand Down

0 comments on commit 22ce3fd

Please sign in to comment.