diff --git a/main.py b/main.py new file mode 100644 index 0000000..a15f1a2 --- /dev/null +++ b/main.py @@ -0,0 +1,219 @@ +import argparse +import csv +import time + +import cv2 +import numpy as np +import tqdm +from PIL import Image +from face_detection import RetinaFace +from torch.utils.data import DataLoader + +from utils.datasets import Datasets +from utils.util import * + + +def train(args): + model = load_model(args, True).cuda() + dataset = Datasets(f'{args.data_dir}', '300W_LP', get_transforms(True), True) + loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) + + criterion = GeodesicLoss().cuda() + optimizer = torch.optim.Adam(model.parameters(), args.lr) + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20], gamma=0.5) + + best_loss = float('inf') + with open('outputs/weights/step.csv', 'w') as log: + logger = csv.DictWriter(log, fieldnames=['epoch', 'Loss', 'Pitch', 'Yaw', 'Roll']) + logger.writeheader() + for epoch in range(args.epochs): + print(('\n' + '%10s' * 3) % ('epoch', 'memory', 'loss')) + p_bar = tqdm.tqdm(loader, total=len(loader)) + model.train() + total_loss = 0 + for i, (samples, labels) in enumerate(p_bar): + samples = samples.cuda() + labels = labels.cuda() + optimizer.zero_grad() + outputs = model(samples) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + total_loss += loss.item() + memory = f'{torch.cuda.memory_reserved() / 1E9:.3g}G' + s = ('%10s' * 2 + '%10.3g') % (f'{epoch + 1}/{args.epochs}', memory, loss.item()) + p_bar.set_description(s) + + avg_loss = total_loss / len(loader) + val_loss, val_pitch, val_yaw, val_roll = test(args, model) + scheduler.step() + + logger.writerow({'Pitch': str(f'{val_pitch:.3f}'), + 'Yaw': str(f'{val_yaw:.3f}'), + 'Roll': str(f'{val_roll:.3f}'), + 'Loss': str(f'{avg_loss:.3f}'), + 'epoch': str(epoch + 1).zfill(3)}) + log.flush() + if val_loss < best_loss: + best_loss = val_loss + torch.save(model.state_dict(), f'{args.save_dir}/weights/best.pt') + print(f'Epoch {epoch + 1}: New best model saved with val_loss: {best_loss:.3f}') + + torch.save(model.state_dict(), f'{args.save_dir}/weights/last.pt') + scheduler.step() + + torch.cuda.empty_cache() + print('Training completed.') + + +@torch.no_grad() +def test(args, model=None): + dataset = Datasets(f'{args.data_dir}', 'AFLW2K', get_transforms(False), False) + loader = DataLoader(dataset, batch_size=64) + if model is None: + model = load_model(args, False).cuda() + # model = model.float() + model.half() + model.eval() + + total, y_error, p_error, r_error = 0, 0.0, 0.0, 0.0 + for sample, label in tqdm.tqdm(loader, ('%10s' * 3) % ('Pitch', 'Yaw', 'Roll')): + sample = sample.cuda() + sample = sample.half() + total += label.size(0) + + p_gt = label[:, 0].float() * 180 / np.pi + y_gt = label[:, 1].float() * 180 / np.pi + r_gt = label[:, 2].float() * 180 / np.pi + + output = model(sample) + euler = compute_euler(output) * 180 / np.pi + + p_pred = euler[:, 0].cpu() + y_pred = euler[:, 1].cpu() + r_pred = euler[:, 2].cpu() + + p_error += torch.sum(torch.min(torch.stack((torch.abs(p_gt - p_pred), + torch.abs(p_pred + 360 - p_gt), + torch.abs(p_pred - 360 - p_gt), + torch.abs(p_pred + 180 - p_gt), + torch.abs(p_pred - 180 - p_gt))), 0)[0]) + + y_error += torch.sum(torch.min(torch.stack((torch.abs(y_gt - y_pred), + torch.abs(y_pred + 360 - y_gt), + torch.abs(y_pred - 360 - y_gt), + torch.abs(y_pred + 180 - y_gt), + torch.abs(y_pred - 180 - y_gt))), 0)[0]) + + r_error += torch.sum(torch.min(torch.stack((torch.abs(r_gt - r_pred), + torch.abs(r_pred + 360 - r_gt), + torch.abs(r_pred - 360 - r_gt), + torch.abs(r_pred + 180 - r_gt), + torch.abs(r_pred - 180 - r_gt))), 0)[0]) + + p_error, y_error, r_error = p_error / total, y_error / total, r_error / total + avg_error = (p_error + y_error + r_error) / (3 * total) + print(('%10.3g' * 3) % (p_error, y_error, r_error)) + + model.float() # for training + return avg_error, p_error, y_error, r_error + + +@torch.no_grad() +def inference(args): + model = load_model(args, False).cuda() + model.eval() + detector = RetinaFace(0) + + cap = cv2.VideoCapture(0) + frame_width = int(cap.get(3)) + frame_height = int(cap.get(4)) + out = cv2.VideoWriter(f'{args.save_dir}/output.avi', cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 25, + (frame_width, frame_height)) + # Check if the webcam is opened correctly + if not cap.isOpened(): + raise IOError("Cannot open webcam") + + with torch.no_grad(): + while True: + ret, frame = cap.read() + + faces = detector(frame) + + for box, landmarks, score in faces: + + # Print the location of each face in this image + if score < .95: + continue + x_min = int(box[0]) + y_min = int(box[1]) + x_max = int(box[2]) + y_max = int(box[3]) + bbox_width = abs(x_max - x_min) + bbox_height = abs(y_max - y_min) + + x_min = max(0, x_min - int(0.2 * bbox_height)) + y_min = max(0, y_min - int(0.2 * bbox_width)) + x_max = x_max + int(0.2 * bbox_height) + y_max = y_max + int(0.2 * bbox_width) + + img = frame[y_min:y_max, x_min:x_max] + img = Image.fromarray(img) + img = img.convert('RGB') + img = get_transforms(False)(img) + + img = torch.Tensor(img[None, :]).cuda() + + c = cv2.waitKey(1) + if c == 27: + break + + start = time.time() + R_pred = model(img) + end = time.time() + print('Head pose estimation: %2f ms' % ((end - start) * 1000.)) + + euler = compute_euler( + R_pred) * 180 / np.pi + p_pred_deg = euler[:, 0].cpu() + y_pred_deg = euler[:, 1].cpu() + r_pred_deg = euler[:, 2].cpu() + + # utils.draw_axis(frame, y_pred_deg, p_pred_deg, r_pred_deg, left+int(.5*(right-left)), top, size=100) + plot_pose_cube(frame, y_pred_deg, p_pred_deg, r_pred_deg, x_min + int(.5 * ( + x_max - x_min)), y_min + int(.5 * (y_max - y_min)), size=bbox_width) + + cv2.imshow("Demo", frame) + out.write(frame) + cv2.waitKey(5) + cap.release() + out.release() + + # Closes all the frames + cv2.destroyAllWindows() + + +def main(): + parser = argparse.ArgumentParser(description='Head Pose Estimation') + parser.add_argument('--model_name', type=str, default='RepVGG-A2') + parser.add_argument('--data_dir', type=str, default='../../Datasets/HPE') + parser.add_argument('--save-dir', type=str, default='./outputs') + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--lr', type=float, default=0.0001) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--train', action='store_true') + parser.add_argument('--test', action='store_true') + parser.add_argument('--inference', default=True, action='store_true') + + args = parser.parse_args() + if args.train: + train(args) + if args.test: + test(args) + if args.inference: + inference(args) + + +if __name__ == "__main__": + main() diff --git a/models/nets.py b/models/nets.py new file mode 100644 index 0000000..601653d --- /dev/null +++ b/models/nets.py @@ -0,0 +1,158 @@ +import torch +import torch.nn as nn +from utils import util + + +def conv_bn(inp, oup, kernel_size, stride, padding, groups=1): + result = nn.Sequential() + result.add_module('conv', nn.Conv2d(inp, oup, kernel_size, stride, padding, groups=groups, bias=False)) + result.add_module('bn', nn.BatchNorm2d(oup)) + return result + + +class RepVGGBlock(nn.Module): + def __init__(self, inp, oup, k, s=1, p=0, d=1, gr=1, padding_mode='zeros', deploy=False): + super(RepVGGBlock, self).__init__() + self.inp = inp + self.groups = gr + self.deploy = deploy + self.nonlinearity = nn.ReLU() + self.se = nn.Identity() + + assert k == 3 + assert p == 1 + + padding = p - k // 2 + + if deploy: + self.rbr_reparam = nn.Conv2d(inp, oup, k, s, p, d, gr, bias=True, padding_mode=padding_mode) + else: + self.rbr_identity = nn.BatchNorm2d(inp) if oup == inp and s == 1 else None + self.rbr_dense = conv_bn(inp, oup, k, s, p, groups=gr) + self.rbr_1x1 = conv_bn(inp, oup, 1, s, padding, groups=gr) + + def forward(self, x): + if hasattr(self, 'rbr_reparam'): + return self.nonlinearity(self.se(self.rbr_reparam(x))) + + if self.rbr_identity is None: + out = 0 + else: + out = self.rbr_identity(x) + + return self.nonlinearity(self.se(self.rbr_dense(x) + self.rbr_1x1(x) + out)) + + +class RepVGG(nn.Module): + def __init__(self, layers, width=None, num_cls=1000, gr_map=None, deploy=False): + super(RepVGG, self).__init__() + self.deploy = deploy + self.cur_layer_idx = 1 + self.gr_map = gr_map or dict() + + assert len(width) == 4 + assert 0 not in self.gr_map + + self.inp = min(64, int(64 * width[0])) + + self.stage0 = RepVGGBlock(3, self.inp, 3, 2, 1, deploy=self.deploy) + self.stage1 = self._make_stage(int(64 * width[0]), layers[0], stride=2) + self.stage2 = self._make_stage(int(128 * width[1]), layers[1], stride=2) + self.stage3 = self._make_stage(int(256 * width[2]), layers[2], stride=2) + self.stage4 = self._make_stage(int(512 * width[3]), layers[3], stride=2) + self.gap = nn.AdaptiveAvgPool2d(output_size=1) + self.linear = nn.Linear(int(512 * width[3]), num_cls) + + def _make_stage(self, oup, layer, stride): + strides = [stride] + [1] * (layer - 1) + layers = [] + for stride in strides: + cur_groups = self.gr_map.get(self.cur_layer_idx, 1) + layers.append(RepVGGBlock(self.inp, oup, 3, stride, p=1, gr=cur_groups, deploy=self.deploy, )) + self.inp = oup + self.cur_layer_idx += 1 + return nn.Sequential(*layers) + + def forward(self, x): + out = self.stage0(x) + out = self.stage1(out) + out = self.stage2(out) + out = self.stage3(out) + out = self.stage4(out) + out = self.gap(out) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def create_model(backbone_name, num_cls=1000): + optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26] + g2_map = {l: 2 for l in optional_groupwise_layers} + g4_map = {l: 4 for l in optional_groupwise_layers} + default_group_map = None + net_configs = { + 'RepVGG-A0': ([2, 4, 14, 1], [0.75, 0.75, 0.75, 2.5], default_group_map), + 'RepVGG-A1': ([2, 4, 14, 1], [1, 1, 1, 2.5], default_group_map), + 'RepVGG-A2': ([2, 4, 14, 1], [1.5, 1.5, 1.5, 2.75], default_group_map), + 'RepVGG-B0': ([4, 6, 16, 1], [1, 1, 1, 2.5], default_group_map), + 'RepVGG-B1': ([4, 6, 16, 1], [2, 2, 2, 4], default_group_map), + 'RepVGG-B1g2': ([4, 6, 16, 1], [2, 2, 2, 4], g2_map), + 'RepVGG-B1g4': ([4, 6, 16, 1], [2, 2, 2, 4], g4_map), + 'RepVGG-B2': ([4, 6, 16, 1], [2.5, 2.5, 2.5, 5], default_group_map), + 'RepVGG-B2g2': ([4, 6, 16, 1], [2.5, 2.5, 2.5, 5], g2_map), + 'RepVGG-B2g4': ([4, 6, 16, 1], [2.5, 2.5, 2.5, 5], g4_map), + 'RepVGG-B3': ([4, 6, 16, 1], [3, 3, 3, 5], default_group_map), + 'RepVGG-B3g2': ([4, 6, 16, 1], [3, 3, 3, 5], g2_map), + 'RepVGG-B3g4': ([4, 6, 16, 1], [3, 3, 3, 5], g4_map), + } + + def model_constructor(deploy): + configs = net_configs.get(backbone_name) + if configs is None: + raise ValueError(f"Network {backbone_name} is not supported.") + layers, width, gr_map = configs[:3] + return RepVGG(layers, width, num_cls, gr_map, deploy=deploy) + + return model_constructor + + +class HPE(nn.Module): + def __init__(self, model_name, weight, deploy, pretrained=True): + super(HPE, self).__init__() + repvgg = create_model(model_name) + backbone = repvgg(deploy) + if pretrained: + checkpoint = torch.load(weight) + if 'state_dict' in checkpoint: + checkpoint = checkpoint['state_dict'] + ckpt = {k.replace('module.', ''): v for k, + v in checkpoint.items()} # strip the names + backbone.load_state_dict(ckpt) + + self.layer0 = backbone.stage0 + self.layer1 = backbone.stage1 + self.layer2 = backbone.stage2 + self.layer3 = backbone.stage3 + self.layer4 = backbone.stage4 + self.gap = nn.AdaptiveAvgPool2d(output_size=1) + + last_channel = 0 + for n, m in self.layer4.named_modules(): + if ('rbr_dense' in n or 'rbr_reparam' in n) and isinstance(m, nn.Conv2d): + last_channel = m.out_channels + + fea_dim = last_channel + + self.linear_reg = nn.Linear(fea_dim, 6) + + def forward(self, x): + x = self.layer0(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.gap(x) + x = torch.flatten(x, 1) + x = self.linear_reg(x) + + return util.compute_rotation(x) diff --git a/outputs/weights/step.csv b/outputs/weights/step.csv new file mode 100644 index 0000000..b951718 --- /dev/null +++ b/outputs/weights/step.csv @@ -0,0 +1,101 @@ +epoch,Loss,Pitch,Yaw,Roll +001,0.085,6.628,4.732,4.674 +002,0.050,5.570,5.128,4.183 +003,0.042,5.877,4.997,4.244 +004,0.037,5.616,4.860,3.812 +005,0.034,5.476,4.771,4.200 +006,0.027,5.533,4.533,4.070 +007,0.024,5.290,4.494,3.708 +008,0.023,5.322,4.355,3.793 +009,0.022,5.228,4.404,3.722 +010,0.021,5.098,4.121,3.480 +011,0.018,5.210,4.278,3.661 +012,0.018,5.186,4.184,3.592 +013,0.017,5.276,4.076,3.689 +014,0.017,5.085,4.072,3.553 +015,0.016,5.130,4.017,3.524 +016,0.016,5.434,4.132,3.880 +017,0.016,5.309,4.076,3.722 +018,0.016,5.055,4.021,3.434 +019,0.015,5.132,3.937,3.636 +020,0.015,5.182,3.980,3.675 +021,0.015,5.202,3.976,3.605 +022,0.015,5.021,3.922,3.524 +023,0.015,5.176,4.018,3.565 +024,0.014,5.133,3.986,3.590 +025,0.014,5.237,3.940,3.723 +026,0.014,5.212,3.895,3.682 +027,0.014,5.387,3.916,3.841 +028,0.014,5.118,3.864,3.610 +029,0.014,5.197,3.937,3.592 +030,0.014,5.205,3.959,3.683 +031,0.014,5.143,3.860,3.634 +032,0.013,5.055,3.821,3.536 +033,0.013,5.117,3.830,3.580 +034,0.013,5.277,3.886,3.652 +035,0.013,5.099,3.850,3.637 +036,0.013,5.030,3.985,3.523 +037,0.013,5.209,3.901,3.692 +038,0.013,5.103,3.998,3.571 +039,0.013,5.182,3.949,3.746 +040,0.013,5.028,3.906,3.533 +041,0.013,5.184,3.972,3.767 +042,0.012,5.144,3.861,3.618 +043,0.012,5.164,3.878,3.705 +044,0.012,5.042,3.854,3.533 +045,0.012,5.109,3.853,3.587 +046,0.012,5.184,3.917,3.717 +047,0.012,5.135,3.833,3.691 +048,0.012,5.185,3.892,3.697 +049,0.012,5.163,3.960,3.681 +050,0.012,5.107,3.794,3.569 +051,0.012,5.082,3.852,3.578 +052,0.012,5.048,3.820,3.557 +053,0.012,4.965,3.761,3.511 +054,0.012,4.977,3.853,3.533 +055,0.012,4.992,3.805,3.502 +056,0.011,4.989,3.854,3.527 +057,0.011,4.986,3.845,3.572 +058,0.011,5.043,3.796,3.561 +059,0.011,4.987,3.748,3.466 +060,0.011,4.915,3.724,3.444 +061,0.011,4.974,3.869,3.507 +062,0.011,5.030,3.767,3.556 +063,0.011,5.021,3.832,3.573 +064,0.011,5.038,3.883,3.549 +065,0.011,5.126,3.878,3.691 +066,0.011,5.014,3.778,3.557 +067,0.011,5.088,3.871,3.641 +068,0.011,5.033,3.835,3.491 +069,0.011,4.966,3.803,3.530 +070,0.011,5.098,3.755,3.625 +071,0.011,5.087,3.823,3.562 +072,0.011,5.043,3.739,3.557 +073,0.011,5.138,3.912,3.673 +074,0.010,5.146,3.827,3.642 +075,0.010,5.037,3.795,3.567 +076,0.010,5.047,3.767,3.585 +077,0.010,4.999,3.807,3.559 +078,0.010,5.034,3.801,3.585 +079,0.010,5.008,3.737,3.507 +080,0.010,5.048,3.795,3.558 +081,0.010,4.974,3.880,3.528 +082,0.010,5.083,3.826,3.601 +083,0.010,5.197,3.936,3.738 +084,0.010,4.969,3.852,3.505 +085,0.010,4.976,3.783,3.489 +086,0.010,4.953,3.787,3.524 +087,0.010,5.139,3.810,3.682 +088,0.010,5.017,3.808,3.525 +089,0.010,4.979,3.896,3.539 +090,0.010,5.059,3.828,3.631 +091,0.010,5.166,3.854,3.761 +092,0.010,4.990,3.772,3.567 +093,0.010,5.001,3.798,3.601 +094,0.010,4.894,3.771,3.507 +095,0.010,5.076,3.859,3.608 +096,0.010,5.099,3.784,3.677 +097,0.010,4.964,3.844,3.531 +098,0.010,5.101,3.849,3.647 +099,0.010,5.015,3.921,3.635 +100,0.010,5.142,3.810,3.655 diff --git a/utils/datasets.py b/utils/datasets.py new file mode 100644 index 0000000..9690724 --- /dev/null +++ b/utils/datasets.py @@ -0,0 +1,111 @@ +import os +from pathlib import Path + +import numpy as np +import scipy.io as sio +import torch +from PIL import Image, ImageFilter +from torch.utils.data import Dataset + + +class Datasets(Dataset): + def __init__(self, data_dir, data_name, transform=None, train_mode=True): + self.data_dir = data_dir + self.transform = transform + self.train_mode = train_mode + file_path = Path(f'{self.data_dir}/{data_name}/files.txt') + if not file_path.exists(): + self.load_label(f'{self.data_dir}/{data_name}/') + self.samples = open(file_path).read().splitlines() + + def __getitem__(self, idx): + image = Image.open(f'{self.samples[idx]}.jpg') + image = image.convert('RGB') + label = sio.loadmat(f'{self.samples[idx]}.mat') + pt2d = label['pt2d'] + x_min = min(pt2d[0, :]) + y_min = min(pt2d[1, :]) + x_max = max(pt2d[0, :]) + y_max = max(pt2d[1, :]) + + if self.train_mode: + k = np.random.random_sample() * 0.2 + 0.2 + x_min -= 0.6 * k * abs(x_max - x_min) + y_min -= 2 * k * abs(y_max - y_min) + x_max += 0.6 * k * abs(x_max - x_min) + y_max += 0.6 * k * abs(y_max - y_min) + else: + k = 0.20 + x_min -= 2 * k * abs(x_max - x_min) + y_min -= 2 * k * abs(y_max - y_min) + x_max += 2 * k * abs(x_max - x_min) + y_max += 0.6 * k * abs(y_max - y_min) + + image = image.crop((int(x_min), int(y_min), int(x_max), int(y_max))) + pre_pose_params = label['Pose_Para'][0] + pose = pre_pose_params[:3] + pitch, yaw, roll = pose[0], pose[1], pose[2] + + rnd = np.random.random_sample() + if self.train_mode and rnd < 0.5: + yaw = -yaw + roll = -roll + image = image.transpose(Image.FLIP_LEFT_RIGHT) + rnd = np.random.random_sample() + + if self.train_mode and rnd < 0.05: + image = image.filter(ImageFilter.BLUR) + + if self.transform is not None: + image = self.transform(image) + + if self.train_mode: + return image, torch.FloatTensor(self.get_rotation(pitch, yaw, roll)) + else: + return image, torch.FloatTensor([pitch, yaw, roll]) + + def __len__(self): + return len(self.samples) + + @staticmethod + def get_rotation(x, y, z): + rotate_x = np.array([[1, 0, 0], + [0, np.cos(x), -np.sin(x)], + [0, np.sin(x), np.cos(x)]]) + # y + rotate_y = np.array([[np.cos(y), 0, np.sin(y)], + [0, 1, 0], + [-np.sin(y), 0, np.cos(y)]]) + # z + rotate_z = np.array([[np.cos(z), -np.sin(z), 0], + [np.sin(z), np.cos(z), 0], + [0, 0, 1]]) + + rotation = rotate_z.dot(rotate_y.dot(rotate_x)) + return rotation + + @staticmethod + def load_label(data_dir): + f_counter, rej_counter = 0, 0 + file = open(f'{data_dir}files.txt', 'w') + + for root, dirs, files in os.walk(data_dir): + for f in files: + if f[-4:].lower().endswith('.jpg'): + mat_path = os.path.join(root, f.replace('.jpg', '.mat')) + mat = sio.loadmat(mat_path) + pre_pose_ = mat['Pose_Para'][0] + pose = pre_pose_[:3] + pitch = pose[0] * 180 / np.pi + yaw = pose[1] * 180 / np.pi + roll = pose[2] * 180 / np.pi + if all(abs(angle) <= 99 for angle in (pitch, yaw, roll)): + if f_counter > 0: + file.write('\n') + file.write(os.path.join(root, f[:-4])) + f_counter += 1 + else: + rej_counter += 1 + +# data_dir = '../../../Datasets/HPE' +# Datasets(data_dir, transform=None)[0] diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000..a63cdb8 --- /dev/null +++ b/utils/util.py @@ -0,0 +1,190 @@ +from pathlib import Path + +import torch +import torchvision.transforms as T + + +class GeodesicLoss(torch.nn.Module): + def __init__(self, eps=1e-7): + super().__init__() + self.eps = eps + + def forward(self, m1, m2): + m = torch.bmm(m1, m2.transpose(1, 2)) # batch*3*3 + + cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2 + theta = torch.acos(torch.clamp(cos, -1 + self.eps, 1 - self.eps)) + + return torch.mean(theta) + + +def compute_euler(rotation_matrices): + batch = rotation_matrices.shape[0] + R = rotation_matrices + sy = torch.sqrt(R[:, 0, 0] * R[:, 0, 0] + R[:, 1, 0] * R[:, 1, 0]) + singular = sy < 1e-6 + singular = singular.float() + + x = torch.atan2(R[:, 2, 1], R[:, 2, 2]) + y = torch.atan2(-R[:, 2, 0], sy) + z = torch.atan2(R[:, 1, 0], R[:, 0, 0]) + + xs = torch.atan2(-R[:, 1, 2], R[:, 1, 1]) + ys = torch.atan2(-R[:, 2, 0], sy) + zs = R[:, 1, 0] * 0 + + gpu = rotation_matrices.get_device() + if gpu < 0: + out_euler = torch.autograd.Variable(torch.zeros(batch, 3)).to(torch.device('cpu')) + else: + out_euler = torch.autograd.Variable(torch.zeros(batch, 3)).to(torch.device('cuda:%d' % gpu)) + out_euler[:, 0] = x * (1 - singular) + xs * singular + out_euler[:, 1] = y * (1 - singular) + ys * singular + out_euler[:, 2] = z * (1 - singular) + zs * singular + + return out_euler + + +def normalize_vector(v): + batch = v.shape[0] + v_mag = torch.sqrt(v.pow(2).sum(1)) # batch + gpu = v_mag.get_device() + if gpu < 0: + eps = torch.autograd.Variable(torch.FloatTensor([1e-8])).to(torch.device('cpu')) + else: + eps = torch.autograd.Variable(torch.FloatTensor([1e-8])).to(torch.device('cuda:%d' % gpu)) + v_mag = torch.max(v_mag, eps) + v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1]) + v = v / v_mag + return v + + +# u, v batch*n +def cross_product(u, v): + batch = u.shape[0] + # print (u.shape) + # print (v.shape) + i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1] + j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2] + k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0] + + out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1) # batch*3 + + return out + + +def compute_rotation(poses): + x_raw = poses[:, 0:3] # batch*3 + y_raw = poses[:, 3:6] # batch*3 + + x = normalize_vector(x_raw) # batch*3 + z = cross_product(x, y_raw) # batch*3 + z = normalize_vector(z) # batch*3 + y = cross_product(z, x) # batch*3 + + x = x.view(-1, 3, 1) + y = y.view(-1, 3, 1) + z = z.view(-1, 3, 1) + matrix = torch.cat((x, y, z), 2) # batch*3*3 + return matrix + + +def download_weights(model_name, save_dir): + import os + import gdown + net_name = { + 'RepVGG-A0': '13Gn8rq1PztoMEgK7rCOPMUYHjGzk-w11', + 'RepVGG-A1': '19lX6lNKSwiO5STCvvu2xRTKcPsSfWAO1', + 'RepVGG-A2': '1PvtYTOX4gd-1VHX8LoT7s6KIyfTKOf8G', + 'RepVGG-B0': '18g7YziprUky7cX6L6vMJ_874PP8tbtKx', + 'RepVGG-B1': '1VlCfXXiaJjNjzQBy3q7C3H2JcxoL0fms', + 'RepVGG-B1g2': '1PL-m9n3g0CEPrSpf3KwWEOf9_ZG-Ux1Z', + 'RepVGG-B1g4': '1WXxhyRDTgUjgkofRV1bLnwzTsFWRwZ0k', + 'RepVGG-B2': '1cFgWJkmf9U1L1UmJsA8UT__kyd3xuY_y', + 'RepVGG-B2g4': '1LZ61o5XH6u1n3_tXIgKII7XqKoqqracI', + 'RepVGG-B3': '1wBpq5317iPKk3-qblBHnx35bY_WumAlU', + 'RepVGG-B3g4': '1s7PxIP-oYB1a94_qzHyzfXAbbI24GYQ8', + } + weight_id = net_name.get(model_name) + url = f'https://drive.google.com/uc?id={weight_id}' + output = os.path.join(save_dir, 'weights', model_name + '.pth') + gdown.download(url, output, quiet=False) + + +def load_model(args, for_training=True): + from models.nets import HPE + + file_path = Path(f'{args.save_dir}/weights/{args.model_name}.pth') + if not file_path.exists(): + print('Downloading ImageNet weights...') + download_weights(args.model_name, args.save_dir) + + model = HPE(args.model_name, file_path, False, for_training).cuda() + if not for_training: + saved_dict = torch.load(f'{args.save_dir}/weights/best.pt') + if 'model_state_dict' in saved_dict: + model.load_state_dict(saved_dict['model_state_dict']) + else: + model.load_state_dict(saved_dict) + return model + + +def get_transforms(for_training=True): + if for_training: + return T.Compose([ + T.RandomResizedCrop(size=224, scale=(0.8, 1)), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + else: + return T.Compose([ + T.Resize(256), + T.CenterCrop(224), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + +def plot_pose_cube(img, yaw, pitch, roll, tdx=None, tdy=None, size=150.): + import cv2 + import numpy as np + from math import sin, cos + p = pitch * np.pi / 180 + y = -(yaw * np.pi / 180) + r = roll * np.pi / 180 + if tdx != None and tdy != None: + face_x = tdx - 0.50 * size + face_y = tdy - 0.50 * size + + else: + height, width = img.shape[:2] + face_x = width / 2 - 0.5 * size + face_y = height / 2 - 0.5 * size + + x1 = size * (cos(y) * cos(r)) + face_x + y1 = size * (cos(p) * sin(r) + cos(r) * sin(p) * sin(y)) + face_y + x2 = size * (-cos(y) * sin(r)) + face_x + y2 = size * (cos(p) * cos(r) - sin(p) * sin(y) * sin(r)) + face_y + x3 = size * (sin(y)) + face_x + y3 = size * (-cos(y) * sin(p)) + face_y + + # Draw base in red + cv2.line(img, (int(face_x), int(face_y)), (int(x1), int(y1)), (0, 0, 255), 3) + cv2.line(img, (int(face_x), int(face_y)), (int(x2), int(y2)), (0, 0, 255), 3) + cv2.line(img, (int(x2), int(y2)), (int(x2 + x1 - face_x), int(y2 + y1 - face_y)), (0, 0, 255), 3) + cv2.line(img, (int(x1), int(y1)), (int(x1 + x2 - face_x), int(y1 + y2 - face_y)), (0, 0, 255), 3) + # Draw pillars in blue + cv2.line(img, (int(face_x), int(face_y)), (int(x3), int(y3)), (255, 0, 0), 2) + cv2.line(img, (int(x1), int(y1)), (int(x1 + x3 - face_x), int(y1 + y3 - face_y)), (255, 0, 0), 2) + cv2.line(img, (int(x2), int(y2)), (int(x2 + x3 - face_x), int(y2 + y3 - face_y)), (255, 0, 0), 2) + cv2.line(img, (int(x2 + x1 - face_x), int(y2 + y1 - face_y)), + (int(x3 + x1 + x2 - 2 * face_x), int(y3 + y2 + y1 - 2 * face_y)), (255, 0, 0), 2) + # Draw top in green + cv2.line(img, (int(x3 + x1 - face_x), int(y3 + y1 - face_y)), + (int(x3 + x1 + x2 - 2 * face_x), int(y3 + y2 + y1 - 2 * face_y)), (0, 255, 0), 2) + cv2.line(img, (int(x2 + x3 - face_x), int(y2 + y3 - face_y)), + (int(x3 + x1 + x2 - 2 * face_x), int(y3 + y2 + y1 - 2 * face_y)), (0, 255, 0), 2) + cv2.line(img, (int(x3), int(y3)), (int(x3 + x1 - face_x), int(y3 + y1 - face_y)), (0, 255, 0), 2) + cv2.line(img, (int(x3), int(y3)), (int(x3 + x2 - face_x), int(y3 + y2 - face_y)), (0, 255, 0), 2) + + return img