From c9ed5e821ff6eae103f6d8386a1d440912e7c814 Mon Sep 17 00:00:00 2001 From: Jonas Frey Date: Sun, 18 Feb 2024 12:05:27 +0100 Subject: [PATCH] match mm branch --- tests/test_feature_extractor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_feature_extractor.py b/tests/test_feature_extractor.py index e98dace2..abce5bf6 100644 --- a/tests/test_feature_extractor.py +++ b/tests/test_feature_extractor.py @@ -12,7 +12,7 @@ def test_feature_extractor(): device = "cuda" if torch.cuda.is_available() else "cpu" segmentation_types = ["none", "grid", "slic", "random", "stego"] feature_types = ["dino", "dinov2", "stego"] - backbone_types = ["vit_small", "vit_base", "vit_small_reg", "vit_base_reg"] + backbone_types = ["vit_small", "vit_base"] # "vit_small_reg", "vit_base_reg"] for seg_type, feat_type, back_type in itertools.product(segmentation_types, feature_types, backbone_types): if seg_type == "stego" and feat_type != "stego": @@ -39,7 +39,7 @@ def test_feature_extractor(): ax[0].imshow(transform(img).permute(0, 2, 3, 1)[0].cpu()) ax[0].set_title("Image") - ax[1].imshow(seg.cpu(), cmap=plt.colormaps.get("inferno")) + ax[1].imshow(seg.cpu(), cmap=plt.colormaps.get("gray")) ax[1].set_title("Segmentation") plt.tight_layout()