From e01e21b5ea753c398652a61a18f9ab856a77a6a4 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 24 Aug 2024 07:10:13 -0700 Subject: [PATCH] [Bug #4529] Fix graph partial validation failure Currently, if a graph partially fails validation (i.e. some outputs are valid while others have links from missing nodes), the execution loop could get an exception resulting in server lockup. This isn't actually possible to reproduce via the default UI, but is a potential issue for people using the API to construct invalid graphs. --- comfy_execution/caching.py | 9 +++++++++ tests/inference/test_execution.py | 23 +++++++++++++++++++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 6664a34287b..e67914a3fab 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -56,6 +56,8 @@ def add_keys(self, node_ids): for node_id in node_ids: if node_id in self.keys: continue + if not self.dynprompt.has_node(node_id): + continue node = self.dynprompt.get_node(node_id) self.keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"]) @@ -74,6 +76,8 @@ def add_keys(self, node_ids): for node_id in node_ids: if node_id in self.keys: continue + if not self.dynprompt.has_node(node_id): + continue node = self.dynprompt.get_node(node_id) self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) self.subcache_keys[node_id] = (node_id, node["class_type"]) @@ -87,6 +91,9 @@ def get_node_signature(self, dynprompt, node_id): return to_hashable(signature) def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + if not dynprompt.has_node(node_id): + # This node doesn't exist -- we can't cache it. + return [float("NaN")] node = dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -112,6 +119,8 @@ def get_ordered_ancestry(self, dynprompt, node_id): return ancestors, order_mapping def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): + if not dynprompt.has_node(node_id): + return inputs = dynprompt.get_node(node_id)["inputs"] input_keys = sorted(inputs.keys()) for key in input_keys: diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 7965165fc9c..ffc0c482aed 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -357,6 +357,25 @@ def test_dynamic_cycle_error(self, client: ComfyClient, builder: GraphBuilder): assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node" + def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) + input3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + mix2 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input3.out(0), mask=mask.out(0)) + # We have multiple outputs. The first is invalid, but the second is valid + g.node("SaveImage", images=mix1.out(0)) + g.node("SaveImage", images=mix2.out(0)) + g.remove_node("removeme") + + client.run(g) + + # Add back in the missing node to make sure the error doesn't break the server + input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) + client.run(g) + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): g = builder # Creating the nodes in this specific order previously caused a bug @@ -450,8 +469,8 @@ def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) - output1 = g.node("PreviewImage", images=input1.out(0)) - output2 = g.node("PreviewImage", images=input1.out(0)) + output1 = g.node("SaveImage", images=input1.out(0)) + output2 = g.node("SaveImage", images=input1.out(0)) result = client.run(g) images1 = result.get_images(output1)