Skip to content

Commit

Permalink
fix(runtime/virtual): node output may be None
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaoses-Ib committed Oct 26, 2024
1 parent b617fa9 commit 032bb26
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 42 deletions.
79 changes: 41 additions & 38 deletions src/comfy_script/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,9 @@ def __init__(self):
self._watch_thread = None
self._queue_empty_callback = None
self._queue_remaining_callbacks = [self._when_empty_callback]
self._watch_display_node = None
self._watch_display_node_preview = None
self._watch_display_task = None
self._watch_display_node = False
self._watch_display_node_preview = False
self._watch_display_task = False
self.queue_remaining = 0

async def _get_history(self, prompt_id: str) -> dict | None:
Expand Down Expand Up @@ -516,7 +516,7 @@ async def _watch(self):
if self.queue_remaining == 0:
for task in self._tasks.values():
print(f'ComfyScript: The queue is empty but {task} has not been executed')
await task._set_result_threadsafe(None, {})
await task._set_results_threadsafe({})
self._tasks.clear()

for callback in self._queue_remaining_callbacks:
Expand All @@ -527,7 +527,7 @@ async def _watch(self):
outputs = {}
if history is not None:
outputs = history['outputs']
await task._set_result_threadsafe(None, outputs, self._watch_display_task)
await task._set_results_threadsafe(outputs, self._watch_display_task)
if self._watch_display_task:
print(f'Queue remaining: {self.queue_remaining}')
elif msg['type'] == 'executed':
Expand Down Expand Up @@ -781,7 +781,7 @@ def __init__(self, prompt_id: str, number: int, id: data.IdManager):
self.prompt_id = prompt_id
self.number = number
self._id = id
self._new_outputs = {}
self._new_outputs: dict[str, dict | None] = {}
self._fut = asyncio.Future()
self._node_preview_callbacks: list[Callable[[Task, str, Image.Image]]] = []

Expand All @@ -800,38 +800,41 @@ def _set_node_preview(self, node_id: str, preview: Image.Image, display: bool):

display(preview, clear=True)

async def _set_result_threadsafe(self, node_id: str | None, output: dict, display_result: bool = False) -> None:
if node_id is not None:
self._new_outputs[node_id] = output
if display_result:
from IPython.display import display
async def _set_result_threadsafe(self, node_id: str, output: dict | None, display_result: bool = False) -> None:
self._new_outputs[node_id] = output
if display_result:
from IPython.display import display

display(clear=True)
display(clear=True)
result = data.Result.from_output(output)
if isinstance(result, data.ImageBatchResult):
await Images(result)._display()
else:
display(result)

async def _set_results_threadsafe(self, outputs: dict[str, dict | None], display_result: bool = False) -> None:
# ComfyUI will skip node outputs None in outputs
outputs = self._new_outputs | outputs

self.get_loop().call_soon_threadsafe(self._fut.set_result, outputs)
if display_result:
from IPython.display import display

image_batches = []
others = []
# TODO: Sort by the parsed id
for _id, output in sorted(outputs.items(), key=lambda item: item[0]):
result = data.Result.from_output(output)
if isinstance(result, data.ImageBatchResult):
await Images(result)._display()
image_batches.append(result)
else:
display(result)
else:
self.get_loop().call_soon_threadsafe(self._fut.set_result, output)
if display_result:
from IPython.display import display

image_batches = []
others = []
# TODO: Sort by the parsed id
for _id, output in sorted(output.items(), key=lambda item: item[0]):
result = data.Result.from_output(output)
if isinstance(result, data.ImageBatchResult):
image_batches.append(result)
else:
others.append(result)
if image_batches or others:
display(clear=True)
if image_batches:
await Images(*image_batches)._display()
if others:
display(*others)
others.append(result)
if image_batches or others:
display(clear=True)
if image_batches:
await Images(*image_batches)._display()
if others:
display(*others)

async def _wait(self) -> list[data.Result]:
'''`Task` can be directly awaited like `await task`. This method is for internal use only.'''
Expand All @@ -850,13 +853,13 @@ async def result(self, output: data.NodeOutput) -> data.Result | None:
if id is None:
return None

output = self._new_outputs.get(id)
if output is not None:
if id in self._new_outputs:
output: dict | None = self._new_outputs[id]
return data.Result.from_output(output)

outputs: dict = await self._fut
output = outputs.get(id)
if output is not None:
if id in outputs:
output: dict | None = outputs[id]
return data.Result.from_output(output)
return None

Expand Down
25 changes: 21 additions & 4 deletions src/comfy_script/runtime/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _get_outputs_prompt_and_id(outputs: Iterable[NodeOutput]) -> (dict, IdManage
return prompt, id

class Result:
def __init__(self, output: dict):
def __init__(self, output: dict | None):
self._output = output

def __repr__(self) -> str:
Expand All @@ -134,11 +134,28 @@ def __str__(self) -> str:
return f'{self.__class__.__name__}({self._output.__str__()})'

@classmethod
def from_output(cls, output: dict) -> Result:
if 'images' in output:
return ImageBatchResult(output)
def from_output(cls, output: dict | None) -> Result:
if isinstance(output, dict):
if 'images' in output:
return ImageBatchResult(output)
elif output is None:
return EmptyResult(output)
return Result(output)

class EmptyResult(Result):
'''
An empty result from an output node that outputs nothing.
Example:
```
# Derfuu_Nodes
StringDebugPrint('123', '456').wait()
```
'''

def _ipython_display_(self):
pass

from .Images import ImageBatchResult, Images

__all__ = [
Expand Down

0 comments on commit 032bb26

Please sign in to comment.