Skip to content

Commit

Permalink
match mm branch
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Feb 18, 2024
1 parent ae0d176 commit c9ed5e8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/test_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_feature_extractor():
device = "cuda" if torch.cuda.is_available() else "cpu"
segmentation_types = ["none", "grid", "slic", "random", "stego"]
feature_types = ["dino", "dinov2", "stego"]
backbone_types = ["vit_small", "vit_base", "vit_small_reg", "vit_base_reg"]
backbone_types = ["vit_small", "vit_base"] # "vit_small_reg", "vit_base_reg"]

for seg_type, feat_type, back_type in itertools.product(segmentation_types, feature_types, backbone_types):
if seg_type == "stego" and feat_type != "stego":
Expand All @@ -39,7 +39,7 @@ def test_feature_extractor():

ax[0].imshow(transform(img).permute(0, 2, 3, 1)[0].cpu())
ax[0].set_title("Image")
ax[1].imshow(seg.cpu(), cmap=plt.colormaps.get("inferno"))
ax[1].imshow(seg.cpu(), cmap=plt.colormaps.get("gray"))
ax[1].set_title("Segmentation")
plt.tight_layout()

Expand Down

0 comments on commit c9ed5e8

Please sign in to comment.