From a063468444994dd0b83097212986351400a69ef4 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Wed, 21 Aug 2024 18:46:11 -0700 Subject: [PATCH] Implement Dynamic Typing This is a proof of concept to get feedback. Note that it requires the frontend branch of the same name. --- comfy_execution/graph.py | 200 ++++++++++++++++++++++++++++++++-- comfy_execution/node_utils.py | 173 +++++++++++++++++++++++++++++ execution.py | 80 ++++++++------ server.py | 47 +++----- 4 files changed, 423 insertions(+), 77 deletions(-) create mode 100644 comfy_execution/node_utils.py diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 0b5bf189906..9bbdd66857e 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -1,4 +1,5 @@ import nodes +from typing import Set, Tuple, Dict, List from comfy_execution.graph_utils import is_link @@ -15,6 +16,7 @@ class DynamicPrompt: def __init__(self, original_prompt): # The original prompt provided by the user self.original_prompt = original_prompt + self.node_definitions = DynamicNodeDefinitionCache(self) # Any extra pieces of the graph created during execution self.ephemeral_prompt = {} self.ephemeral_parents = {} @@ -27,6 +29,9 @@ def get_node(self, node_id): return self.original_prompt[node_id] raise NodeNotFoundError(f"Node {node_id} not found") + def get_node_definition(self, node_id): + return self.node_definitions.get_node_definition(node_id) + def has_node(self, node_id): return node_id in self.original_prompt or node_id in self.ephemeral_prompt @@ -54,8 +59,188 @@ def all_node_ids(self): def get_original_prompt(self): return self.original_prompt -def get_input_info(class_def, input_name): - valid_inputs = class_def.INPUT_TYPES() +class DynamicNodeDefinitionCache: + def __init__(self, dynprompt: DynamicPrompt): + self.dynprompt = dynprompt + self.definitions = {} + self.inputs_from_output_slot = {} + self.inputs_from_output_node = {} + + def get_node_definition(self, node_id): + if node_id not in self.definitions: + node = self.dynprompt.get_node(node_id) + if node is None: + return None + class_type = node["class_type"] + definition = node_class_info(class_type) + self.definitions[node_id] = definition + return self.definitions[node_id] + + def get_constant_type(self, value): + if isinstance(value, (int, float)): + return "INT,FLOAT" + elif isinstance(value, str): + return "STRING" + elif isinstance(value, bool): + return "BOOL" + else: + return None + + def get_input_output_types(self, node_id) -> Tuple[Dict[str, str], Dict[str, List[str]]]: + node = self.dynprompt.get_node(node_id) + input_types: Dict[str, str] = {} + for input_name, input_data in node["inputs"].items(): + if is_link(input_data): + from_node_id, from_socket = input_data + if from_socket < len(self.definitions[from_node_id]["output_name"]): + input_types[input_name] = self.definitions[from_node_id]["output"][from_socket] + else: + input_types[input_name] = "*" + else: + constant_type = self.get_constant_type(input_data) + if constant_type is not None: + input_types[input_name] = constant_type + output_types: Dict[str, List[str]] = {} + for index in range(len(self.definitions[node_id]["output_name"])): + output_name = self.definitions[node_id]["output_name"][index] + if (node_id, index) not in self.inputs_from_output_slot: + continue + for (to_node_id, to_input_name) in self.inputs_from_output_slot[(node_id, index)]: + if output_name not in output_types: + output_types[output_name] = [] + if to_input_name in self.definitions[to_node_id]["input"]["required"]: + output_types[output_name].append(self.definitions[to_node_id]["input"]["required"][to_input_name][0]) + elif to_input_name in self.definitions[to_node_id]["input"]["optional"]: + output_types[output_name].append(self.definitions[to_node_id]["input"]["optional"][to_input_name][0]) + else: + output_types[output_name].append("*") + return input_types, output_types + + def resolve_dynamic_definitions(self, node_id_set: Set[str]): + entangled = {} + # Pre-fill with class info. Also, build a lookup table for output nodes + for node_id in node_id_set: + node = self.dynprompt.get_node(node_id) + class_type = node["class_type"] + self.definitions[node_id] = node_class_info(class_type) + for input_name, input_data in node["inputs"].items(): + if is_link(input_data): + input_tuple = tuple(input_data) + if input_tuple not in self.inputs_from_output_slot: + self.inputs_from_output_slot[input_tuple] = [] + self.inputs_from_output_slot[input_tuple].append((node_id, input_name)) + if input_tuple[0] not in self.inputs_from_output_node: + self.inputs_from_output_node[input_tuple[0]] = [] + self.inputs_from_output_node[input_tuple[0]].append((node_id, input_name)) + _, _, extra_info = get_input_info(self.definitions[node_id], input_name) + if extra_info is not None and extra_info.get("entangleTypes", False): + from_node_id = input_data[0] + if node_id not in entangled: + entangled[node_id] = [] + if from_node_id not in entangled: + entangled[from_node_id] = [] + + entangled[node_id].append((from_node_id, input_name)) + entangled[from_node_id].append((node_id, input_name)) + + # Evaluate node info + to_resolve = node_id_set.copy() + updated = {} + while len(to_resolve) > 0: + node_id = to_resolve.pop() + node = self.dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, "resolve_dynamic_types"): + entangled_types = {} + for (entangled_id, entangled_name) in entangled.get(node_id, []): + entangled_def = self.get_node_definition(entangled_id) + if entangled_def is None: + continue + input_types = {} + output_types = {} + for input_category, input_list in entangled_def["input"].items(): + for input_name, input_info in input_list.items(): + if isinstance(input_info, tuple) or input_category != "hidden": + input_types[input_name] = input_info[0] + for i in range(len(entangled_def["output"])): + output_name = entangled_def["output_name"][i] + output_types[output_name] = entangled_def["output"][i] + + if entangled_name not in entangled_types: + entangled_types[entangled_name] = [] + entangled_types[entangled_name].append({ + "node_id": entangled_id, + "input_types": input_types, + "output_types": output_types + }) + + input_types, output_types = self.get_input_output_types(node_id) + dynamic_info = class_def.resolve_dynamic_types( + input_types=input_types, + output_types=output_types, + entangled_types=entangled_types + ) + old_info = self.definitions[node_id].copy() + self.definitions[node_id].update(dynamic_info) + updated[node_id] = self.definitions[node_id] + # We changed the info, so we potentially need to resolve adjacent and entangled nodes + if old_info != self.definitions[node_id]: + for (entangled_node_id, _) in entangled.get(node_id, []): + if entangled_node_id in node_id_set: + to_resolve.add(entangled_node_id) + for i in range(len(self.definitions[node_id]["output"])): + for (output_node_id, _) in self.inputs_from_output_slot.get((node_id, i), []): + if output_node_id in node_id_set: + to_resolve.add(output_node_id) + for _, input_data in node["inputs"].items(): + if is_link(input_data): + if input_data[0] in node_id_set: + to_resolve.add(input_data[0]) + for (to_node_id, _) in self.inputs_from_output_node.get(node_id, []): + if to_node_id in node_id_set: + to_resolve.add(to_node_id) + # Because this run may have changed the number of outputs, we may need to run it again + # in order to get those outputs passed as output_types. + to_resolve.add(node_id) + return updated + +def node_class_info(node_class): + if node_class not in nodes.NODE_CLASS_MAPPINGS: + return None + obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] + info = {} + info['input'] = obj_class.INPUT_TYPES() + info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} + info['output'] = obj_class.RETURN_TYPES + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) + info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] + info['name'] = node_class + info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class + info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else '' + info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes") + info['category'] = 'sd' + if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: + info['output_node'] = True + else: + info['output_node'] = False + + if hasattr(obj_class, 'CATEGORY'): + info['category'] = obj_class.CATEGORY + + if hasattr(obj_class, 'OUTPUT_TOOLTIPS'): + info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS + + if getattr(obj_class, "DEPRECATED", False): + info['deprecated'] = True + if getattr(obj_class, "EXPERIMENTAL", False): + info['experimental'] = True + + return info + + +def get_input_info(node_info, input_name): + valid_inputs = node_info["input"] input_info = None input_category = None if "required" in valid_inputs and input_name in valid_inputs["required"]: @@ -84,9 +269,7 @@ def __init__(self, dynprompt): self.blocking = {} # Which nodes are blocked by this node def get_input_info(self, unique_id, input_name): - class_type = self.dynprompt.get_node(unique_id)["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - return get_input_info(class_def, input_name) + return get_input_info(self.dynprompt.get_node_definition(unique_id), input_name) def make_input_strong_link(self, to_node_id, to_input): inputs = self.dynprompt.get_node(to_node_id)["inputs"] @@ -197,11 +380,8 @@ def ux_friendly_pick_node(self, node_list): # for a PreviewImage to display a result as soon as it can # Some other heuristics could probably be used here to improve the UX further. def is_output(node_id): - class_type = self.dynprompt.get_node(node_id)["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: - return True - return False + node_def = self.dynprompt.get_node_definition(node_id) + return node_def['output_node'] for node_id in node_list: if is_output(node_id): diff --git a/comfy_execution/node_utils.py b/comfy_execution/node_utils.py new file mode 100644 index 00000000000..57e4a278e4f --- /dev/null +++ b/comfy_execution/node_utils.py @@ -0,0 +1,173 @@ +import re +from typing import Optional, Tuple + +# This decorator can be used to enable a "template" syntax for types in a node. +# +# Dynamic Types +# When specifying a type for an input or output, you can wrap an arbitrary string in angle brackets to indicate that it is dynamic. For example, the type "" will be the equivalent of "*" (with the commonly used hacks) with the caveat that all inputs/outputs with the same template name ("FOO" in this case) must have the same type. Use multiple different template names if you want to allow types to differ. Note that this only applies within a single instance of a node -- different nodes can have different type resolutions +# +# Wrapping Types +# Rather than using JUST a template type, you can also use a template type with a wrapping type. For example, if you have a node that takes two inputs with the types "" and "Accumulation" respectively, any output can be connected to the "" input. Once that input has a value (let's say an IMAGE), the other input will resolve as well (to Accumulation in this example). +# +# Variadic Inputs +# Sometimes, you want a node to take a dynamic number of inputs. To do this, create an input value that has a name followed by a number sign and a string (e.g. "input#COUNT"). This will cause additional inputs to be added and removed as the user attaches to those sockets. The string after the '#' can be used to ensure that you have the same number of sockets for two different inputs. For example, having inputs named "image#FOO" and "mask#BAR" will allow the number of images and the number of masks to dynamically increase independently. Having inputs named "image#FOO" and "mask#FOO" will ensure that there are the same number of images as masks. +# +# Variadic Input - Same Type +# If you want to have a variadic input with a dynamic type, you can combine the two. For example, if you have an input named "input#COUNT" with the type "", you can attach multiple inputs to that socket. Once you attach a value to one of the inputs, all of the other inputs will resolve to the same type. This is useful for nodes that take a dynamic number of inputs of the same type. +# +# Variadic Input - Different Types +# If you want to have a variadic input with a dynamic type, you can combine the two. For example, if you have an input named "input#COUNT" with the type "", each socket for the input can have a different type. (Internally, this is equivalent to making the type where 1 is the index of this input.) +# +# Restrictions +# - All dynamic inputs must have `"forceInput": True` due to frontend reasons that will hopefully be resolved before merging. + +def TemplateTypeSupport(): + def decorator(cls): + old_input_types = getattr(cls, "INPUT_TYPES") + def new_input_types(cls): + old_types = old_input_types() + new_types = { + "required": {}, + "optional": {}, + "hidden": old_types.get("hidden", {}), + } + for category in ["required", "optional"]: + if category not in old_types: + continue + for key, value in old_types[category].items(): + new_types[category][replace_variadic_suffix(key, 1)] = (template_to_type(value[0]),) + value[1:] + return new_types + setattr(cls, "INPUT_TYPES", classmethod(new_input_types)) + old_outputs = getattr(cls, "RETURN_TYPES") + setattr(cls, "RETURN_TYPES", tuple(template_to_type(x) for x in old_outputs)) + + def resolve_dynamic_types(cls, input_types, output_types, entangled_types): + original_inputs = old_input_types() + + # Step 1 - Find all variadic groups and determine their maximum used index + variadic_group_map = {} + max_group_index = {} + for category in ["required", "optional"]: + for key, value in original_inputs.get(category, {}).items(): + root, group = determine_variadic_group(key) + if root is not None and group is not None: + variadic_group_map[root] = group + for type_map in [input_types, output_types]: + for key in type_map.keys(): + root, index = determine_variadic_suffix(key) + if root is not None and index is not None: + if root in variadic_group_map: + group = variadic_group_map[root] + max_group_index[group] = max(max_group_index.get(group, 0), index) + + # Step 2 - Create variadic arguments + variadic_inputs = { + "required": {}, + "optional": {}, + } + for category in ["required", "optional"]: + for key, value in original_inputs.get(category, {}).items(): + root, group = determine_variadic_group(key) + if root is None or group is None: + # Copy it over as-is + variadic_inputs[category][key] = value + else: + for i in range(1, max_group_index.get(group, 0) + 2): + # Also replace any variadic suffixes in the type (for use with templates) + input_type = value[0] + if isinstance(input_type, str): + input_type = replace_variadic_suffix(input_type, i) + variadic_inputs[category][replace_variadic_suffix(key, i)] = (input_type,value[1]) + + # Step 3 - Resolve template arguments + resolved = {} + for category in ["required", "optional"]: + for key, value in variadic_inputs[category].items(): + if key in input_types: + tkey, tvalue = determine_template_value(value[0], input_types[key]) + if tkey is not None and tvalue is not None: + resolved[tkey] = type_intersection(resolved.get(tkey, "*"), tvalue) + for i in range(len(old_outputs)): + output_name = cls.RETURN_NAMES[i] + if output_name in output_types: + for output_type in output_types[output_name]: + tkey, tvalue = determine_template_value(old_outputs[i], output_type) + if tkey is not None and tvalue is not None: + resolved[tkey] = type_intersection(resolved.get(tkey, "*"), tvalue) + + # Step 4 - Replace templates with resolved types + final_inputs = { + "required": {}, + "optional": {}, + "hidden": original_inputs.get("hidden", {}), + } + for category in ["required", "optional"]: + for key, value in variadic_inputs[category].items(): + final_inputs[category][key] = (template_to_type(value[0], resolved),) + value[1:] + outputs = (template_to_type(x, resolved) for x in old_outputs) + return { + "input": final_inputs, + "output": tuple(outputs), + "output_name": cls.RETURN_NAMES, + "dynamic_counts": max_group_index, + } + setattr(cls, "resolve_dynamic_types", classmethod(resolve_dynamic_types)) + return cls + return decorator + +def type_intersection(a: str, b: str) -> str: + if a == "*": + return b + if b == "*": + return a + if a == b: + return a + aset = set(a.split(',')) + bset = set(b.split(',')) + intersection = aset.intersection(bset) + if len(intersection) == 0: + return "*" + return ",".join(intersection) + +naked_template_regex = re.compile(r"^<(.+)>$") +qualified_template_regex = re.compile(r"^(.+)<(.+)>$") +variadic_template_regex = re.compile(r"([^<]+)#([^>]+)") +variadic_suffix_regex = re.compile(r"([^<]+)(\d+)") + +empty_lookup = {} +def template_to_type(template, key_lookup=empty_lookup): + templ_match = naked_template_regex.match(template) + if templ_match: + return key_lookup.get(templ_match.group(1), "*") + templ_match = qualified_template_regex.match(template) + if templ_match: + resolved = key_lookup.get(templ_match.group(2), "*") + return qualified_template_regex.sub(r"\1<%s>" % resolved, template) + return template + +# Returns the 'key' and 'value' of the template (if any) +def determine_template_value(template: str, actual_type: str) -> Tuple[Optional[str], Optional[str]]: + templ_match = naked_template_regex.match(template) + if templ_match: + return templ_match.group(1), actual_type + templ_match = qualified_template_regex.match(template) + actual_match = qualified_template_regex.match(actual_type) + if templ_match and actual_match and templ_match.group(1) == actual_match.group(1): + return templ_match.group(2), actual_match.group(2) + return None, None + +def determine_variadic_group(template: str) -> Tuple[Optional[str], Optional[str]]: + variadic_match = variadic_template_regex.match(template) + if variadic_match: + return variadic_match.group(1), variadic_match.group(2) + return None, None + +def replace_variadic_suffix(template: str, index: int) -> str: + return variadic_template_regex.sub(lambda match: match.group(1) + str(index), template) + +def determine_variadic_suffix(template: str) -> Tuple[Optional[str], Optional[int]]: + variadic_match = variadic_suffix_regex.match(template) + if variadic_match: + return variadic_match.group(1), int(variadic_match.group(2)) + return None, None + diff --git a/execution.py b/execution.py index 6c386341bfe..f6af89faa5b 100644 --- a/execution.py +++ b/execution.py @@ -7,13 +7,13 @@ import traceback from enum import Enum import inspect -from typing import List, Literal, NamedTuple, Optional +from typing import List, Literal, NamedTuple, Optional, Dict, Tuple import torch import nodes import comfy.model_management -from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker +from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker, node_class_info from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID from comfy.cli_args import args @@ -37,8 +37,8 @@ def get(self, node_id): return self.is_changed[node_id] node = self.dynprompt.get_node(node_id) - class_type = node["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + class_def = nodes.NODE_CLASS_MAPPINGS[node["class_type"]] + node_def = self.dynprompt.get_node_definition(node_id) if not hasattr(class_def, "IS_CHANGED"): self.is_changed[node_id] = False return self.is_changed[node_id] @@ -48,7 +48,7 @@ def get(self, node_id): return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _ = get_input_data(node["inputs"], node_def, node_id, None) try: is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED") node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] @@ -87,13 +87,13 @@ def recursive_debug_dump(self): } return result -def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): - valid_inputs = class_def.INPUT_TYPES() +def get_input_data(inputs, node_def, unique_id, outputs=None, dynprompt=None, extra_data={}): + valid_inputs = node_def['input'] input_data_all = {} missing_keys = {} for x in inputs: input_data = inputs[x] - input_type, input_category, input_info = get_input_info(class_def, x) + input_type, input_category, input_info = get_input_info(node_def, x) def mark_missing(): missing_keys[x] = True input_data_all[x] = (None,) @@ -126,6 +126,8 @@ def mark_missing(): input_data_all[x] = [extra_data.get('extra_pnginfo', None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] + if h[x] == "NODE_DEFINITION": + input_data_all[x] = [node_def] return input_data_all, missing_keys map_node_over_list = None #Don't hook this please @@ -169,12 +171,12 @@ def process_inputs(inputs, index=None): process_inputs(input_dict, i) return results -def merge_result_data(results, obj): +def merge_result_data(results, node_def): # check which outputs need concatenating output = [] - output_is_list = [False] * len(results[0]) - if hasattr(obj, "OUTPUT_IS_LIST"): - output_is_list = obj.OUTPUT_IS_LIST + output_is_list = node_def['output_is_list'] + if len(output_is_list) < len(results[0]): + output_is_list = output_is_list + [False] * (len(results[0]) - len(output_is_list)) # merge node execution results for i, is_list in zip(range(len(results[0])), output_is_list): @@ -190,13 +192,14 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): +def get_output_data(obj, node_def, input_data_all, execution_block_cb=None, pre_execute_cb=None): results = [] uis = [] subgraph_results = [] return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) has_subgraph = False + num_outputs = len(node_def['output']) for i in range(len(return_values)): r = return_values[i] if isinstance(r, dict): @@ -208,24 +211,24 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb new_graph = r['expand'] result = r.get("result", None) if isinstance(result, ExecutionBlocker): - result = tuple([result] * len(obj.RETURN_TYPES)) + result = tuple([result] * num_outputs) subgraph_results.append((new_graph, result)) elif 'result' in r: result = r.get("result", None) if isinstance(result, ExecutionBlocker): - result = tuple([result] * len(obj.RETURN_TYPES)) + result = tuple([result] * num_outputs) results.append(result) subgraph_results.append((None, result)) else: if isinstance(r, ExecutionBlocker): - r = tuple([r] * len(obj.RETURN_TYPES)) + r = tuple([r] * num_outputs) results.append(r) subgraph_results.append((None, r)) if has_subgraph: output = subgraph_results elif len(results) > 0: - output = merge_result_data(results, obj) + output = merge_result_data(results, node_def) else: output = [] ui = dict() @@ -249,6 +252,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + node_def = dynprompt.get_node_definition(unique_id) if caches.outputs.get(unique_id) is not None: if server.client_id is not None: cached_output = caches.ui.get(unique_id) or {} @@ -275,11 +279,11 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp else: resolved_output.append(r) resolved_outputs.append(tuple(resolved_output)) - output_data = merge_result_data(resolved_outputs, class_def) + output_data = merge_result_data(resolved_outputs, node_def) output_ui = [] has_subgraph = False else: - input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + input_data_all, missing_keys = get_input_data(inputs, node_def, unique_id, caches.outputs, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -320,7 +324,7 @@ def execution_block_cb(block): return block def pre_execute_cb(call_index): GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + output_data, output_ui, has_subgraph = get_output_data(obj, node_def, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) if len(output_ui) > 0: caches.ui.set(unique_id, { "meta": { @@ -351,10 +355,11 @@ def pre_execute_cb(call_index): new_node_ids.append(node_id) display_id = node_info.get("override_display_id", unique_id) dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id) - # Figure out if the newly created node is an output node - class_type = node_info["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + dynprompt.node_definitions.resolve_dynamic_definitions(set(new_graph.keys())) + # Figure out if the newly created node is an output node + for node_id, node_info in new_graph.items(): + node_def = dynprompt.get_node_definition(node_id) + if node_def['output_node']: new_output_ids.append(node_id) for i in range(len(node_outputs)): if is_link(node_outputs[i]): @@ -470,6 +475,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): with torch.inference_mode(): dynamic_prompt = DynamicPrompt(prompt) + dynamic_prompt.node_definitions.resolve_dynamic_definitions(set(dynamic_prompt.all_node_ids())) is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) for cache in self.caches.all: cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) @@ -528,7 +534,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): -def validate_inputs(prompt, item, validated): +def validate_inputs(dynprompt, prompt, item, validated): unique_id = item if unique_id in validated: return validated[unique_id] @@ -536,8 +542,9 @@ def validate_inputs(prompt, item, validated): inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] + node_def = dynprompt.get_node_definition(unique_id) - class_inputs = obj_class.INPUT_TYPES() + class_inputs = node_def['input'] valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) errors = [] @@ -552,7 +559,7 @@ def validate_inputs(prompt, item, validated): received_types = {} for x in valid_inputs: - type_input, input_category, extra_info = get_input_info(obj_class, x) + type_input, input_category, extra_info = get_input_info(node_def, x) assert extra_info is not None if x not in inputs: if input_category == "required": @@ -585,8 +592,9 @@ def validate_inputs(prompt, item, validated): continue o_id = val[0] - o_class_type = prompt[o_id]['class_type'] - r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES + o_node_def = dynprompt.get_node_definition(o_id) + r = o_node_def['output'] + assert r is not None received_type = r[val[1]] received_types[x] = received_type if 'input_types' not in validate_function_inputs and received_type != type_input: @@ -605,7 +613,7 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue try: - r = validate_inputs(prompt, o_id, validated) + r = validate_inputs(dynprompt, prompt, o_id, validated) if r[0] is False: # `r` will be set in `validated[o_id]` already valid = False @@ -713,7 +721,7 @@ def validate_inputs(prompt, item, validated): continue if len(validate_function_inputs) > 0 or validate_has_kwargs: - input_data_all, _ = get_input_data(inputs, obj_class, unique_id) + input_data_all, _ = get_input_data(inputs, node_def, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: @@ -756,6 +764,8 @@ def full_type_name(klass): return module + '.' + klass.__qualname__ def validate_prompt(prompt): + dynprompt = DynamicPrompt(prompt) + dynprompt.node_definitions.resolve_dynamic_definitions(set(dynprompt.all_node_ids())) outputs = set() for x in prompt: if 'class_type' not in prompt[x]: @@ -768,8 +778,8 @@ def validate_prompt(prompt): return (False, error, [], []) class_type = prompt[x]['class_type'] - class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None) - if class_ is None: + node_def = dynprompt.get_node_definition(x) + if node_def is None: error = { "type": "invalid_prompt", "message": f"Cannot execute because node {class_type} does not exist.", @@ -778,7 +788,7 @@ def validate_prompt(prompt): } return (False, error, [], []) - if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: + if node_def['output_node']: outputs.add(x) if len(outputs) == 0: @@ -798,7 +808,7 @@ def validate_prompt(prompt): valid = False reasons = [] try: - m = validate_inputs(prompt, o, validated) + m = validate_inputs(dynprompt, prompt, o, validated) valid = m[0] reasons = m[1] except Exception as ex: diff --git a/server.py b/server.py index c7bf6622d4e..2fa62449502 100644 --- a/server.py +++ b/server.py @@ -32,6 +32,7 @@ from model_filemanager import download_model, DownloadModelStatus from typing import Optional from api_server.routes.internal.internal_routes import InternalRoutes +from comfy_execution.graph import DynamicPrompt, DynamicNodeDefinitionCache, node_class_info class BinaryEventTypes: PREVIEW_IMAGE = 1 @@ -525,43 +526,13 @@ async def system_stats(request): async def get_prompt(request): return web.json_response(self.get_queue_info()) - def node_info(node_class): - obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] - info = {} - info['input'] = obj_class.INPUT_TYPES() - info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} - info['output'] = obj_class.RETURN_TYPES - info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) - info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] - info['name'] = node_class - info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class - info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else '' - info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes") - info['category'] = 'sd' - if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: - info['output_node'] = True - else: - info['output_node'] = False - - if hasattr(obj_class, 'CATEGORY'): - info['category'] = obj_class.CATEGORY - - if hasattr(obj_class, 'OUTPUT_TOOLTIPS'): - info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS - - if getattr(obj_class, "DEPRECATED", False): - info['deprecated'] = True - if getattr(obj_class, "EXPERIMENTAL", False): - info['experimental'] = True - return info - @routes.get("/object_info") async def get_object_info(request): with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: try: - out[x] = node_info(x) + out[x] = node_class_info(x) except Exception as e: logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") logging.error(traceback.format_exc()) @@ -572,7 +543,7 @@ async def get_object_info_node(request): node_class = request.match_info.get("node_class", None) out = {} if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): - out[node_class] = node_info(node_class) + out[node_class] = node_class_info(node_class) return web.json_response(out) @routes.get("/history") @@ -595,6 +566,18 @@ async def get_queue(request): queue_info['queue_pending'] = current_queue[1] return web.json_response(queue_info) + @routes.post("/resolve_dynamic_types") + async def resolve_dynamic_types(request): + json_data = await request.json() + if 'prompt' not in json_data: + return web.json_response({"error": "no prompt"}, status=400) + prompt = json_data['prompt'] + dynprompt = DynamicPrompt(prompt) + definitions = DynamicNodeDefinitionCache(dynprompt) + updated = definitions.resolve_dynamic_definitions(dynprompt.all_node_ids()) + return web.json_response(updated) + + @routes.post("/prompt") async def post_prompt(request): logging.info("got prompt")