Skip to content

Commit

Permalink
expose the dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
1adrianb committed Jun 6, 2023
1 parent c19affb commit c047411
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 39 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,14 @@ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, face_detec
In order to specify the device (GPU or CPU) on which the code will run one can explicitly pass the device flag:

```python
import torch
import face_alignment

# cuda for CUDA
# cuda for CUDA, mps for Apple M1/2 GPUs.
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device='cpu')

# running using lower precision
fa = fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, dtype=torch.bfloat16, device='cuda')
```

Please also see the ``examples`` folder
Expand All @@ -85,10 +89,10 @@ Please also see the ``examples`` folder

```python

# dlib
# dlib (fast, may miss faces)
model = FaceAlignment(landmarks_type= LandmarksType.TWO_D, face_detector='dlib')

# SFD
# SFD (likely best results, but slowest)
model = FaceAlignment(landmarks_type= LandmarksType.TWO_D, face_detector='sfd')

# Blazeface (front camera model)
Expand Down
15 changes: 8 additions & 7 deletions face_alignment/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ class NetworkSize(IntEnum):

class FaceAlignment:
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
device='cuda', flip_input=False, face_detector='sfd', face_detector_kwargs=None, verbose=False):
device='cuda', dtype=torch.float32, flip_input=False, face_detector='sfd', face_detector_kwargs=None, verbose=False):
self.device = device
self.flip_input = flip_input
self.landmarks_type = landmarks_type
self.verbose = verbose
self.dtype = dtype

if version.parse(torch.__version__) < version.parse('1.5.0'):
raise ImportError(f'Unsupported pytorch version detected. Minimum supported version of pytorch: 1.5.0\
Expand Down Expand Up @@ -84,15 +85,15 @@ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
self.face_alignment_net = torch.jit.load(
load_file_from_url(models_urls.get(pytorch_version, default_model_urls)[network_name]))

self.face_alignment_net.to(device)
self.face_alignment_net.to(device, dtype=dtype)
self.face_alignment_net.eval()

# Initialiase the depth prediciton network
if landmarks_type == LandmarksType.THREE_D:
self.depth_prediciton_net = torch.jit.load(
load_file_from_url(models_urls.get(pytorch_version, default_model_urls)['depth']))

self.depth_prediciton_net.to(device)
self.depth_prediciton_net.to(device, dtype=dtype)
self.depth_prediciton_net.eval()

def get_landmarks(self, image_or_path, detected_faces=None, return_bboxes=False, return_landmark_score=False):
Expand Down Expand Up @@ -159,13 +160,13 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None, return_bb
inp = torch.from_numpy(inp.transpose(
(2, 0, 1))).float()

inp = inp.to(self.device)
inp = inp.to(self.device, dtype=self.dtype)
inp.div_(255.0).unsqueeze_(0)

out = self.face_alignment_net(inp).detach()
if self.flip_input:
out += flip(self.face_alignment_net(flip(inp)).detach(), is_label=True)
out = out.cpu().numpy()
out = out.to(device='cpu', dtype=torch.float32).numpy()

pts, pts_img, scores = get_preds_fromhm(out, center.numpy(), scale)
pts, pts_img = torch.from_numpy(pts), torch.from_numpy(pts_img)
Expand All @@ -181,9 +182,9 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None, return_bb
heatmaps = torch.from_numpy(
heatmaps).unsqueeze_(0)

heatmaps = heatmaps.to(self.device)
heatmaps = heatmaps.to(self.device, dtype=self.dtype)
depth_pred = self.depth_prediciton_net(
torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1)
torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1).to(dtype=torch.float32)
pts_img = torch.cat(
(pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1)

Expand Down
33 changes: 4 additions & 29 deletions face_alignment/detection/blazeface/net_blazeface.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,38 +114,13 @@ def _define_back_model_layers(self):
self.backbone = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=24, kernel_size=5, stride=2, padding=0, bias=True),
nn.ReLU(inplace=True),

BlazeBlock(24, 24),
BlazeBlock(24, 24),
BlazeBlock(24, 24),
BlazeBlock(24, 24),
BlazeBlock(24, 24),
BlazeBlock(24, 24),
BlazeBlock(24, 24),
*[BlazeBlock(24, 24) for _ in range(7)],
BlazeBlock(24, 24, stride=2),
BlazeBlock(24, 24),
BlazeBlock(24, 24),
BlazeBlock(24, 24),
BlazeBlock(24, 24),
BlazeBlock(24, 24),
BlazeBlock(24, 24),
BlazeBlock(24, 24),
*[BlazeBlock(24, 24) for _ in range(7)],
BlazeBlock(24, 48, stride=2),
BlazeBlock(48, 48),
BlazeBlock(48, 48),
BlazeBlock(48, 48),
BlazeBlock(48, 48),
BlazeBlock(48, 48),
BlazeBlock(48, 48),
BlazeBlock(48, 48),
*[BlazeBlock(48, 48) for _ in range(7)],
BlazeBlock(48, 96, stride=2),
BlazeBlock(96, 96),
BlazeBlock(96, 96),
BlazeBlock(96, 96),
BlazeBlock(96, 96),
BlazeBlock(96, 96),
BlazeBlock(96, 96),
BlazeBlock(96, 96),
*[BlazeBlock(96, 96) for _ in range(7)],
)
self.final = FinalBlazeBlock(96)
self.classifier_8 = nn.Conv2d(96, 2, 1, bias=True)
Expand Down

0 comments on commit c047411

Please sign in to comment.