Skip to content

Commit

Permalink
Dynamic typing is working!
Browse files Browse the repository at this point in the history
  • Loading branch information
guill committed Aug 29, 2024
1 parent 95c8d14 commit 02b13c7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
15 changes: 14 additions & 1 deletion comfy_execution/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ def get_node_definition(self, node_id):
self.definitions[node_id] = definition
return self.definitions[node_id]

def get_constant_type(self, value):
if isinstance(value, (int, float)):
return "INT,FLOAT"
elif isinstance(value, str):
return "STRING"
elif isinstance(value, bool):
return "BOOL"
else:
return None

def get_input_output_types(self, node_id) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
node = self.dynprompt.get_node(node_id)
input_types: Dict[str, str] = {}
Expand All @@ -86,6 +96,10 @@ def get_input_output_types(self, node_id) -> Tuple[Dict[str, str], Dict[str, Lis
input_types[input_name] = self.definitions[from_node_id]["output"][from_socket]
else:
input_types[input_name] = "*"
else:
constant_type = self.get_constant_type(input_data)
if constant_type is not None:
input_types[input_name] = constant_type
output_types: Dict[str, List[str]] = {}
for index in range(len(self.definitions[node_id]["output_name"])):
output_name = self.definitions[node_id]["output_name"][index]
Expand All @@ -103,7 +117,6 @@ def get_input_output_types(self, node_id) -> Tuple[Dict[str, str], Dict[str, Lis
return input_types, output_types

def resolve_dynamic_definitions(self, node_id_set: Set[str]):
print("Resolving dynamic definitions", node_id_set)
entangled = {}
# Pre-fill with class info. Also, build a lookup table for output nodes
for node_id in node_id_set:
Expand Down
5 changes: 5 additions & 0 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def mark_missing():
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id]
if h[x] == "NODE_DEFINITION":
input_data_all[x] = [node_def]
return input_data_all, missing_keys

def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
Expand Down Expand Up @@ -169,6 +171,8 @@ def merge_result_data(results, node_def):
# check which outputs need concatenating
output = []
output_is_list = node_def['output_is_list']
if len(output_is_list) < len(results[0]):
output_is_list = output_is_list + [False] * (len(results[0]) - len(output_is_list))

# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
Expand Down Expand Up @@ -461,6 +465,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):

with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
dynamic_prompt.node_definitions.resolve_dynamic_definitions(set(dynamic_prompt.all_node_ids()))
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
Expand Down

0 comments on commit 02b13c7

Please sign in to comment.