From 86f007a7ca07382bf4696bba6ba28338a0564ce2 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 14 Dec 2024 19:01:18 +0100 Subject: [PATCH] Refactor peft test (#810) Refactor PEFT test --- test/test_models/test_peft_sam.py | 69 +++++++------------------------ 1 file changed, 14 insertions(+), 55 deletions(-) diff --git a/test/test_models/test_peft_sam.py b/test/test_models/test_peft_sam.py index f480f314..059b08c2 100644 --- a/test/test_models/test_peft_sam.py +++ b/test/test_models/test_peft_sam.py @@ -8,12 +8,7 @@ class TestPEFTSam(unittest.TestCase): model_type = "vit_b" - def test_lora_sam(self): - from micro_sam.models.peft_sam import PEFT_Sam, LoRASurgery - - _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") - peft_sam = PEFT_Sam(sam, rank=2, peft_module=LoRASurgery) - + def _check_output(self, peft_sam): shape = (3, 1024, 1024) expected_shape = (1, 3, 1024, 1024) with torch.no_grad(): @@ -22,90 +17,54 @@ def test_lora_sam(self): masks = output[0]["masks"] self.assertEqual(masks.shape, expected_shape) + def test_lora_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, LoRASurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=LoRASurgery) + self._check_output(peft_sam) + def test_fact_sam(self): from micro_sam.models.peft_sam import PEFT_Sam, FacTSurgery _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") peft_sam = PEFT_Sam(sam, rank=2, peft_module=FacTSurgery) - - shape = (3, 1024, 1024) - expected_shape = (1, 3, 1024, 1024) - with torch.no_grad(): - batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] - output = peft_sam(batched_input, multimask_output=True) - masks = output[0]["masks"] - self.assertEqual(masks.shape, expected_shape) + self._check_output(peft_sam) def test_attention_layer_peft_sam(self): from micro_sam.models.peft_sam import PEFT_Sam, AttentionSurgery _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") peft_sam = PEFT_Sam(sam, rank=2, peft_module=AttentionSurgery) - - shape = (3, 1024, 1024) - expected_shape = (1, 3, 1024, 1024) - with torch.no_grad(): - batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] - output = peft_sam(batched_input, multimask_output=True) - masks = output[0]["masks"] - self.assertEqual(masks.shape, expected_shape) + self._check_output(peft_sam) def test_norm_layer_peft_sam(self): from micro_sam.models.peft_sam import PEFT_Sam, LayerNormSurgery _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") peft_sam = PEFT_Sam(sam, rank=2, peft_module=LayerNormSurgery) - - shape = (3, 1024, 1024) - expected_shape = (1, 3, 1024, 1024) - with torch.no_grad(): - batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] - output = peft_sam(batched_input, multimask_output=True) - masks = output[0]["masks"] - self.assertEqual(masks.shape, expected_shape) + self._check_output(peft_sam) def test_bias_layer_peft_sam(self): from micro_sam.models.peft_sam import PEFT_Sam, BiasSurgery _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") peft_sam = PEFT_Sam(sam, rank=2, peft_module=BiasSurgery) - - shape = (3, 1024, 1024) - expected_shape = (1, 3, 1024, 1024) - with torch.no_grad(): - batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] - output = peft_sam(batched_input, multimask_output=True) - masks = output[0]["masks"] - self.assertEqual(masks.shape, expected_shape) + self._check_output(peft_sam) def test_ssf_peft_sam(self): from micro_sam.models.peft_sam import PEFT_Sam, SSFSurgery _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") peft_sam = PEFT_Sam(sam, rank=2, peft_module=SSFSurgery) - - shape = (3, 1024, 1024) - expected_shape = (1, 3, 1024, 1024) - with torch.no_grad(): - batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] - output = peft_sam(batched_input, multimask_output=True) - masks = output[0]["masks"] - self.assertEqual(masks.shape, expected_shape) + self._check_output(peft_sam) def test_adaptformer_peft_sam(self): from micro_sam.models.peft_sam import PEFT_Sam, AdaptFormer _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") peft_sam = PEFT_Sam(sam, rank=2, peft_module=AdaptFormer, projection_size=64, alpha=2.0, dropout=0.5) - - - shape = (3, 1024, 1024) - expected_shape = (1, 3, 1024, 1024) - with torch.no_grad(): - batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] - output = peft_sam(batched_input, multimask_output=True) - masks = output[0]["masks"] - self.assertEqual(masks.shape, expected_shape) + self._check_output(peft_sam) if __name__ == "__main__":