Skip to content

Commit

Permalink
Implement Dynamic Typing
Browse files Browse the repository at this point in the history
This is a proof of concept to get feedback. Note that it requires the
frontend branch of the same name.
  • Loading branch information
guill committed Oct 16, 2024
1 parent 0dbba9f commit a063468
Show file tree
Hide file tree
Showing 4 changed files with 423 additions and 77 deletions.
200 changes: 190 additions & 10 deletions comfy_execution/graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import nodes
from typing import Set, Tuple, Dict, List

from comfy_execution.graph_utils import is_link

Expand All @@ -15,6 +16,7 @@ class DynamicPrompt:
def __init__(self, original_prompt):
# The original prompt provided by the user
self.original_prompt = original_prompt
self.node_definitions = DynamicNodeDefinitionCache(self)
# Any extra pieces of the graph created during execution
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
Expand All @@ -27,6 +29,9 @@ def get_node(self, node_id):
return self.original_prompt[node_id]
raise NodeNotFoundError(f"Node {node_id} not found")

def get_node_definition(self, node_id):
return self.node_definitions.get_node_definition(node_id)

def has_node(self, node_id):
return node_id in self.original_prompt or node_id in self.ephemeral_prompt

Expand Down Expand Up @@ -54,8 +59,188 @@ def all_node_ids(self):
def get_original_prompt(self):
return self.original_prompt

def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
class DynamicNodeDefinitionCache:
def __init__(self, dynprompt: DynamicPrompt):
self.dynprompt = dynprompt
self.definitions = {}
self.inputs_from_output_slot = {}
self.inputs_from_output_node = {}

def get_node_definition(self, node_id):
if node_id not in self.definitions:
node = self.dynprompt.get_node(node_id)
if node is None:
return None
class_type = node["class_type"]
definition = node_class_info(class_type)
self.definitions[node_id] = definition
return self.definitions[node_id]

def get_constant_type(self, value):
if isinstance(value, (int, float)):
return "INT,FLOAT"
elif isinstance(value, str):
return "STRING"
elif isinstance(value, bool):
return "BOOL"
else:
return None

def get_input_output_types(self, node_id) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
node = self.dynprompt.get_node(node_id)
input_types: Dict[str, str] = {}
for input_name, input_data in node["inputs"].items():
if is_link(input_data):
from_node_id, from_socket = input_data
if from_socket < len(self.definitions[from_node_id]["output_name"]):
input_types[input_name] = self.definitions[from_node_id]["output"][from_socket]
else:
input_types[input_name] = "*"
else:
constant_type = self.get_constant_type(input_data)
if constant_type is not None:
input_types[input_name] = constant_type
output_types: Dict[str, List[str]] = {}
for index in range(len(self.definitions[node_id]["output_name"])):
output_name = self.definitions[node_id]["output_name"][index]
if (node_id, index) not in self.inputs_from_output_slot:
continue
for (to_node_id, to_input_name) in self.inputs_from_output_slot[(node_id, index)]:
if output_name not in output_types:
output_types[output_name] = []
if to_input_name in self.definitions[to_node_id]["input"]["required"]:
output_types[output_name].append(self.definitions[to_node_id]["input"]["required"][to_input_name][0])
elif to_input_name in self.definitions[to_node_id]["input"]["optional"]:
output_types[output_name].append(self.definitions[to_node_id]["input"]["optional"][to_input_name][0])
else:
output_types[output_name].append("*")
return input_types, output_types

def resolve_dynamic_definitions(self, node_id_set: Set[str]):
entangled = {}
# Pre-fill with class info. Also, build a lookup table for output nodes
for node_id in node_id_set:
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
self.definitions[node_id] = node_class_info(class_type)
for input_name, input_data in node["inputs"].items():
if is_link(input_data):
input_tuple = tuple(input_data)
if input_tuple not in self.inputs_from_output_slot:
self.inputs_from_output_slot[input_tuple] = []
self.inputs_from_output_slot[input_tuple].append((node_id, input_name))
if input_tuple[0] not in self.inputs_from_output_node:
self.inputs_from_output_node[input_tuple[0]] = []
self.inputs_from_output_node[input_tuple[0]].append((node_id, input_name))
_, _, extra_info = get_input_info(self.definitions[node_id], input_name)
if extra_info is not None and extra_info.get("entangleTypes", False):
from_node_id = input_data[0]
if node_id not in entangled:
entangled[node_id] = []
if from_node_id not in entangled:
entangled[from_node_id] = []

entangled[node_id].append((from_node_id, input_name))
entangled[from_node_id].append((node_id, input_name))

# Evaluate node info
to_resolve = node_id_set.copy()
updated = {}
while len(to_resolve) > 0:
node_id = to_resolve.pop()
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, "resolve_dynamic_types"):
entangled_types = {}
for (entangled_id, entangled_name) in entangled.get(node_id, []):
entangled_def = self.get_node_definition(entangled_id)
if entangled_def is None:
continue
input_types = {}
output_types = {}
for input_category, input_list in entangled_def["input"].items():
for input_name, input_info in input_list.items():
if isinstance(input_info, tuple) or input_category != "hidden":
input_types[input_name] = input_info[0]
for i in range(len(entangled_def["output"])):
output_name = entangled_def["output_name"][i]
output_types[output_name] = entangled_def["output"][i]

if entangled_name not in entangled_types:
entangled_types[entangled_name] = []
entangled_types[entangled_name].append({
"node_id": entangled_id,
"input_types": input_types,
"output_types": output_types
})

input_types, output_types = self.get_input_output_types(node_id)
dynamic_info = class_def.resolve_dynamic_types(
input_types=input_types,
output_types=output_types,
entangled_types=entangled_types
)
old_info = self.definitions[node_id].copy()
self.definitions[node_id].update(dynamic_info)
updated[node_id] = self.definitions[node_id]
# We changed the info, so we potentially need to resolve adjacent and entangled nodes
if old_info != self.definitions[node_id]:
for (entangled_node_id, _) in entangled.get(node_id, []):
if entangled_node_id in node_id_set:
to_resolve.add(entangled_node_id)
for i in range(len(self.definitions[node_id]["output"])):
for (output_node_id, _) in self.inputs_from_output_slot.get((node_id, i), []):
if output_node_id in node_id_set:
to_resolve.add(output_node_id)
for _, input_data in node["inputs"].items():
if is_link(input_data):
if input_data[0] in node_id_set:
to_resolve.add(input_data[0])
for (to_node_id, _) in self.inputs_from_output_node.get(node_id, []):
if to_node_id in node_id_set:
to_resolve.add(to_node_id)
# Because this run may have changed the number of outputs, we may need to run it again
# in order to get those outputs passed as output_types.
to_resolve.add(node_id)
return updated

def node_class_info(node_class):
if node_class not in nodes.NODE_CLASS_MAPPINGS:
return None
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes")
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
info['output_node'] = True
else:
info['output_node'] = False

if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY

if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS

if getattr(obj_class, "DEPRECATED", False):
info['deprecated'] = True
if getattr(obj_class, "EXPERIMENTAL", False):
info['experimental'] = True

return info


def get_input_info(node_info, input_name):
valid_inputs = node_info["input"]
input_info = None
input_category = None
if "required" in valid_inputs and input_name in valid_inputs["required"]:
Expand Down Expand Up @@ -84,9 +269,7 @@ def __init__(self, dynprompt):
self.blocking = {} # Which nodes are blocked by this node

def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return get_input_info(class_def, input_name)
return get_input_info(self.dynprompt.get_node_definition(unique_id), input_name)

def make_input_strong_link(self, to_node_id, to_input):
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
Expand Down Expand Up @@ -197,11 +380,8 @@ def ux_friendly_pick_node(self, node_list):
# for a PreviewImage to display a result as soon as it can
# Some other heuristics could probably be used here to improve the UX further.
def is_output(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
return True
return False
node_def = self.dynprompt.get_node_definition(node_id)
return node_def['output_node']

for node_id in node_list:
if is_output(node_id):
Expand Down
Loading

0 comments on commit a063468

Please sign in to comment.