Skip to content

Commit

Permalink
prepare_test_weights: add DINOv2
Browse files Browse the repository at this point in the history
  • Loading branch information
deltheil committed Dec 17, 2023
1 parent 9ab98ac commit feb1741
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions scripts/prepare_test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,30 @@ def download_sam():
)


def download_dinov2():
# For conversion
weights_folder = os.path.join(test_weights_dir)
urls = [
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth",
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth",
]
download_files(urls, weights_folder)

# For testing (note: versions with registers are not available yet on HuggingFace)
for repo in ["dinov2-small", "dinov2-base", "dinov2-large"]:
base_folder = os.path.join(test_weights_dir, "facebook", repo)
urls = [
f"https://huggingface.co/facebook/{repo}/raw/main/config.json",
f"https://huggingface.co/facebook/{repo}/raw/main/preprocessor_config.json",
f"https://huggingface.co/facebook/{repo}/resolve/main/pytorch_model.bin",
]
download_files(urls, base_folder)


def printg(msg: str):
"""print in green color"""
print("\033[92m" + msg + "\033[0m")
Expand Down Expand Up @@ -541,6 +565,45 @@ def convert_sam():
)


def convert_dinov2():
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vits14_pretrain.pth",
"tests/weights/dinov2_vits14_pretrain.safetensors",
expected_hash="b7f9b294",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitb14_pretrain.pth",
"tests/weights/dinov2_vitb14_pretrain.safetensors",
expected_hash="d72c767b",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitl14_pretrain.pth",
"tests/weights/dinov2_vitl14_pretrain.safetensors",
expected_hash="71eb98d1",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vits14_reg4_pretrain.pth",
"tests/weights/dinov2_vits14_reg4_pretrain.safetensors",
expected_hash="89118b46",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitb14_reg4_pretrain.pth",
"tests/weights/dinov2_vitb14_reg4_pretrain.safetensors",
expected_hash="b0296f77",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitl14_reg4_pretrain.pth",
"tests/weights/dinov2_vitl14_reg4_pretrain.safetensors",
expected_hash="b3d877dc",
)


def download_all():
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
download_sd15("runwayml/stable-diffusion-v1-5")
Expand All @@ -554,6 +617,7 @@ def download_all():
download_ip_adapter()
download_t2i_adapter()
download_sam()
download_dinov2()


def convert_all():
Expand All @@ -567,6 +631,7 @@ def convert_all():
convert_ip_adapter()
convert_t2i_adapter()
convert_sam()
convert_dinov2()


def main():
Expand Down

0 comments on commit feb1741

Please sign in to comment.