Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
chaojie committed Jan 8, 2024
1 parent 6bba33e commit 919db1c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
29 changes: 13 additions & 16 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def process_traj(points_str,frame_length):

return optical_flow

def save_results(video, fps=10,traj="[]",draw_traj_dot=False,cameras=[],draw_camera_dot=False):
def save_results(video, fps=10,traj="[]",draw_traj_dot=False,cameras=[],draw_camera_dot=False,context_overlap=0):

# b,c,t,h,w
video = video.detach().cpu()
Expand Down Expand Up @@ -117,7 +117,7 @@ def save_results(video, fps=10,traj="[]",draw_traj_dot=False,cameras=[],draw_cam
#writer.append_data(img)

#writer.close()
return torch.cat(tuple(outframes), dim=0).unsqueeze(0)
return torch.cat(tuple(outframes[context_overlap:]), dim=0).unsqueeze(0)

MOTION_CAMERA_OPTIONS = ["U", "D", "L", "R", "O", "O_0.2x", "O_0.4x", "O_1.0x", "O_2.0x", "O_0.2x", "O_0.2x", "Round-RI", "Round-RI_90", "Round-RI-120", "Round-ZoomIn", "SPIN-ACW-60", "SPIN-CW-60", "I", "I_0.2x", "I_0.4x", "I_1.0x", "I_2.0x", "1424acd0007d40b5", "d971457c81bca597", "018f7907401f2fef", "088b93f15ca8745d", "b133a504fc90a2d1"]

Expand Down Expand Up @@ -419,24 +419,21 @@ def run_inference(self,model,clip,vae,ddim_sampler,positive, negative,traj_list,
comfy_path = os.path.dirname(folder_paths.__file__)
pred_x0_path = os.path.join(comfy_path, 'custom_nodes/ComfyUI-MotionCtrl/pred_x0.pt')
x_inter_path = os.path.join(comfy_path, 'custom_nodes/ComfyUI-MotionCtrl/x_inter.pt')

randt=torch.randn([noise_shape[0],noise_shape[1],frame_length-context_overlap,noise_shape[3],noise_shape[4]], device=device)
randt_np=randt.detach().cpu().numpy()

if context_overlap>0:
if os.path.exists(pred_x0_path):
pre_x0=torch.load(pred_x0_path)
pre_x0_np=pre_x0[-1].detach().cpu().numpy()
pre_x0_np_overlap = np.concatenate((pre_x0_np[:,:,-context_overlap:], randt_np), axis=2)
x0=torch.tensor(pre_x0_np_overlap, device=device)
if os.path.exists(x_inter_path):
pre_x_T=torch.load(x_inter_path)

pre_x0_np=pre_x0[-1].detach().cpu().numpy()
pre_x_T_np=pre_x_T[-1].detach().cpu().numpy()

randt=torch.randn([noise_shape[0],noise_shape[1],frame_length-context_overlap,noise_shape[3],noise_shape[4]], device=device)
randt_np=randt.detach().cpu().numpy()

pre_x0_np_overlap = np.concatenate((pre_x0_np[:,:,-context_overlap:], randt_np), axis=2)
x0=torch.tensor(pre_x0_np_overlap, device=device)
pre_x_T_np_overlap = np.concatenate((pre_x_T_np[:,:,-context_overlap:], randt_np), axis=2)
x_T=torch.tensor(pre_x_T_np_overlap, device=device)

pre_x_T_np=pre_x_T[-1].detach().cpu().numpy()
pre_x_T_np_overlap = np.concatenate((pre_x_T_np[:,:,-context_overlap:], randt_np), axis=2)
x_T=torch.tensor(pre_x_T_np_overlap, device=device)

for _ in range(n_samples):
if ddim_sampler is not None:
samples, intermediates = ddim_sampler.sample(S=ddim_steps,
Expand Down Expand Up @@ -479,7 +476,7 @@ def run_inference(self,model,clip,vae,ddim_sampler,positive, negative,traj_list,

torch.save(intermediates['x_inter'], x_inter_path)
torch.save(intermediates['pred_x0'], pred_x0_path)
ret = save_results(batch_variants, fps=10,traj=traj_list,draw_traj_dot=draw_traj_dot,cameras=rt_list,draw_camera_dot=draw_camera_dot)
ret = save_results(batch_variants, fps=10,traj=traj_list,draw_traj_dot=draw_traj_dot,cameras=rt_list,draw_camera_dot=draw_camera_dot,context_overlap=context_overlap)
#print(ret)
return ret

Expand Down
10 changes: 9 additions & 1 deletion turbo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,15 @@ def handle_message(message):
@socketio.on('camera_poses')
def handle_message(camera_poses):
print(f'camera_poses:{request.sid} {camera_poses}')

cams=json.loads(camera_poses["camera_poses"])
trajs=json.loads(camera_poses["trajs"])
if len(cams)>1 and len(trajs)>1:
prompt["60"]["inputs"]["infer_mode"] = "control both camera and object motion"
elif len(trajs)>1:
prompt["60"]["inputs"]["infer_mode"] = "control object trajectory"
else:
prompt["60"]["inputs"]["infer_mode"] = "control camera poses"

prompt["60"]["inputs"]["prompt"] = camera_poses["prompt"]
prompt["60"]["inputs"]["camera"] = camera_poses["camera_poses"]
prompt["60"]["inputs"]["traj"] = camera_poses["trajs"]
Expand Down

0 comments on commit 919db1c

Please sign in to comment.