Skip to content

Commit

Permalink
Pull missing changes from develop
Browse files Browse the repository at this point in the history
  • Loading branch information
radekd91 committed Feb 13, 2023
1 parent 8aa260a commit 6b1c214
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 7 deletions.
108 changes: 108 additions & 0 deletions gdl/datasets/FaceAlignmentTools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import numpy as np
from pathlib import Path
from gdl.datasets.ImageDatasetHelpers import bbox2point, bbpoint_warp
import skvideo
import types


def align_face(image, landmarks, landmark_type, scale_adjustment, target_size_height, target_size_width=None,):
"""
Returns an image with the face aligned to the center of the image.
:param image: The full resolution image in which to align the face.
:param landmarks: The landmarks of the face in the image (in the original image coordinates).
:param landmark_type: The type of landmarks. Such as 'kpt68' or 'bbox' or 'mediapipe'.
:param scale_adjustment: The scale adjustment to apply to the image.
:param target_size_height: The height of the output image.
:param target_size_width: The width of the output image. If not provided, it is assumed to be the same as target_size_height.
:return: The aligned face image. The image will be in range [0,1].
"""
# landmarks_for_alignment = "mediapipe"
left = landmarks[:,0].min()
top = landmarks[:,1].min()
right = landmarks[:,0].max()
bottom = landmarks[:,1].max()

old_size, center = bbox2point(left, right, top, bottom, type=landmark_type)
size = (old_size * scale_adjustment).astype(np.int32)

img_warped, lmk_warped = bbpoint_warp(image, center, size, target_size_height, target_size_width, landmarks=landmarks)

return img_warped


def align_video(video, centers, sizes, landmarks, target_size_height, target_size_width=None, ):
"""
Returns a video with the face aligned to the center of the image.
:param video: The full resolution video in which to align the face.
:param landmarks: The landmarks of the face in the video (in the original video coordinates).
:param target_size_height: The height of the output video.
:param target_size_width: The width of the output video. If not provided, it is assumed to be the same as target_size_height.
:return: The aligned face video. The video will be in range [0,1].
"""
if isinstance(video, (str, Path)):
video = skvideo.io.vread(video)
elif isinstance(video, (np.ndarray, types.GeneratorType)):
pass
else:
raise ValueError("video must be a string, Path, or numpy array")

aligned_video = []
warped_landmarks = []
if isinstance(video, np.ndarray):
for i in range(len(centers)):
img_warped, lmk_warped = bbpoint_warp(video[i], centers[i], sizes[i],
target_size_height=target_size_height, target_size_width=target_size_width,
landmarks=landmarks[i])
aligned_video.append(img_warped)
warped_landmarks += [lmk_warped]

elif isinstance(video, types.GeneratorType):
for i, frame in enumerate(video):
img_warped, lmk_warped = bbpoint_warp(frame, centers[i], sizes[i],
target_size_height=target_size_height, target_size_width=target_size_width,
landmarks=landmarks[i])
aligned_video.append(img_warped)
warped_landmarks += [lmk_warped]

aligned_video = np.stack(aligned_video, axis=0)
return aligned_video, warped_landmarks


def align_and_save_video(video, out_video_path, centers, sizes, landmarks, target_size_height, target_size_width=None, output_dict=None):
"""
Returns a video with the face aligned to the center of the image.
:param video: The full resolution video in which to align the face.
:param landmarks: The landmarks of the face in the video (in the original video coordinates).
:param target_size_height: The height of the output video.
:param target_size_width: The width of the output video. If not provided, it is assumed to be the same as target_size_height.
:return: The aligned face video. The video will be in range [0,1].
"""
if isinstance(video, (str, Path)):
video = skvideo.io.vread(video)
elif isinstance(video, (np.ndarray, types.GeneratorType)):
pass
else:
raise ValueError("video must be a string, Path, or numpy array")

writer = skvideo.io.FFmpegWriter(str(out_video_path), outputdict=output_dict)
warped_landmarks = []
if isinstance(video, np.ndarray):
for i in range(len(centers)):
img_warped, lmk_warped = bbpoint_warp(video[i], centers[i], sizes[i],
target_size_height=target_size_height, target_size_width=target_size_width,
landmarks=landmarks[i])
img_warped = (img_warped * 255).astype(np.uint8)
writer.writeFrame(img_warped)
warped_landmarks += [lmk_warped]

elif isinstance(video, types.GeneratorType):
for i, frame in enumerate(video):
img_warped, lmk_warped = bbpoint_warp(frame, centers[i], sizes[i],
target_size_height=target_size_height, target_size_width=target_size_width,
landmarks=landmarks[i])
img_warped = (img_warped * 255).astype(np.uint8)
writer.writeFrame(img_warped)
warped_landmarks += [lmk_warped]
writer.close()

return warped_landmarks
43 changes: 36 additions & 7 deletions gdl/datasets/ImageDatasetHelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,25 @@ def bbox2point(left, right, top, bottom, type='bbox'):
'''
if type == 'kpt68':
old_size = (right - left + bottom - top) / 2 * 1.1
center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
center_x = right - (right - left) / 2.0
center_y = bottom - (bottom - top) / 2.0
# center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
elif type == 'bbox':
old_size = (right - left + bottom - top) / 2
center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size * 0.12])
center_x = right - (right - left) / 2.0
center_y = bottom - (bottom - top) / 2.0 + old_size * 0.12
# center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size * 0.12])
elif type == "mediapipe":
old_size = (right - left + bottom - top) / 2 * 1.1
center_x = right - (right - left) / 2.0
center_y = bottom - (bottom - top) / 2.0
# center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
else:
raise NotImplementedError
raise NotImplementedError(f" bbox2point not implemented for {type} ")
if isinstance(center_x, np.ndarray):
center = np.stack([center_x, center_y], axis=1)
else:
center = np.array([center_x, center_y])
return old_size, center


Expand All @@ -53,15 +66,31 @@ def point2transform(center, size, target_size_height, target_size_width):
return tform


def bbpoint_warp(image, center, size, target_size_height, target_size_width=None, output_shape=None, inv=True, landmarks=None):
def bbpoint_warp(image, center, size, target_size_height, target_size_width=None, output_shape=None, inv=True, landmarks=None,
order=3 # order of interpolation, bicubic by default
):
target_size_width = target_size_width or target_size_height
tform = point2transform(center, size, target_size_height, target_size_width)
tf = tform.inverse if inv else tform
output_shape = output_shape or (target_size_height, target_size_width)
dst_image = warp(image, tf, output_shape=output_shape, order=3)
dst_image = warp(image, tf, output_shape=output_shape, order=order)
if landmarks is None:
return dst_image
# points need the matrix
tf_lmk = tform if inv else tform.inverse
dst_landmarks = tf_lmk(landmarks)
if isinstance(landmarks, np.ndarray):
assert isinstance(landmarks, np.ndarray)
tf_lmk = tform if inv else tform.inverse
dst_landmarks = tf_lmk(landmarks[:, :2])
elif isinstance(landmarks, list):
tf_lmk = tform if inv else tform.inverse
dst_landmarks = []
for i in range(len(landmarks)):
dst_landmarks += [tf_lmk(landmarks[i][:, :2])]
elif isinstance(landmarks, dict):
tf_lmk = tform if inv else tform.inverse
dst_landmarks = {}
for key, value in landmarks.items():
dst_landmarks[key] = tf_lmk(landmarks[key][:, :2])
else:
raise ValueError("landmarks must be np.ndarray, list or dict")
return dst_image, dst_landmarks
144 changes: 144 additions & 0 deletions gdl/datasets/VideoFaceDetectionDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Author: Radek Danecek
Copyright (c) 2022, Radek Danecek
All rights reserved.
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# Using this computer program means that you agree to the terms
# in the LICENSE file included with this software distribution.
# Any use not explicitly granted by the LICENSE is prohibited.
#
# Copyright©2022 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# For comments or questions, please email us at [email protected]
# For commercial licensing contact, please contact [email protected]
"""


import numpy as np
import torch
from PIL import Image
from skimage.io import imread
from torchvision.transforms import ToTensor

from gdl.utils.FaceDetector import load_landmark
from gdl.datasets.FaceAlignmentTools import align_face

from skvideo.io import vread, vreader
from types import GeneratorType
import pickle as pkl

class VideoFaceDetectionDataset(torch.utils.data.Dataset):

def __init__(self, video_name, landmark_path, image_transforms=None,
align_landmarks=False, vid_read=None, output_im_range=None,
scale_adjustment=1.25,
target_size_height=256,
target_size_width=256,
):
super().__init__()
self.video_name = video_name
self.landmark_path = landmark_path / "landmarks_original.pkl"
# if landmark_list is not None and len(lanmark_file_name) != len(image_list):
# raise RuntimeError("There must be a landmark for every image")
self.image_transforms = image_transforms
self.vid_read = vid_read or 'skvreader' # 'skvread'
self.prev_index = -1

self.scale_adjustment=scale_adjustment
self.target_size_height=target_size_height
self.target_size_width=target_size_width

self.video_frames = None
if self.vid_read == "skvread":
self.video_frames = vread(str(self.video_name))
elif self.vid_read == "skvreader":
self.video_frames = vreader(str(self.video_name))

with open(self.landmark_path, "rb") as f:
self.landmark_list = pkl.load(f)

with open(landmark_path / "landmark_types.pkl", "rb") as f:
self.landmark_types = pkl.load(f)

self.total_len = 0
self.frame_map = {} # detection index to frame map
self.index_for_frame_map = {} # detection index to frame map
for i in range(len(self.landmark_list)):
for j in range(len(self.landmark_list[i])):
self.frame_map[self.total_len + j] = i
self.index_for_frame_map[self.total_len + j] = j
self.total_len += len(self.landmark_list[i])

self.output_im_range = output_im_range


def __getitem__(self, index):
# if index < len(self.image_list):
# x = self.mnist_data[index]
# raise IndexError("Out of bounds")
if index != self.prev_index+1 and self.vid_read != 'skvread':
raise RuntimeError("This dataset is meant to be accessed in ordered way only (and with 0 or 1 workers)")

frame_index = self.frame_map[index]
detection_in_frame_index = self.index_for_frame_map[index]
landmark = self.landmark_list[frame_index][detection_in_frame_index]
landmark_type = self.landmark_types[frame_index][detection_in_frame_index]

if isinstance(self.video_frames, np.ndarray):
img = self.video_frames[frame_index, ...]
elif isinstance(self.video_frames, GeneratorType):
img = next(self.video_frames)
else:
raise NotImplementedError()

# try:
# if self.vid_read == 'skvread':
# img = vread(self.image_list[index])
# img = img.transpose([2, 0, 1]).astype(np.float32)
# img_torch = torch.from_numpy(img)
# path = str(self.image_list[index])
# elif self.vid_read == 'pil':
# img = Image.open(self.image_list[index])
# img_torch = ToTensor()(img)
# path = str(self.image_list[index])
# # path = f"{index:05d}"
# else:
# raise ValueError(f"Invalid image reading method {self.im_read}")
# except Exception as e:
# print(f"Failed to read '{self.image_list[index]}'. File is probably corrupted. Rerun data processing")
# raise e

# crop out the face
img = align_face(img, landmark, landmark_type, scale_adjustment=1.25, target_size_height=256, target_size_width=256,)
if self.output_im_range == 255:
img = img * 255.0
img = img.astype(np.float32)
img_torch = ToTensor()(img)

# # plot img with pyplot
# import matplotlib.pyplot as plt
# plt.figure()
# plt.imshow(img)
# plt.show()
# # plot image with plotly
# import plotly.graph_objects as go
# fig = go.Figure(data=go.Image(z=img*255.,))
# fig.show()


if self.image_transforms is not None:
img_torch = self.image_transforms(img_torch)

batch = {"image" : img_torch,
# "path" : path
}

self.prev_index += 1
return batch

def __len__(self):
return self.total_len

0 comments on commit 6b1c214

Please sign in to comment.