Skip to content

Commit

Permalink
Small update (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthijsBurgh authored Jun 28, 2024
2 parents a7b88d7 + e13f787 commit 8a25ffc
Show file tree
Hide file tree
Showing 8 changed files with 319 additions and 271 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenv-create: true
virtualenvs-create: true
virtualenvs-in-project: true
- name: Poetry cache
id: poetry_cache
Expand Down
102 changes: 57 additions & 45 deletions facenet_pytorch/models/inception_resnet_v1.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from __future__ import annotations

import os
from pathlib import Path
from typing import TYPE_CHECKING

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import functional as tv_functional

from .utils.download import download_url_to_file

if TYPE_CHECKING:
from torch.nn.common_types import _size_2_t


class BasicConv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
def __init__( # noqa: PLR0913
self, in_planes: int, out_planes: int, kernel_size: _size_2_t, stride: _size_2_t, padding: _size_2_t = 0
) -> None:
super().__init__()
self.conv = nn.Conv2d(
in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False
Expand All @@ -21,15 +30,15 @@ def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
)
self.relu = nn.ReLU(inplace=False)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
return x # noqa: RET504


class Block35(nn.Module):
def __init__(self, scale=1.0):
def __init__(self, scale: float = 1.0) -> None:
super().__init__()

self.scale = scale
Expand All @@ -49,19 +58,18 @@ def __init__(self, scale=1.0):
self.conv2d = nn.Conv2d(96, 256, kernel_size=1, stride=1)
self.relu = nn.ReLU(inplace=False)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
out = torch.cat((x0, x1, x2), 1)
out = self.conv2d(out)
out = out * self.scale + x
out = self.relu(out)
return out
return self.relu(out)


class Block17(nn.Module):
def __init__(self, scale=1.0):
def __init__(self, scale: float = 1.0) -> None:
super().__init__()

self.scale = scale
Expand All @@ -77,18 +85,17 @@ def __init__(self, scale=1.0):
self.conv2d = nn.Conv2d(256, 896, kernel_size=1, stride=1)
self.relu = nn.ReLU(inplace=False)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x0 = self.branch0(x)
x1 = self.branch1(x)
out = torch.cat((x0, x1), 1)
out = self.conv2d(out)
out = out * self.scale + x
out = self.relu(out)
return out
return self.relu(out)


class Block8(nn.Module):
def __init__(self, scale=1.0, noReLU=False):
def __init__(self, scale: float = 1.0, *, noReLU: bool = False) -> None: # noqa: N803
super().__init__()

self.scale = scale
Expand All @@ -106,7 +113,7 @@ def __init__(self, scale=1.0, noReLU=False):
if not self.noReLU:
self.relu = nn.ReLU(inplace=False)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x0 = self.branch0(x)
x1 = self.branch1(x)
out = torch.cat((x0, x1), 1)
Expand All @@ -117,8 +124,8 @@ def forward(self, x):
return out


class Mixed_6a(nn.Module):
def __init__(self):
class Mixed6a(nn.Module):
def __init__(self) -> None:
super().__init__()

self.branch0 = BasicConv2d(256, 384, kernel_size=3, stride=2)
Expand All @@ -131,16 +138,15 @@ def __init__(self):

self.branch2 = nn.MaxPool2d(3, stride=2)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
out = torch.cat((x0, x1, x2), 1)
return out
return torch.cat((x0, x1, x2), 1)


class Mixed_7a(nn.Module):
def __init__(self):
class Mixed7a(nn.Module):
def __init__(self) -> None:
super().__init__()

self.branch0 = nn.Sequential(
Expand All @@ -159,13 +165,12 @@ def __init__(self):

self.branch3 = nn.MaxPool2d(3, stride=2)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
return torch.cat((x0, x1, x2, x3), 1)


class InceptionResnetV1(nn.Module):
Expand All @@ -187,7 +192,15 @@ class InceptionResnetV1(nn.Module):
dropout_prob {float} -- Dropout probability. (default: {0.6})
"""

def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None):
def __init__( # noqa: PLR0913
self,
pretrained: str | None = None,
*,
classify: bool = False,
num_classes: int | None = None,
dropout_prob: float = 0.6,
device: torch.device | str | None = None,
) -> None:
super().__init__()

# Set simple attributes
Expand All @@ -200,7 +213,8 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr
elif pretrained == "casia-webface":
tmp_classes = 10575
elif pretrained is None and self.classify and self.num_classes is None:
raise Exception('If "pretrained" is not specified and "classify" is True, "num_classes" must be specified')
msg = "If 'pretrained' is not specified and 'classify' is True, 'num_classes' must be specified"
raise ValueError(msg)

# Define layers
self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
Expand All @@ -213,7 +227,7 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr
self.repeat_1 = nn.Sequential(
Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17), Block35(scale=0.17)
)
self.mixed_6a = Mixed_6a()
self.mixed_6a = Mixed6a()
self.repeat_2 = nn.Sequential(
Block17(scale=0.10),
Block17(scale=0.10),
Expand All @@ -226,7 +240,7 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr
Block17(scale=0.10),
Block17(scale=0.10),
)
self.mixed_7a = Mixed_7a()
self.mixed_7a = Mixed7a()
self.repeat_3 = nn.Sequential(
Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20), Block8(scale=0.20)
)
Expand All @@ -248,7 +262,7 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr
self.device = device
self.to(device)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Calculate embeddings or logits given a batch of input image tensors.
Arguments:
Expand All @@ -274,14 +288,11 @@ def forward(self, x):
x = self.dropout(x)
x = self.last_linear(x.view(x.shape[0], -1))
x = self.last_bn(x)
if self.classify:
x = self.logits(x)
else:
x = F.normalize(x, p=2, dim=1)
x = self.logits(x) if self.classify else tv_functional.normalize(x, p=2, dim=1)
return x


def load_weights(mdl, name):
def load_weights(mdl: torch.nn.Module, name: str) -> None:
"""Download pretrained state_dict and load into model.
Arguments:
Expand All @@ -293,24 +304,25 @@ def load_weights(mdl, name):
"""
if name == "vggface2":
path = "https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180402-114759-vggface2.pt"
hash_prefix = "281cebca8662831adb987a874bdcb36e73f5b1c6dc5ee5878f305e985625d99b"
elif name == "casia-webface":
path = "https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180408-102900-casia-webface.pt"
hash_prefix = "7a67afdbbc995fce5e10128675e318799a70698c2f433ba75dd7eb9a2f096e7d"
else:
raise ValueError('Pretrained models only exist for "vggface2" and "casia-webface"')
msg = "Pretrained models only exist for 'vggface2' and 'casia-webface'"
raise ValueError(msg)

model_dir = os.path.join(get_torch_home(), "checkpoints")
os.makedirs(model_dir, exist_ok=True)
model_dir = get_torch_home() / "checkpoints"
model_dir.mkdir(parents=True, exist_ok=True)

cached_file = os.path.join(model_dir, os.path.basename(path))
if not os.path.exists(cached_file):
download_url_to_file(path, cached_file)
cached_file = model_dir / Path(path).name
if not cached_file.exists():
download_url_to_file(path, cached_file, hash_prefix)

state_dict = torch.load(cached_file)
mdl.load_state_dict(state_dict)


def get_torch_home():
torch_home = os.path.expanduser(
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
)
return torch_home
def get_torch_home() -> Path:
"""Get the torch cache directory."""
return Path(os.getenv("TORCH_HOME", Path(os.getenv("XDG_CACHE_HOME", "~/.cache")) / "torch")).expanduser()
5 changes: 4 additions & 1 deletion facenet_pytorch/models/mtcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@

import importlib.resources
import os
from typing import TYPE_CHECKING

import numpy as np
import torch
from PIL.Image import Image
from torch import nn

import facenet_pytorch.data
from facenet_pytorch.models.utils.detect_face import detect_face, extract_face, get_size

if TYPE_CHECKING:
from PIL.Image import Image


class PNet(nn.Module):
"""MTCNN PNet.
Expand Down
Loading

0 comments on commit 8a25ffc

Please sign in to comment.