-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add working version of TRUST albedo model
- Loading branch information
1 parent
709a4a2
commit 9e68382
Showing
8 changed files
with
180 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .trust import TRUST |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | ||
# holder of all proprietary rights on this computer program. | ||
# Using this computer program means that you agree to the terms | ||
# in the LICENSE file included with this software distribution. | ||
# Any use not explicitly granted by the LICENSE is prohibited. | ||
# | ||
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung | ||
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | ||
# for Intelligent Systems. All rights reserved. | ||
# | ||
# For comments or questions, please email us at [email protected] | ||
# For commercial licensing contact, please contact [email protected] | ||
|
||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
from torchvision.utils import save_image | ||
|
||
from ..defaults import DEVICE | ||
from ..recon.flame.deca.encoders import Resnet50Encoder | ||
from ..recon.flame.decoders import FlameTex | ||
|
||
|
||
class TRUST(nn.Module): | ||
def __init__(self, device=DEVICE): | ||
# avoid circular import | ||
from ..data import get_external_data_config | ||
super().__init__() | ||
self.device = device | ||
self._cfg = get_external_data_config() | ||
self._create_submodels() | ||
self.to(device).eval() | ||
|
||
def _create_submodels(self): | ||
|
||
self.param_dict = { | ||
'n_tex': 54, | ||
'n_light': 27, | ||
'n_scenelight': 3, | ||
'n_facelight': 27 | ||
} | ||
|
||
self.E_albedo = Resnet50Encoder(outsize=self.param_dict['n_tex'], version='v2').to(self.device) | ||
checkpoint = torch.load(self._cfg['trust_albedo_encoder_path'], map_location=self.device) | ||
self.E_albedo.load_state_dict(checkpoint["E_albedo"]) | ||
|
||
self.E_scene_light = Resnet50Encoder(outsize=self.param_dict['n_scenelight']).to(self.device) | ||
checkpoint = torch.load(self._cfg['trust_scene_light_encoder_path'], map_location=self.device) | ||
self.E_scene_light.load_state_dict(checkpoint["E_scene_light"]) | ||
|
||
self.E_face_light = Resnet50Encoder(outsize=self.param_dict['n_facelight']).to(self.device) | ||
checkpoint = torch.load(self._cfg['trust_face_light_encoder_path'], map_location=self.device) | ||
self.E_face_light.load_state_dict(checkpoint["E_face_light"]) | ||
|
||
# decoding | ||
self.D_flame_tex = FlameTex(model_path=self._cfg['trust_albedo_decoder_path'], | ||
n_tex=54).to(self.device) # texture layer | ||
|
||
def _fuse_light(self, E_scene_light_pred, E_face_light_pred): | ||
|
||
normalized_sh_params = F.normalize(E_face_light_pred, p=1, dim=1) | ||
lightcode = E_scene_light_pred.unsqueeze(1).expand(-1, 9, -1) * normalized_sh_params | ||
|
||
return lightcode, E_scene_light_pred, normalized_sh_params | ||
|
||
def _encode(self, imgs, scene_imgs): | ||
''' | ||
:param images: | ||
:param scene_images: | ||
:param face_lighting: | ||
:param scene_lighting: | ||
:return: | ||
''' | ||
|
||
B, C, H, W = imgs.size() | ||
E_scene_light_pred = self.E_scene_light(scene_imgs) # B x 3 | ||
#E_face_light_pred = self.E_face_light(imgs).reshape(B, 9, 3) | ||
E_scene_light_pred = E_scene_light_pred[..., None, None].repeat(1, 1, H, W) | ||
|
||
imgs_cond = torch.cat((E_scene_light_pred, imgs), dim=1) | ||
tex_code = self.E_albedo(imgs_cond) | ||
|
||
return tex_code | ||
|
||
def _decode(self, tex_code): | ||
albedo = self.D_flame_tex(tex_code) | ||
return albedo | ||
|
||
def forward(self, imgs, scene_imgs): | ||
|
||
if imgs.dtype == torch.uint8: | ||
imgs = imgs.float() | ||
|
||
if imgs.max() >= 1.0: | ||
imgs = imgs.div(255.0) | ||
|
||
if scene_imgs.dtype == torch.uint8: | ||
scene_imgs = scene_imgs.float() | ||
|
||
if scene_imgs.max() >= 1.0: | ||
scene_imgs = scene_imgs.div(255.0) | ||
|
||
tex_code = self._encode(imgs, scene_imgs) | ||
albedo = self._decode(tex_code) | ||
|
||
|
||
out = {'albedo': albedo} | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import torch | ||
import pytest | ||
from pathlib import Path | ||
from torchvision.utils import save_image | ||
from medusa.albedo import TRUST | ||
from medusa.crop import RandomSquareCropModel, BboxCropModel | ||
from medusa.data import get_example_image | ||
|
||
|
||
@pytest.mark.parametrize("n_faces", [1, 2, 3, 4]) | ||
def test_trust(n_faces): | ||
|
||
face_crop_model = BboxCropModel(output_size=(224, 224), scale=1.6, scale_orig=1) | ||
scene_crop_model = RandomSquareCropModel(output_size=(224, 224)) | ||
|
||
img = get_example_image(n_faces=n_faces) | ||
face_crop = face_crop_model(img)['imgs_crop'] | ||
scene_crop = scene_crop_model(img)['imgs_crop'] | ||
scene_crop = scene_crop.repeat(face_crop.shape[0], 1, 1, 1) | ||
|
||
model = TRUST() | ||
with torch.inference_mode(): | ||
out = model(face_crop, scene_crop) | ||
|
||
dir_out = Path(__file__).parent / "test_viz/albedo" | ||
save_image(out['albedo'].float(), dir_out / f'albedo_n-{n_faces}.png', normalize=True) | ||
save_image(face_crop.float(), dir_out / f'crop-face_n-{n_faces}.png', normalize=True) | ||
save_image(scene_crop.float(), dir_out / f'crop-scene_n-{n_faces}.png', normalize=True) |
Empty file.