-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
To be completed with tests using image preprocessing, e.g. test cosine similarity on a relevant pair of images
- Loading branch information
Showing
3 changed files
with
96 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from pathlib import Path | ||
from warnings import warn | ||
|
||
import pytest | ||
import torch | ||
from transformers import AutoModel # type: ignore | ||
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore | ||
|
||
from refiners.fluxion.utils import load_from_safetensors, manual_seed | ||
from refiners.foundationals.dinov2 import DINOv2_base, DINOv2_large, DINOv2_small | ||
from refiners.foundationals.dinov2.vit import ViT | ||
|
||
# TODO: add DINOv2 with registers ("dinov2_vits14_reg", etc). At the time of writing, those are not yet supported in | ||
# transformers (https://github.com/huggingface/transformers/issues/27379). Alternatively, it is also possible to use | ||
# facebookresearch/dinov2 directly (https://github.com/finegrain-ai/refiners/pull/132). | ||
FLAVORS = [ | ||
"dinov2_vits14", | ||
"dinov2_vitb14", | ||
"dinov2_vitl14", | ||
] | ||
|
||
|
||
@pytest.fixture(scope="module", params=FLAVORS) | ||
def flavor(request: pytest.FixtureRequest) -> str: | ||
return request.param | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device) -> ViT: | ||
# TODO: parameterize the various checkpoints | ||
weights = test_weights_path / f"{flavor}_pretrain.safetensors" | ||
if not weights.is_file(): | ||
warn(f"could not find weights at {weights}, skipping") | ||
pytest.skip(allow_module_level=True) | ||
match flavor: | ||
case "dinov2_vits14": | ||
backbone = DINOv2_small(device=test_device) | ||
case "dinov2_vitb14": | ||
backbone = DINOv2_base(device=test_device) | ||
case "dinov2_vitl14": | ||
backbone = DINOv2_large(device=test_device) | ||
case _: | ||
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}") | ||
tensors = load_from_safetensors(weights) | ||
backbone.load_state_dict(tensors) | ||
return backbone | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def dinov2_weights_path(test_weights_path: Path, flavor: str): | ||
match flavor: | ||
case "dinov2_vits14": | ||
name = "dinov2-small" | ||
case "dinov2_vitb14": | ||
name = "dinov2-base" | ||
case "dinov2_vitl14": | ||
name = "dinov2-large" | ||
case _: | ||
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}") | ||
r = test_weights_path / "facebook" / name | ||
if not r.is_dir(): | ||
warn(f"could not find DINOv2 weights at {r}, skipping") | ||
pytest.skip(allow_module_level=True) | ||
return r | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def ref_backbone(dinov2_weights_path: Path, test_device: torch.device) -> Dinov2Model: | ||
backbone = AutoModel.from_pretrained(dinov2_weights_path) # type: ignore | ||
assert isinstance(backbone, Dinov2Model) | ||
return backbone.to(test_device) # type: ignore | ||
|
||
|
||
def test_encoder( | ||
ref_backbone: Dinov2Model, | ||
our_backbone: ViT, | ||
test_device: torch.device, | ||
): | ||
manual_seed(42) | ||
|
||
# Position encoding interpolation [1] at runtime is not supported yet. So stick to the default image resolution | ||
# e.g. using (224, 224) pixels as input would give a runtime error (sequence size mismatch) | ||
# [1]: https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179 | ||
assert our_backbone.image_size == 518 | ||
|
||
x = torch.randn(1, 3, 518, 518).to(test_device) | ||
|
||
with torch.no_grad(): | ||
ref_features = ref_backbone(x).last_hidden_state | ||
our_features = our_backbone(x) | ||
|
||
assert (our_features - ref_features).abs().max() < 1e-3 |