Skip to content

Commit

Permalink
refactor: 📦 renaming variables - start
Browse files Browse the repository at this point in the history
  • Loading branch information
melMass committed Jul 9, 2024
1 parent b316df6 commit e49535d
Showing 1 changed file with 66 additions and 67 deletions.
133 changes: 66 additions & 67 deletions liveportrait/live_portrait_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,20 @@ def _get_source_frame(self, source_np, idx, total_frames, method):
min(int(ratio * (source_np.shape[0] - 1)), source_np.shape[0] - 1)
]

def execute(
self, source_np, driving_images_np, mismatch_method="repeat", reference_frame=0
):
inference_cfg = self.live_portrait_wrapper.cfg
def execute(self, source_np, driving_images_np, mismatch_method="repeat"):
cfg = self.live_portrait_wrapper.cfg

I_p_lst = []
I_p_paste_lst = []
warped_imgs = []
cropped_imgs = []
driving_lmk_lst = []
R_d_0, x_d_0_info = None, None

rot_0, kp_0_info = None, None

total_frames = driving_images_np.shape[0]

pbar = comfy.utils.ProgressBar(total_frames)

if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
if cfg.flag_eye_retargeting or cfg.flag_lip_retargeting:
driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_images_np)

for i in range(total_frames):
Expand All @@ -89,20 +88,20 @@ def execute(
crop_info["img_crop_256x256"],
)

if inference_cfg.flag_do_crop:
if cfg.flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
else:
I_s = self.live_portrait_wrapper.prepare_source(source_frame_rgb)

x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_c_s = x_s_info["kp"]
kp_src_info = self.live_portrait_wrapper.get_kp_info(I_s)
x_c_s = kp_src_info["kp"]
R_s = get_rotation_matrix(
x_s_info["pitch"], x_s_info["yaw"], x_s_info["roll"]
kp_src_info["pitch"], kp_src_info["yaw"], kp_src_info["roll"]
)
f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
kp_src = self.live_portrait_wrapper.transform_keypoint(kp_src_info)

if inference_cfg.flag_lip_zero:
if cfg.flag_lip_zero:
c_d_lip_before_animation = [0.0]
combined_lip_ratio_tensor_before_animation = (
self.live_portrait_wrapper.calc_combined_lip_ratio(
Expand All @@ -112,13 +111,13 @@ def execute(
# TODO: expose lip_zero_threshold
if (
combined_lip_ratio_tensor_before_animation[0][0]
< inference_cfg.lip_zero_threshold
< cfg.lip_zero_threshold
):
inference_cfg.flag_lip_zero = False
cfg.flag_lip_zero = False
else:
lip_delta_before_animation = (
self.live_portrait_wrapper.retarget_lip(
x_s, combined_lip_ratio_tensor_before_animation
kp_src, combined_lip_ratio_tensor_before_animation
)
)

Expand All @@ -128,7 +127,7 @@ def execute(
[driving_frame_256]
)[0]

if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
if cfg.flag_eye_retargeting or cfg.flag_lip_retargeting:
# driving_lmk_lst = self.cropper.get_retargeting_lmk_info([driving_frame])
input_eye_ratio_lst, input_lip_ratio_lst = (
self.live_portrait_wrapper.calc_retargeting_ratio(
Expand All @@ -142,89 +141,89 @@ def execute(
)

if i == 0:
R_d_0 = R_d
x_d_0_info = x_d_info

if inference_cfg.flag_relative:
R_new = (R_d @ R_d_0.permute(0, 2, 1)) @ R_s
delta_new = x_s_info["exp"] + (x_d_info["exp"] - x_d_0_info["exp"])
scale_new = x_s_info["scale"] * (
x_d_info["scale"] / x_d_0_info["scale"]
rot_0 = R_d
kp_0_info = x_d_info

if cfg.flag_relative:
R_new = (R_d @ rot_0.permute(0, 2, 1)) @ R_s
delta_new = kp_src_info["exp"] + (x_d_info["exp"] - kp_0_info["exp"])
scale_new = kp_src_info["scale"] * (
x_d_info["scale"] / kp_0_info["scale"]
)
t_new = x_s_info["t"] + (x_d_info["t"] - x_d_0_info["t"])
t_new = kp_src_info["t"] + (x_d_info["t"] - kp_0_info["t"])
else:
R_new = R_d
delta_new = x_d_info["exp"]
scale_new = x_s_info["scale"]
scale_new = kp_src_info["scale"]
t_new = x_d_info["t"]

t_new[..., 2].fill_(0) # zero tz
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
if (
not inference_cfg.flag_stitching
and not inference_cfg.flag_eye_retargeting
and not inference_cfg.flag_lip_retargeting
not cfg.flag_stitching
and not cfg.flag_eye_retargeting
and not cfg.flag_lip_retargeting
):
# without stitching or retargeting
if inference_cfg.flag_lip_zero:
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
if cfg.flag_lip_zero:
x_d_i_new += lip_delta_before_animation.reshape(
-1, kp_src.shape[1], 3
)
else:
pass
elif (
inference_cfg.flag_stitching
and not inference_cfg.flag_eye_retargeting
and not inference_cfg.flag_lip_retargeting
cfg.flag_stitching
and not cfg.flag_eye_retargeting
and not cfg.flag_lip_retargeting
):
# with stitching and without retargeting
if inference_cfg.flag_lip_zero:
if cfg.flag_lip_zero:
x_d_i_new = self.live_portrait_wrapper.stitching(
x_s, x_d_i_new
) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
kp_src, x_d_i_new
) + lip_delta_before_animation.reshape(-1, kp_src.shape[1], 3)
else:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
x_d_i_new = self.live_portrait_wrapper.stitching(kp_src, x_d_i_new)
else:
eyes_delta, lip_delta = None, None
if inference_cfg.flag_eye_retargeting:
if cfg.flag_eye_retargeting:
c_d_eyes_i = input_eye_ratio_lst[i]
combined_eye_ratio_tensor = (
self.live_portrait_wrapper.calc_combined_eye_ratio(
c_d_eyes_i, source_lmk
)
)
combined_eye_ratio_tensor = (
combined_eye_ratio_tensor
* inference_cfg.eyes_retargeting_multiplier
combined_eye_ratio_tensor * cfg.eyes_retargeting_multiplier
)
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
eyes_delta = self.live_portrait_wrapper.retarget_eye(
x_s, combined_eye_ratio_tensor
kp_src, combined_eye_ratio_tensor
)
if inference_cfg.flag_lip_retargeting:
if cfg.flag_lip_retargeting:
c_d_lip_i = input_lip_ratio_lst[i]
combined_lip_ratio_tensor = (
self.live_portrait_wrapper.calc_combined_lip_ratio(
c_d_lip_i, source_lmk
)
)
combined_lip_ratio_tensor = (
combined_lip_ratio_tensor
* inference_cfg.lip_retargeting_multiplier
combined_lip_ratio_tensor * cfg.lip_retargeting_multiplier
)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
lip_delta = self.live_portrait_wrapper.retarget_lip(
x_s, combined_lip_ratio_tensor
kp_src, combined_lip_ratio_tensor
)

if inference_cfg.flag_relative: # use x_s
if cfg.flag_relative: # use x_s
x_d_i_new = (
x_s
kp_src
+ (
eyes_delta.reshape(-1, x_s.shape[1], 3)
eyes_delta.reshape(-1, kp_src.shape[1], 3)
if eyes_delta is not None
else 0
)
+ (
lip_delta.reshape(-1, x_s.shape[1], 3)
lip_delta.reshape(-1, kp_src.shape[1], 3)
if lip_delta is not None
else 0
)
Expand All @@ -233,28 +232,28 @@ def execute(
x_d_i_new = (
x_d_i_new
+ (
eyes_delta.reshape(-1, x_s.shape[1], 3)
eyes_delta.reshape(-1, kp_src.shape[1], 3)
if eyes_delta is not None
else 0
)
+ (
lip_delta.reshape(-1, x_s.shape[1], 3)
lip_delta.reshape(-1, kp_src.shape[1], 3)
if lip_delta is not None
else 0
)
)

if inference_cfg.flag_stitching:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
if cfg.flag_stitching:
x_d_i_new = self.live_portrait_wrapper.stitching(kp_src, x_d_i_new)

out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
out = self.live_portrait_wrapper.warp_decode(f_s, kp_src, x_d_i_new)

if inference_cfg.flag_stitching:
x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
if cfg.flag_stitching:
x_d_i_new = self.live_portrait_wrapper.stitching(kp_src, x_d_i_new)

out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
out = self.live_portrait_wrapper.warp_decode(f_s, kp_src, x_d_i_new)
I_p_i = self.live_portrait_wrapper.parse_output(out["out"])[0]
I_p_lst.append(I_p_i)
warped_imgs.append(I_p_i)

# Transform and blend
I_p_i_to_ori = _transform_img(
Expand All @@ -263,14 +262,14 @@ def execute(
dsize=(source_frame_rgb.shape[1], source_frame_rgb.shape[0]),
)

if inference_cfg.flag_pasteback:
if inference_cfg.mask_crop is None:
inference_cfg.mask_crop = cv2.imread(
if cfg.flag_pasteback:
if cfg.mask_crop is None:
cfg.mask_crop = cv2.imread(
make_abs_path("./utils/resources/mask_template.png"),
cv2.IMREAD_COLOR,
)
mask_ori = _transform_img(
inference_cfg.mask_crop,
cfg.mask_crop,
crop_info["M_c2o"],
dsize=(source_frame_rgb.shape[1], source_frame_rgb.shape[0]),
)
Expand All @@ -281,7 +280,7 @@ def execute(
else:
I_p_i_to_ori_blend = I_p_i_to_ori

I_p_paste_lst.append(I_p_i_to_ori_blend)
cropped_imgs.append(I_p_i_to_ori_blend)
pbar.update(1)

return I_p_lst, I_p_paste_lst
return warped_imgs, cropped_imgs

0 comments on commit e49535d

Please sign in to comment.