Skip to content

Commit

Permalink
More work on dynamic typing
Browse files Browse the repository at this point in the history
  • Loading branch information
guill committed Aug 26, 2024
1 parent 493023f commit 95c8d14
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 138 deletions.
140 changes: 132 additions & 8 deletions comfy_execution/graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import nodes
from typing import Set, Tuple, Dict, List

from comfy_execution.graph_utils import is_link

Expand All @@ -15,6 +16,7 @@ class DynamicPrompt:
def __init__(self, original_prompt):
# The original prompt provided by the user
self.original_prompt = original_prompt
self.node_definitions = DynamicNodeDefinitionCache(self)
# Any extra pieces of the graph created during execution
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
Expand All @@ -27,6 +29,9 @@ def get_node(self, node_id):
return self.original_prompt[node_id]
raise NodeNotFoundError(f"Node {node_id} not found")

def get_node_definition(self, node_id):
return self.node_definitions.get_node_definition(node_id)

def has_node(self, node_id):
return node_id in self.original_prompt or node_id in self.ephemeral_prompt

Expand Down Expand Up @@ -54,7 +59,130 @@ def all_node_ids(self):
def get_original_prompt(self):
return self.original_prompt

class DynamicNodeDefinitionCache:
def __init__(self, dynprompt: DynamicPrompt):
self.dynprompt = dynprompt
self.definitions = {}
self.inputs_from_output_slot = {}
self.inputs_from_output_node = {}

def get_node_definition(self, node_id):
if node_id not in self.definitions:
node = self.dynprompt.get_node(node_id)
if node is None:
return None
class_type = node["class_type"]
definition = node_class_info(class_type)
self.definitions[node_id] = definition
return self.definitions[node_id]

def get_input_output_types(self, node_id) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
node = self.dynprompt.get_node(node_id)
input_types: Dict[str, str] = {}
for input_name, input_data in node["inputs"].items():
if is_link(input_data):
from_node_id, from_socket = input_data
if from_socket < len(self.definitions[from_node_id]["output_name"]):
input_types[input_name] = self.definitions[from_node_id]["output"][from_socket]
else:
input_types[input_name] = "*"
output_types: Dict[str, List[str]] = {}
for index in range(len(self.definitions[node_id]["output_name"])):
output_name = self.definitions[node_id]["output_name"][index]
if (node_id, index) not in self.inputs_from_output_slot:
continue
for (to_node_id, to_input_name) in self.inputs_from_output_slot[(node_id, index)]:
if output_name not in output_types:
output_types[output_name] = []
if to_input_name in self.definitions[to_node_id]["input"]["required"]:
output_types[output_name].append(self.definitions[to_node_id]["input"]["required"][to_input_name][0])
elif to_input_name in self.definitions[to_node_id]["input"]["optional"]:
output_types[output_name].append(self.definitions[to_node_id]["input"]["optional"][to_input_name][0])
else:
output_types[output_name].append("*")
return input_types, output_types

def resolve_dynamic_definitions(self, node_id_set: Set[str]):
print("Resolving dynamic definitions", node_id_set)
entangled = {}
# Pre-fill with class info. Also, build a lookup table for output nodes
for node_id in node_id_set:
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
self.definitions[node_id] = node_class_info(class_type)
for input_name, input_data in node["inputs"].items():
if is_link(input_data):
input_tuple = tuple(input_data)
if input_tuple not in self.inputs_from_output_slot:
self.inputs_from_output_slot[input_tuple] = []
self.inputs_from_output_slot[input_tuple].append((node_id, input_name))
if input_tuple[0] not in self.inputs_from_output_node:
self.inputs_from_output_node[input_tuple[0]] = []
self.inputs_from_output_node[input_tuple[0]].append((node_id, input_name))
_, _, extra_info = get_input_info(self.definitions[node_id], input_name)
if extra_info is not None and extra_info.get("entangleTypes", False):
from_node_id = input_data[0]
if node_id not in entangled:
entangled[node_id] = []
if from_node_id not in entangled:
entangled[from_node_id] = []

entangled[node_id].append((from_node_id, input_name))
entangled[from_node_id].append((node_id, input_name))

# Evaluate node info
to_resolve = node_id_set.copy()
updated = {}
while len(to_resolve) > 0:
node_id = to_resolve.pop()
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, "resolve_dynamic_types"):
entangled_types = {}
for (entangled_id, entangled_name) in entangled.get(node_id, []):
input_types, output_types = self.get_input_output_types(entangled_id)
if entangled_name not in entangled_types:
entangled_types[entangled_name] = []
entangled_types[entangled_name].append({
"node_id": entangled_id,
"input_types": input_types,
"output_types": output_types
})

input_types, output_types = self.get_input_output_types(node_id)
dynamic_info = class_def.resolve_dynamic_types(
input_types=input_types,
output_types=output_types,
entangled_types=entangled_types
)
old_info = self.definitions[node_id].copy()
self.definitions[node_id].update(dynamic_info)
# We changed the info, so we potentially need to resolve adjacent and entangled nodes
if old_info != self.definitions[node_id]:
updated[node_id] = self.definitions[node_id]
for (entangled_node_id, _) in entangled.get(node_id, []):
if entangled_node_id in node_id_set:
to_resolve.add(entangled_node_id)
for i in range(len(self.definitions[node_id]["output"])):
for (output_node_id, _) in self.inputs_from_output_slot.get((node_id, i), []):
if output_node_id in node_id_set:
to_resolve.add(output_node_id)
for _, input_data in node["inputs"].items():
if is_link(input_data):
if input_data[0] in node_id_set:
to_resolve.add(input_data[0])
for (to_node_id, _) in self.inputs_from_output_node.get(node_id, []):
if to_node_id in node_id_set:
to_resolve.add(to_node_id)
# Because this run may have changed the number of outputs, we may need to run it again
# in order to get those outputs passed as output_types.
to_resolve.add(node_id)
return updated

def node_class_info(node_class):
if node_class not in nodes.NODE_CLASS_MAPPINGS:
return None
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
Expand Down Expand Up @@ -110,9 +238,7 @@ def __init__(self, dynprompt):
self.blocking = {} # Which nodes are blocked by this node

def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return get_input_info(node_class_info(class_def), input_name)
return get_input_info(self.dynprompt.get_node_definition(unique_id), input_name)

def make_input_strong_link(self, to_node_id, to_input):
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
Expand Down Expand Up @@ -145,8 +271,7 @@ def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
node_info = node_class_info(self.dynprompt.get_node(from_node_id)["class_type"])
input_type, input_category, input_info = self.get_input_info(node_info, input_name)
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if include_lazy or not is_lazy:
self.add_strong_link(from_node_id, from_socket, unique_id)
Expand Down Expand Up @@ -209,9 +334,8 @@ def stage_node_execution(self):
# for a PreviewImage to display a result as soon as it can
# Some other heuristics could probably be used here to improve the UX further.
for node_id in available:
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
node_def = self.dynprompt.get_node_definition(node_id)
if node_def['output_node']:
next_node = node_id
break
self.staged_node_id = next_node
Expand Down
Loading

0 comments on commit 95c8d14

Please sign in to comment.