Skip to content

Commit

Permalink
Fix ui output for duplicated nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
guill committed Jun 17, 2024
1 parent afa4c7b commit 8d17f3c
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 67 deletions.
44 changes: 13 additions & 31 deletions comfy/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,6 @@ def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_map
order_mapping[ancestor_id] = len(ancestors) - 1
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)

class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)

def include_node_id_in_input(self):
return True

class BasicCache:
def __init__(self, key_class):
self.key_class = key_class
Expand All @@ -151,24 +144,30 @@ def all_node_ids(self):
node_ids = node_ids.union(subcache.all_node_ids())
return node_ids

def clean_unused(self):
assert self.initialized
def _clean_cache(self):
preserve_keys = set(self.cache_key_set.get_used_keys())
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
to_remove = []
for key in self.cache:
if key not in preserve_keys:
to_remove.append(key)
for key in to_remove:
del self.cache[key]

def _clean_subcaches(self):
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())

to_remove = []
for key in self.subcaches:
if key not in preserve_subcaches:
to_remove.append(key)
for key in to_remove:
del self.subcaches[key]

def clean_unused(self):
assert self.initialized
self._clean_cache()
self._clean_subcaches()

def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
Expand Down Expand Up @@ -246,15 +245,6 @@ def ensure_subcache_for(self, node_id, children_ids):
assert cache is not None
return cache._ensure_subcache(node_id, children_ids)

def all_active_values(self):
active_nodes = self.all_node_ids()
result = []
for node_id in active_nodes:
value = self.get(node_id)
if value is not None:
result.append(value)
return result

class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
super().__init__(key_class)
Expand All @@ -279,6 +269,7 @@ def clean_unused(self):
del self.used_generation[key]
if key in self.children:
del self.children[key]
self._clean_subcaches()

def get(self, node_id):
self._mark_used(node_id)
Expand All @@ -294,6 +285,9 @@ def set(self, node_id, value):
return self._set_immediate(node_id, value)

def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
super()._ensure_subcache(node_id, children_ids)

self.cache_key_set.add_keys(children_ids)
self._mark_used(node_id)
cache_key = self.cache_key_set.get_data_key(node_id)
Expand All @@ -303,15 +297,3 @@ def ensure_subcache_for(self, node_id, children_ids):
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self

def all_active_values(self):
explored = set()
to_explore = set(self.cache_key_set.get_used_keys())
while len(to_explore) > 0:
cache_key = to_explore.pop()
if cache_key not in explored:
self.used_generation[cache_key] = self.generation
explored.add(cache_key)
if cache_key in self.children:
to_explore.update(self.children[cache_key])
return [self.cache[key] for key in explored if key in self.cache]

16 changes: 9 additions & 7 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import comfy.graph_utils
from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy.graph_utils import is_link, GraphBuilder
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetInputSignatureWithID, CacheKeySetID
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy.cli_args import args

class ExecutionResult(Enum):
Expand Down Expand Up @@ -69,13 +69,13 @@ def __init__(self, lru_size=None):
# blowing away the cache every time
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignatureWithID, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)

# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignatureWithID)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)

def recursive_debug_dump(self):
Expand Down Expand Up @@ -486,10 +486,12 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):

ui_outputs = {}
meta_outputs = {}
for ui_info in self.caches.ui.all_active_values():
node_id = ui_info["meta"]["node_id"]
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
all_node_ids = self.caches.ui.all_node_ids()
for node_id in all_node_ids:
ui_info = self.caches.ui.get(node_id)
if ui_info is not None:
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
Expand Down
66 changes: 37 additions & 29 deletions tests/inference/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,26 @@ class TestExecution:
#
# Initialize server and client
#
@fixture(scope="class", autouse=True)
def _server(self, args_pytest):
@fixture(scope="class", autouse=True, params=[
# (use_lru, lru_size)
(False, 0),
(True, 0),
(True, 100),
])
def _server(self, args_pytest, request):
# Start server
p = subprocess.Popen([
'python','main.py',
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
])
pargs = [
'python','main.py',
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
]
use_lru, lru_size = request.param
if use_lru:
pargs += ['--cache-lru', str(lru_size)]
print("Running server with args:", pargs)
p = subprocess.Popen(pargs)
yield
p.kill()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -159,15 +169,9 @@ def client(self, shared_client, request):
shared_client.set_test_name(f"execution[{request.node.name}]")
yield shared_client

def clear_cache(self, client: ComfyClient):
g = GraphBuilder(prefix="foo")
random = g.node("StubImage", content="NOISE", height=1, width=1, batch_size=1)
g.node("PreviewImage", images=random.out(0))
client.run(g)

@fixture
def builder(self):
yield GraphBuilder(prefix="")
def builder(self, request):
yield GraphBuilder(prefix=request.node.name)

def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder):
g = builder
Expand All @@ -187,7 +191,6 @@ def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder):
assert result.did_run(lazy_mix)

def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
self.clear_cache(client)
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
Expand All @@ -196,14 +199,12 @@ def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
g.node("SaveImage", images=lazy_mix.out(0))

result1 = client.run(g)
client.run(g)
result2 = client.run(g)
for node_id, node in g.nodes.items():
assert result1.did_run(node), f"Node {node_id} didn't run"
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"

def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
self.clear_cache(client)
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
Expand All @@ -212,15 +213,11 @@ def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
g.node("SaveImage", images=lazy_mix.out(0))

result1 = client.run(g)
client.run(g)
mask.inputs['value'] = 0.4
result2 = client.run(g)
for node_id, node in g.nodes.items():
assert result1.did_run(node), f"Node {node_id} didn't run"
assert not result2.did_run(input1), "Input1 should have been cached"
assert not result2.did_run(input2), "Input2 should have been cached"
assert result2.did_run(mask), "Mask should have been re-run"
assert result2.did_run(lazy_mix), "Lazy mix should have been re-run"

def test_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
Expand Down Expand Up @@ -365,7 +362,6 @@ def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
assert result4.did_run(is_changed), "is_changed should not have been cached"

def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder):
self.clear_cache(client)
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
Expand All @@ -378,8 +374,6 @@ def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder):
result_image = result.get_images(output)[0]
expected = 255 // 4
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
assert result.did_run(input1)
assert result.did_run(input2)

def test_for_loop(self, client: ComfyClient, builder: GraphBuilder):
g = builder
Expand Down Expand Up @@ -418,3 +412,17 @@ def test_mixed_expansion_returns(self, client: ComfyClient, builder: GraphBuilde
assert len(images_literal) == 3, "Should have 2 images"
for i in range(3):
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"

def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)

output1 = g.node("PreviewImage", images=input1.out(0))
output2 = g.node("PreviewImage", images=input1.out(0))

result = client.run(g)
images1 = result.get_images(output1)
images2 = result.get_images(output2)
assert len(images1) == 1, "Should have 1 image"
assert len(images2) == 1, "Should have 1 image"

0 comments on commit 8d17f3c

Please sign in to comment.