Skip to content

Commit

Permalink
Merge branch 'main' into pr/add-bunch-mkdocs-configs
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 authored Jan 31, 2024
2 parents 5e9888b + b2c26ab commit 173307c
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 62 deletions.
11 changes: 4 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ training = [
"torchvision>=0.16.1",
"loguru>=0.7.2",
"wandb>=0.16.0",
# Added scipy as a work around until this PR gets merged:
# https://github.com/TimDettmers/bitsandbytes/pull/525
"scipy>=1.11.4",
"datasets>=2.15.0",
"tomli>=2.0.1",
]
Expand Down Expand Up @@ -54,11 +51,11 @@ build-backend = "hatchling.build"
[tool.rye]
managed = true
dev-dependencies = [
"pyright == 1.1.342",
"ruff>=0.0.292",
"pyright==1.1.349",
"ruff>=0.1.15",
"docformatter>=1.7.5",
"pytest>=7.4.2",
"mkdocs-material>=9.5.3",
"pytest>=8.0.0",
"mkdocs-material>=9.5.6",
"coverage>=7.4.1",
"mkdocstrings[python]>=0.24.0",
]
Expand Down
74 changes: 37 additions & 37 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,38 @@
# with-sources: false

-e file:.
aiohttp==3.9.1
aiohttp==3.9.3
aiosignal==1.3.1
annotated-types==0.6.0
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.1.0
bitsandbytes==0.41.3
attrs==23.2.0
bitsandbytes==0.42.0
certifi==2023.11.17
charset-normalizer==3.3.2
click==8.1.7
datasets==2.15.0
diffusers==0.24.0
datasets==2.16.1
diffusers==0.25.1
dill==0.3.7
docker-pycreds==0.4.0
filelock==3.13.1
frozenlist==1.4.0
frozenlist==1.4.1
fsspec==2023.10.0
gitdb==4.0.11
gitpython==3.1.40
huggingface-hub==0.19.4
gitpython==3.1.41
huggingface-hub==0.20.3
idna==3.6
importlib-metadata==7.0.0
importlib-metadata==7.0.1
invisible-watermark==0.2.0
jaxtyping==0.2.24
jinja2==3.1.2
jaxtyping==0.2.25
jinja2==3.1.3
loguru==0.7.2
markupsafe==2.1.3
markupsafe==2.1.4
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.15
networkx==3.2.1
numpy==1.26.2
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
Expand All @@ -49,49 +49,49 @@ nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
opencv-python==4.8.1.78
opencv-python==4.9.0.80
packaging==23.2
pandas==2.1.4
pillow==10.1.0
pandas==2.2.0
pillow==10.2.0
piq==0.8.0
prodigyopt==1.0
protobuf==4.25.1
psutil==5.9.6
pyarrow==14.0.1
protobuf==4.25.2
psutil==5.9.8
pyarrow==15.0.0
pyarrow-hotfix==0.6
pydantic==2.5.2
pydantic-core==2.14.5
pydantic==2.6.0
pydantic-core==2.16.1
python-dateutil==2.8.2
pytz==2023.3.post1
pytz==2023.4
pywavelets==1.5.0
pyyaml==6.0.1
regex==2023.10.3
regex==2023.12.25
requests==2.31.0
safetensors==0.4.1
scipy==1.11.4
safetensors==0.4.2
scipy==1.12.0
segment-anything-py==1.0
sentry-sdk==1.38.0
sentry-sdk==1.40.0
setproctitle==1.3.3
six==1.16.0
smmap==5.0.1
sympy==1.12
tokenizers==0.15.0
tokenizers==0.15.1
tomli==2.0.1
torch==2.1.1
torchvision==0.16.1
torch==2.2.0
torchvision==0.17.0
tqdm==4.66.1
transformers==4.35.2
triton==2.1.0
transformers==4.37.2
triton==2.2.0
typeguard==2.13.3
typing-extensions==4.8.0
tzdata==2023.3
urllib3==2.1.0
wandb==0.16.1
typing-extensions==4.9.0
tzdata==2023.4
urllib3==2.2.0
wandb==0.16.2
xxhash==3.4.1
yarl==1.9.4
zipp==3.17.0
# The following packages are considered to be unsafe in a requirements file:
setuptools==69.0.2
setuptools==69.0.3
10 changes: 5 additions & 5 deletions src/refiners/foundationals/clip/concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ def __init__(
) -> None:
with self.setup_adapter(target):
super().__init__(fl.Lambda(func=self.lookup))
self.old_weight = cast(Parameter, target.weight)
self.new_weight = Parameter(
p = Parameter(
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
) # requires_grad=True by default
self.old_weight = cast(Parameter, target.weight)
self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736

# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
def lookup(self, x: Tensor) -> Tensor:
Expand All @@ -33,9 +34,8 @@ def lookup(self, x: Tensor) -> Tensor:

def add_embedding(self, embedding: Tensor) -> None:
assert embedding.shape == (self.old_weight.shape[1],)
self.new_weight = Parameter(
cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])
)
p = Parameter(cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)]))
self.new_weight = cast(Parameter, p) # PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736

@property
def num_embeddings(self) -> int:
Expand Down
20 changes: 11 additions & 9 deletions src/refiners/foundationals/dinov2/vit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import cast

import torch
from torch import Tensor

Expand Down Expand Up @@ -60,18 +62,18 @@ def __init__(
super().__init__()
self.embedding_dim = embedding_dim

self.register_parameter(
name="weight",
param=torch.nn.Parameter(
torch.full(
size=(embedding_dim,),
fill_value=init_value,
dtype=dtype,
device=device,
),
p = torch.nn.Parameter(
torch.full(
size=(embedding_dim,),
fill_value=init_value,
dtype=dtype,
device=device,
),
)

# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
self.register_parameter(name="weight", param=cast(torch.nn.Parameter, p))

def forward(self, x: Tensor) -> Tensor:
return x * self.weight

Expand Down
7 changes: 4 additions & 3 deletions src/refiners/foundationals/segment_anything/prompt_encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Sequence
from enum import Enum, auto
from typing import cast

import torch
from jaxtyping import Float, Int
Expand Down Expand Up @@ -180,9 +181,9 @@ def __init__(
dtype=dtype,
),
)
self.register_parameter(
"no_mask_embedding", nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype))
)
p = nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype))
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
self.register_parameter("no_mask_embedding", cast(nn.Parameter, p))

def get_no_mask_dense_embedding(
self, image_embedding_size: tuple[int, int], batch_size: int = 1
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion tests/training_utils/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import cast
from warnings import warn

import pytest
Expand Down Expand Up @@ -100,7 +101,8 @@ def test_count_learnable_parameters_with_params() -> None:
nn.Parameter(torch.randn(5), requires_grad=False),
nn.Parameter(torch.randn(3, 3), requires_grad=True),
]
assert count_learnable_parameters(params) == 13
# cast because of PyTorch 2.2, see https://github.com/pytorch/pytorch/issues/118736
assert count_learnable_parameters(cast(list[nn.Parameter], params)) == 13


def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None:
Expand Down

0 comments on commit 173307c

Please sign in to comment.