From 02b13c7be3c601c5c2ad65213d297b91db6ad4bd Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Wed, 28 Aug 2024 22:58:43 -0700 Subject: [PATCH] Dynamic typing is working! --- comfy_execution/graph.py | 15 ++++++++++++++- execution.py | 5 +++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 2335fe164ec2..40eaf00dd6e9 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -76,6 +76,16 @@ def get_node_definition(self, node_id): self.definitions[node_id] = definition return self.definitions[node_id] + def get_constant_type(self, value): + if isinstance(value, (int, float)): + return "INT,FLOAT" + elif isinstance(value, str): + return "STRING" + elif isinstance(value, bool): + return "BOOL" + else: + return None + def get_input_output_types(self, node_id) -> Tuple[Dict[str, str], Dict[str, List[str]]]: node = self.dynprompt.get_node(node_id) input_types: Dict[str, str] = {} @@ -86,6 +96,10 @@ def get_input_output_types(self, node_id) -> Tuple[Dict[str, str], Dict[str, Lis input_types[input_name] = self.definitions[from_node_id]["output"][from_socket] else: input_types[input_name] = "*" + else: + constant_type = self.get_constant_type(input_data) + if constant_type is not None: + input_types[input_name] = constant_type output_types: Dict[str, List[str]] = {} for index in range(len(self.definitions[node_id]["output_name"])): output_name = self.definitions[node_id]["output_name"][index] @@ -103,7 +117,6 @@ def get_input_output_types(self, node_id) -> Tuple[Dict[str, str], Dict[str, Lis return input_types, output_types def resolve_dynamic_definitions(self, node_id_set: Set[str]): - print("Resolving dynamic definitions", node_id_set) entangled = {} # Pre-fill with class info. Also, build a lookup table for output nodes for node_id in node_id_set: diff --git a/execution.py b/execution.py index 59e563c90cdb..b7f0bfb0a716 100644 --- a/execution.py +++ b/execution.py @@ -124,6 +124,8 @@ def mark_missing(): input_data_all[x] = [extra_data.get('extra_pnginfo', None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] + if h[x] == "NODE_DEFINITION": + input_data_all[x] = [node_def] return input_data_all, missing_keys def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): @@ -169,6 +171,8 @@ def merge_result_data(results, node_def): # check which outputs need concatenating output = [] output_is_list = node_def['output_is_list'] + if len(output_is_list) < len(results[0]): + output_is_list = output_is_list + [False] * (len(results[0]) - len(output_is_list)) # merge node execution results for i, is_list in zip(range(len(results[0])), output_is_list): @@ -461,6 +465,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): with torch.inference_mode(): dynamic_prompt = DynamicPrompt(prompt) + dynamic_prompt.node_definitions.resolve_dynamic_definitions(set(dynamic_prompt.all_node_ids())) is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) for cache in self.caches.all: cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)