Skip to content

Commit

Permalink
add context_overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
chaojie committed Jan 8, 2024
1 parent f0873a8 commit 6bba33e
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 31 deletions.
101 changes: 91 additions & 10 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,17 +232,50 @@ def INPUT_TYPES(cls):
"prompt": ("STRING", {"multiline": True, "default":"a rose swaying in the wind"}),
"camera": ("STRING", {"multiline": True, "default":"[[1,0,0,0,0,1,0,0,0,0,1,0.2]]"}),
"traj": ("STRING", {"multiline": True, "default":"[[117, 102]]"}),
"infer_mode": (MODE, {"default":"control both camera and object motion"})
"infer_mode": (MODE, {"default":"control both camera and object motion"}),
"context_overlap": ("INT", {"default": 0, "min": 0, "max": 32}),
}
}

RETURN_TYPES = ("CONDITIONING", "CONDITIONING","TRAJ_LIST","RT_LIST","TRAJ_FEATURES","RT","NOISE_SHAPE")
RETURN_NAMES = ("positive", "negative","traj_list","rt_list","traj","rt","noise_shape")
RETURN_TYPES = ("CONDITIONING", "CONDITIONING","TRAJ_LIST","RT_LIST","TRAJ_FEATURES","RT","NOISE_SHAPE","INT")
RETURN_NAMES = ("positive", "negative","traj_list","rt_list","traj","rt","noise_shape","context_overlap")
FUNCTION = "load_cond"
CATEGORY = "motionctrl"

def load_cond(self, model, prompt, camera, traj,infer_mode):
def load_cond(self, model, prompt, camera, traj,infer_mode,context_overlap):
comfy_path = os.path.dirname(folder_paths.__file__)
camera_align_file = os.path.join(comfy_path, 'custom_nodes/ComfyUI-MotionCtrl/camera.json')
traj_align_file = os.path.join(comfy_path, 'custom_nodes/ComfyUI-MotionCtrl/traj.json')
frame_length=model.temporal_length

camera_align=json.loads(camera)
for i in range(frame_length):
if len(camera_align)<=i:
camera_align.append(camera_align[len(camera_align)-1])
camera=json.dumps(camera_align)
traj_align=json.loads(traj)
for i in range(frame_length):
if len(traj_align)<=i:
traj_align.append(traj_align[len(traj_align)-1])
traj=json.dumps(traj_align)

if context_overlap>0:
if os.path.exists(camera_align_file):
with open(camera_align_file, 'r') as file:
pre_camera_align=json.load(file)
camera_align=pre_camera_align[:context_overlap]+camera_align[:-context_overlap]

if os.path.exists(traj_align_file):
with open(traj_align_file, 'r') as file:
pre_traj_align=json.load(file)
traj_align=pre_traj_align[:context_overlap]+traj_align[:-context_overlap]

with open(camera_align_file, 'w') as file:
json.dump(camera_align, file)

with open(traj_align_file, 'w') as file:
json.dump(traj_align, file)

prompts = prompt
RT = process_camera(camera,frame_length).reshape(-1,12)
RT_list = process_camera_list(camera,frame_length)
Expand Down Expand Up @@ -319,7 +352,7 @@ def load_cond(self, model, prompt, camera, traj,infer_mode):
un_motion = None
uc = {"features_adapter": un_motion, "uc": uc}

return (cond,uc,traj,RT_list,traj_features,RT,noise_shape)
return (cond,uc,traj,RT_list,traj_features,RT,noise_shape,context_overlap)



Expand All @@ -340,7 +373,8 @@ def INPUT_TYPES(cls):
"rt": ("RT",),
"steps": ("INT", {"default": 50}),
"seed": ("INT", {"default": 1234}),
"noise_shape":("NOISE_SHAPE",)
"noise_shape":("NOISE_SHAPE",),
"context_overlap": ("INT", {"default": 0, "min": 0, "max": 32}),
},
"optional": {
"traj_tool": ("STRING",{"multiline": False, "default": "https://chaojie.github.io/ComfyUI-MotionCtrl/tools/draw.html"}),
Expand All @@ -353,9 +387,10 @@ def INPUT_TYPES(cls):
FUNCTION = "run_inference"
CATEGORY = "motionctrl"

def run_inference(self,model,clip,vae,ddim_sampler,positive, negative,traj_list,rt_list,traj,rt,steps,seed,noise_shape,traj_tool="https://chaojie.github.io/ComfyUI-MotionCtrl/tools/draw.html",draw_traj_dot=False,draw_camera_dot=False):
def run_inference(self,model,clip,vae,ddim_sampler,positive, negative,traj_list,rt_list,traj,rt,steps,seed,noise_shape,context_overlap,traj_tool="https://chaojie.github.io/ComfyUI-MotionCtrl/tools/draw.html",draw_traj_dot=False,draw_camera_dot=False):
frame_length=model.temporal_length

device = model.betas.device
print(f'frame_length{frame_length}')
#noise_shape = [1, 4, 16, 32, 32]
unconditional_guidance_scale = 7.5
unconditional_guidance_scale_temporal = None
Expand All @@ -374,9 +409,37 @@ def run_inference(self,model,clip,vae,ddim_sampler,positive, negative,traj_list,

batch_images=[]
batch_variants = []
intermediates = {}

x0=None
x_T=None
pre_x0=None
pre_x_T=None

comfy_path = os.path.dirname(folder_paths.__file__)
pred_x0_path = os.path.join(comfy_path, 'custom_nodes/ComfyUI-MotionCtrl/pred_x0.pt')
x_inter_path = os.path.join(comfy_path, 'custom_nodes/ComfyUI-MotionCtrl/x_inter.pt')

if context_overlap>0:
if os.path.exists(pred_x0_path):
pre_x0=torch.load(pred_x0_path)
if os.path.exists(x_inter_path):
pre_x_T=torch.load(x_inter_path)

pre_x0_np=pre_x0[-1].detach().cpu().numpy()
pre_x_T_np=pre_x_T[-1].detach().cpu().numpy()

randt=torch.randn([noise_shape[0],noise_shape[1],frame_length-context_overlap,noise_shape[3],noise_shape[4]], device=device)
randt_np=randt.detach().cpu().numpy()

pre_x0_np_overlap = np.concatenate((pre_x0_np[:,:,-context_overlap:], randt_np), axis=2)
x0=torch.tensor(pre_x0_np_overlap, device=device)
pre_x_T_np_overlap = np.concatenate((pre_x_T_np[:,:,-context_overlap:], randt_np), axis=2)
x_T=torch.tensor(pre_x_T_np_overlap, device=device)

for _ in range(n_samples):
if ddim_sampler is not None:
samples, _ = ddim_sampler.sample(S=ddim_steps,
samples, intermediates = ddim_sampler.sample(S=ddim_steps,
conditioning=positive,
batch_size=noise_shape[0],
shape=noise_shape[1:],
Expand All @@ -388,16 +451,34 @@ def run_inference(self,model,clip,vae,ddim_sampler,positive, negative,traj_list,
conditional_guidance_scale_temporal=unconditional_guidance_scale_temporal,
features_adapter=traj,
pose_emb=rt,
cond_T=cond_T
cond_T=cond_T,
x0=x0,
x_T=x_T
)
#print(f'{samples}')
## reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants.append(batch_images)
'''
batch_images = model.decode_first_stage(intermediates['pred_x0'][0])
batch_variants.append(batch_images)
batch_images = model.decode_first_stage(intermediates['pred_x0'][1])
batch_variants.append(batch_images)
batch_images = model.decode_first_stage(intermediates['pred_x0'][2])
batch_variants.append(batch_images)
batch_images = model.decode_first_stage(intermediates['x_inter'][0])
batch_variants.append(batch_images)
batch_images = model.decode_first_stage(intermediates['x_inter'][1])
batch_variants.append(batch_images)
batch_images = model.decode_first_stage(intermediates['x_inter'][2])
batch_variants.append(batch_images)
'''
## variants, batch, c, t, h, w
batch_variants = torch.stack(batch_variants, dim=1)
batch_variants = batch_variants[0]

torch.save(intermediates['x_inter'], x_inter_path)
torch.save(intermediates['pred_x0'], pred_x0_path)
ret = save_results(batch_variants, fps=10,traj=traj_list,draw_traj_dot=draw_traj_dot,cameras=rt_list,draw_camera_dot=draw_camera_dot)
#print(ret)
return ret
Expand Down
2 changes: 1 addition & 1 deletion turbo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,5 @@ def index():
return render_template('index.html')

if __name__ == '__main__':
socketio.run(app, host='0.0.0.0', port=5017, debug=True)
socketio.run(app, host='0.0.0.0', port=5017, debug=True, allow_unsafe_werkzeug=True)
#app.run(host='0.0.0.0', port=5017)
54 changes: 34 additions & 20 deletions turbo/workflow_api_motionctrl_turbo.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,28 @@
},
"class_type": "Load Motionctrl Checkpoint"
},
"57": {
"60": {
"inputs": {
"prompt": "a rose swaying in the wind",
"camera": "[[1,0,0,0,0,1,0,0,0,0,1,0.2]]",
"traj": "[[117, 102]]",
"infer_mode": "control camera poses",
"context_overlap": 4,
"model": [
"56",
0
]
},
"class_type": "Motionctrl Cond"
},
"61": {
"inputs": {
"steps": 12,
"seed": 1860,
"steps": 20,
"seed": 1647,
"context_overlap": [
"60",
7
],
"traj_tool": "https://chaojie.github.io/ComfyUI-MotionCtrl/tools/draw.html",
"draw_traj_dot": false,
"draw_camera_dot": false,
Expand Down Expand Up @@ -37,41 +55,37 @@
"60",
1
],
"traj": [
"traj_list": [
"60",
2
],
"rt": [
"rt_list": [
"60",
3
],
"noise_shape": [
"traj": [
"60",
4
],
"rt": [
"60",
5
],
"noise_shape": [
"60",
6
]
},
"class_type": "Motionctrl Sample Simple"
},
"59": {
"62": {
"inputs": {
"filename_prefix": "motionctrl/motionctrl",
"images": [
"57",
"61",
0
]
},
"class_type": "SaveImage"
},
"60": {
"inputs": {
"prompt": "a rose swaying in the wind",
"camera": "[[1,0,0,0,0,1,0,0,0,0,1,0.2]]",
"traj": "[[117, 102]]",
"model": [
"56",
0
]
},
"class_type": "Motionctrl Cond"
}
}

0 comments on commit 6bba33e

Please sign in to comment.