Skip to content

Commit

Permalink
Add working version of TRUST albedo model
Browse files Browse the repository at this point in the history
  • Loading branch information
lukassnoek committed Aug 23, 2024
1 parent 709a4a2 commit 9e68382
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 38 deletions.
1 change: 1 addition & 0 deletions medusa/albedo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .trust import TRUST
111 changes: 111 additions & 0 deletions medusa/albedo/trust.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-
#
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# Using this computer program means that you agree to the terms
# in the LICENSE file included with this software distribution.
# Any use not explicitly granted by the LICENSE is prohibited.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# For comments or questions, please email us at [email protected]
# For commercial licensing contact, please contact [email protected]

import torch
from torch import nn
import torch.nn.functional as F
from torchvision.utils import save_image

from ..defaults import DEVICE
from ..recon.flame.deca.encoders import Resnet50Encoder
from ..recon.flame.decoders import FlameTex


class TRUST(nn.Module):
def __init__(self, device=DEVICE):
# avoid circular import
from ..data import get_external_data_config
super().__init__()
self.device = device
self._cfg = get_external_data_config()
self._create_submodels()
self.to(device).eval()

def _create_submodels(self):

self.param_dict = {
'n_tex': 54,
'n_light': 27,
'n_scenelight': 3,
'n_facelight': 27
}

self.E_albedo = Resnet50Encoder(outsize=self.param_dict['n_tex'], version='v2').to(self.device)
checkpoint = torch.load(self._cfg['trust_albedo_encoder_path'], map_location=self.device)
self.E_albedo.load_state_dict(checkpoint["E_albedo"])

self.E_scene_light = Resnet50Encoder(outsize=self.param_dict['n_scenelight']).to(self.device)
checkpoint = torch.load(self._cfg['trust_scene_light_encoder_path'], map_location=self.device)
self.E_scene_light.load_state_dict(checkpoint["E_scene_light"])

self.E_face_light = Resnet50Encoder(outsize=self.param_dict['n_facelight']).to(self.device)
checkpoint = torch.load(self._cfg['trust_face_light_encoder_path'], map_location=self.device)
self.E_face_light.load_state_dict(checkpoint["E_face_light"])

# decoding
self.D_flame_tex = FlameTex(model_path=self._cfg['trust_albedo_decoder_path'],
n_tex=54).to(self.device) # texture layer

def _fuse_light(self, E_scene_light_pred, E_face_light_pred):

normalized_sh_params = F.normalize(E_face_light_pred, p=1, dim=1)
lightcode = E_scene_light_pred.unsqueeze(1).expand(-1, 9, -1) * normalized_sh_params

return lightcode, E_scene_light_pred, normalized_sh_params

def _encode(self, imgs, scene_imgs):
'''
:param images:
:param scene_images:
:param face_lighting:
:param scene_lighting:
:return:
'''

B, C, H, W = imgs.size()
E_scene_light_pred = self.E_scene_light(scene_imgs) # B x 3
#E_face_light_pred = self.E_face_light(imgs).reshape(B, 9, 3)
E_scene_light_pred = E_scene_light_pred[..., None, None].repeat(1, 1, H, W)

imgs_cond = torch.cat((E_scene_light_pred, imgs), dim=1)
tex_code = self.E_albedo(imgs_cond)

return tex_code

def _decode(self, tex_code):
albedo = self.D_flame_tex(tex_code)
return albedo

def forward(self, imgs, scene_imgs):

if imgs.dtype == torch.uint8:
imgs = imgs.float()

if imgs.max() >= 1.0:
imgs = imgs.div(255.0)

if scene_imgs.dtype == torch.uint8:
scene_imgs = scene_imgs.float()

if scene_imgs.max() >= 1.0:
scene_imgs = scene_imgs.div(255.0)

tex_code = self._encode(imgs, scene_imgs)
albedo = self._decode(tex_code)


out = {'albedo': albedo}

return out
3 changes: 2 additions & 1 deletion medusa/crop/square_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def forward(self, imgs):

out = []
for i in range(imgs.shape[0]):
img = imgs[i, ...]
img = imgs[i, ...] # 3 x h x w
scene_h, scene_w = img.shape[1:]

if scene_w > scene_h:
Expand All @@ -53,6 +53,7 @@ def forward(self, imgs):
else:
scene = img.clone()

#scene = img[:, :224, :224]
scene = self.resizer(scene)
out.append(scene)

Expand Down
2 changes: 1 addition & 1 deletion medusa/data/default_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
insightface_path: ~/.medusa_ext_data/antelopev2
insightface_path: ~/.medusa_ext_data/buffalo_l
deca_path: ~/.medusa_ext_data/deca_model.tar
emoca_path: ~/.medusa_ext_data/emoca.ckpt
flame_masks_path: ~/.medusa_ext_data/FLAME/FLAME_masks.pkl
Expand Down
Binary file added medusa/data/flame/uv_face_cheek_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
73 changes: 37 additions & 36 deletions tests/test_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch
import pytest
from conftest import _is_device_compatible

from torchvision.utils import save_image

from medusa.containers.results import BatchResults
from medusa.crop import AlignCropModel, BboxCropModel, RandomSquareCropModel
from medusa.data import get_example_image, get_example_video
Expand All @@ -19,12 +20,13 @@ def test_random_square_crop(n_faces, device):

imgs = get_example_image(n_faces, device=device)
model = RandomSquareCropModel(output_size=(224, 224), device=device)
imgs_crop = model(imgs)
assert(imgs_crop.shape == (len(n_faces), 3, 224, 224))
out = model(imgs)
assert(out['imgs_crop'].shape == (len(n_faces), 3, 224, 224))


@pytest.mark.parametrize("Model", [AlignCropModel, BboxCropModel])
@pytest.mark.parametrize("lm_name", ["2d106det", "1k3d68"])
@pytest.mark.parametrize("n_faces", [0, 1, 2, 3, 4, [0, 1], [0, 1, 2], [0, 1, 2, 3, 4]])
@pytest.mark.parametrize("n_faces", [2, 3, 4, [0, 1], [0, 1, 2], [0, 1, 2, 3, 4]])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_crop_model(Model, lm_name, n_faces, device):
"""Generic tests for crop models."""
Expand Down Expand Up @@ -65,44 +67,43 @@ def test_crop_model(Model, lm_name, n_faces, device):

out_crop.visualize(str(f_out) + '_uncropped.jpg', imgs, template=template)
if out_crop.imgs_crop is not None:
from torchvision.utils import save_image
save_image(out_crop.imgs_crop.float(), str(f_out) + '_cropped.jpg', nrow=1, normalize=True)

# @torch.inference_mode()
# @pytest.mark.parametrize("Model", [AlignCropModel, BboxCropModel])
# @pytest.mark.parametrize("n_faces", [0, 1, 2, 3, 4])
# def test_crop_model_vid(Model, n_faces):
# """Test of crop model applied to videos and the visualization thereof."""

# video_test = get_example_video(n_faces)
@torch.inference_mode()
@pytest.mark.parametrize("Model", [AlignCropModel, BboxCropModel])
@pytest.mark.parametrize("n_faces", [0, 1, 2, 3, 4])
def test_crop_model_vid(Model, n_faces):
"""Test of crop model applied to videos and the visualization thereof."""

# if Model == BboxCropModel:
# crop_size = (224, 224)
# model = Model("2d106det", crop_size)
# else:
# crop_size = (448, 448)
# model = Model(crop_size)
video_test = get_example_video(n_faces)

# results = model.crop_faces_video(video_test, save_imgs=True)
if Model == BboxCropModel:
crop_size = (224, 224)
model = Model("2d106det", crop_size)
else:
crop_size = (448, 448)
model = Model(crop_size)

# if getattr(results, "lms", None) is None:
# return
results = model.crop_faces_video(video_test, save_imgs=True)

# results.sort_faces(attr="lms")
if getattr(results, "lms", None) is None:
return

# if 'GITHUB_ACTIONS' in os.environ:
# # Too slow for Github Actions
# return
results.sort_faces(attr="lms")

# f_out = Path(__file__).parent / f"test_viz/crop/{str(model)}_{video_test.stem}.mp4"
# template = getattr(model, "template", None)
# results.visualize(
# f_out,
# results.imgs,
# template=template,
# video=True,
# crop_size=crop_size,
# show_cropped=True,
# )
if 'GITHUB_ACTIONS' in os.environ:
# Too slow for Github Actions
return

# torch.cuda.empty_cache()
f_out = Path(__file__).parent / f"test_viz/crop/{str(model)}_{video_test.stem}.mp4"
template = getattr(model, "template", None)
results.visualize(
f_out,
results.imgs,
template=template,
video=True,
crop_size=crop_size,
show_cropped=True,
)

torch.cuda.empty_cache()
28 changes: 28 additions & 0 deletions tests/test_trust.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
import pytest
from pathlib import Path
from torchvision.utils import save_image
from medusa.albedo import TRUST
from medusa.crop import RandomSquareCropModel, BboxCropModel
from medusa.data import get_example_image


@pytest.mark.parametrize("n_faces", [1, 2, 3, 4])
def test_trust(n_faces):

face_crop_model = BboxCropModel(output_size=(224, 224), scale=1.6, scale_orig=1)
scene_crop_model = RandomSquareCropModel(output_size=(224, 224))

img = get_example_image(n_faces=n_faces)
face_crop = face_crop_model(img)['imgs_crop']
scene_crop = scene_crop_model(img)['imgs_crop']
scene_crop = scene_crop.repeat(face_crop.shape[0], 1, 1, 1)

model = TRUST()
with torch.inference_mode():
out = model(face_crop, scene_crop)

dir_out = Path(__file__).parent / "test_viz/albedo"
save_image(out['albedo'].float(), dir_out / f'albedo_n-{n_faces}.png', normalize=True)
save_image(face_crop.float(), dir_out / f'crop-face_n-{n_faces}.png', normalize=True)
save_image(scene_crop.float(), dir_out / f'crop-scene_n-{n_faces}.png', normalize=True)
Empty file added tests/test_viz/albedo/.gitkeep
Empty file.

0 comments on commit 9e68382

Please sign in to comment.