diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 303ccae31e73..858fd9a01011 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -54,8 +54,34 @@ def all_node_ids(self): def get_original_prompt(self): return self.original_prompt -def get_input_info(class_def, input_name): - valid_inputs = class_def.INPUT_TYPES() +def node_class_info(node_class): + obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] + info = {} + info['input'] = obj_class.INPUT_TYPES() + info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} + info['output'] = obj_class.RETURN_TYPES + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) + info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] + info['name'] = node_class + info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class + info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else '' + info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes") + info['category'] = 'sd' + if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: + info['output_node'] = True + else: + info['output_node'] = False + + if hasattr(obj_class, 'CATEGORY'): + info['category'] = obj_class.CATEGORY + + if hasattr(obj_class, 'OUTPUT_TOOLTIPS'): + info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS + return info + + +def get_input_info(node_info, input_name): + valid_inputs = node_info["input"] input_info = None input_category = None if "required" in valid_inputs and input_name in valid_inputs["required"]: @@ -86,7 +112,7 @@ def __init__(self, dynprompt): def get_input_info(self, unique_id, input_name): class_type = self.dynprompt.get_node(unique_id)["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - return get_input_info(class_def, input_name) + return get_input_info(node_class_info(class_def), input_name) def make_input_strong_link(self, to_node_id, to_input): inputs = self.dynprompt.get_node(to_node_id)["inputs"] @@ -119,7 +145,8 @@ def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None): from_node_id, from_socket = value if subgraph_nodes is not None and from_node_id not in subgraph_nodes: continue - input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + node_info = node_class_info(self.dynprompt.get_node(from_node_id)["class_type"]) + input_type, input_category, input_info = self.get_input_info(node_info, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] if include_lazy or not is_lazy: self.add_strong_link(from_node_id, from_socket, unique_id) diff --git a/execution.py b/execution.py index a7f10172a7e6..c1c089fa1e5e 100644 --- a/execution.py +++ b/execution.py @@ -7,31 +7,109 @@ import traceback from enum import Enum import inspect -from typing import List, Literal, NamedTuple, Optional +from typing import List, Literal, NamedTuple, Optional, Dict, Tuple import torch import nodes import comfy.model_management -from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker +from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker, node_class_info from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID from comfy.cli_args import args -def resolve_dynamic_types(prompt): - output = {} - for node_id in prompt: - node = prompt[node_id] +def get_input_output_types(dynprompt, node_id, info, output_lookup) -> Tuple[Dict[str, str], Dict[str, List[str]]]: + node = dynprompt.get_node(node_id) + input_types: Dict[str, str] = {} + for input_name, input_data in node["inputs"].items(): + if is_link(input_data): + from_node_id, from_socket = input_data + if from_socket < len(info[from_node_id]["output_name"]): + input_types[input_name] = info[from_node_id]["output"][from_socket] + else: + input_types[input_name] = "*" + output_types: Dict[str, List[str]] = {} + for index in range(len(info[node_id]["output_name"])): + output_name = info[node_id]["output_name"][index] + if (node_id, index) not in output_lookup: + continue + for (to_node_id, to_input_name) in output_lookup[(node_id, index)]: + if output_name not in output_types: + output_types[output_name] = [] + if to_input_name in info[to_node_id]["input"]["required"]: + output_types[output_name].append(info[to_node_id]["input"]["required"][to_input_name][0]) + elif to_input_name in info[to_node_id]["input"]["optional"]: + output_types[output_name].append(info[to_node_id]["input"]["optional"][to_input_name][0]) + return input_types, output_types + +def resolve_dynamic_types(dynprompt: DynamicPrompt): + info = {} + output_lookup = {} + entangled = {} + # Pre-fill with class info + for node_id in dynprompt.all_node_ids(): + node = dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + info[node_id] = node_class_info(class_type) + for input_name, input_data in node["inputs"].items(): + if is_link(input_data): + input_tuple = tuple(input_data) + if input_tuple not in output_lookup: + output_lookup[input_tuple] = [] + output_lookup[input_tuple].append((node_id, input_name)) + _, _, extra_info = get_input_info(info[node_id], input_name) + if extra_info is not None and extra_info.get("entangleTypes", False): + from_node_id = input_data[0] + if node_id not in entangled: + entangled[node_id] = [] + if from_node_id not in entangled: + entangled[from_node_id] = [] + + entangled[node_id].append((from_node_id, input_name)) + entangled[from_node_id].append((node_id, input_name)) + + # Evaluate node info + to_resolve = set(info.keys()) + result = {} + while len(to_resolve) > 0: + node_id = to_resolve.pop() + node = dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if hasattr(class_def, "resolve_dynamic_types"): - inputs, outputs, output_names = class_def.resolve_dynamic_types(node_id, prompt) - output[node_id] = { - "input": inputs, - "output": outputs, - "output_name": output_names, - } - return output + entangled_types = {} + for (entangled_id, entangled_name) in entangled.get(node_id, []): + input_types, output_types = get_input_output_types(dynprompt, entangled_id, info, output_lookup) + if entangled_name not in entangled_types: + entangled_types[entangled_name] = [] + entangled_types[entangled_name].append({ + "node_id": entangled_id, + "input_types": input_types, + "output_types": output_types + }) + + input_types, output_types = get_input_output_types(dynprompt, node_id, info, output_lookup) + dynamic_info = class_def.resolve_dynamic_types( + node_id=node_id, + input_types=input_types, + output_types=output_types, + entangled_types=entangled_types + ) + old_info = info[node_id].copy() + info[node_id].update(dynamic_info) + # We changed the info, so we potentially need to resolve adjacent and entangled nodes + if old_info != info[node_id]: + result[node_id] = info[node_id] + for (entangled_node_id, _) in entangled.get(node_id, []): + to_resolve.add(entangled_node_id) + for i in range(len(info[node_id]["output"])): + for (output_node_id, _) in output_lookup.get((node_id, i), []): + to_resolve.add(output_node_id) + for _, input_data in node["inputs"].items(): + if is_link(input_data): + to_resolve.add(input_data[0]) + return result class ExecutionResult(Enum): @@ -557,7 +635,7 @@ def validate_inputs(prompt, item, validated): received_types = {} for x in valid_inputs: - type_input, input_category, extra_info = get_input_info(obj_class, x) + type_input, input_category, extra_info = get_input_info(node_class_info(obj_class), x) assert extra_info is not None if x not in inputs: if input_category == "required": diff --git a/server.py b/server.py index 703f41a343f0..e69fd842d6b5 100644 --- a/server.py +++ b/server.py @@ -29,6 +29,7 @@ from app.user_manager import UserManager from model_filemanager import download_model, DownloadModelStatus from typing import Optional +from comfy_execution.graph import DynamicPrompt, node_class_info class BinaryEventTypes: PREVIEW_IMAGE = 1 @@ -419,37 +420,12 @@ async def get_queue(request): async def get_prompt(request): return web.json_response(self.get_queue_info()) - def node_info(node_class): - obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] - info = {} - info['input'] = obj_class.INPUT_TYPES() - info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} - info['output'] = obj_class.RETURN_TYPES - info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) - info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] - info['name'] = node_class - info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class - info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else '' - info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes") - info['category'] = 'sd' - if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: - info['output_node'] = True - else: - info['output_node'] = False - - if hasattr(obj_class, 'CATEGORY'): - info['category'] = obj_class.CATEGORY - - if hasattr(obj_class, 'OUTPUT_TOOLTIPS'): - info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS - return info - @routes.get("/object_info") async def get_object_info(request): out = {} for x in nodes.NODE_CLASS_MAPPINGS: try: - out[x] = node_info(x) + out[x] = node_class_info(x) except Exception as e: logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") logging.error(traceback.format_exc()) @@ -460,7 +436,7 @@ async def get_object_info_node(request): node_class = request.match_info.get("node_class", None) out = {} if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): - out[node_class] = node_info(node_class) + out[node_class] = node_class_info(node_class) return web.json_response(out) @routes.get("/history") @@ -489,8 +465,8 @@ async def resolve_dynamic_types(request): if 'prompt' not in json_data: return web.json_response({"error": "no prompt"}, status=400) prompt = json_data['prompt'] - logging.info("Resolving dynamic types", prompt) - resolved = execution.resolve_dynamic_types(prompt) + dynprompt = DynamicPrompt(prompt) + resolved = execution.resolve_dynamic_types(dynprompt) return web.json_response(resolved)