From bc654d37e744d38185b5870bbda3dd0536e09dea Mon Sep 17 00:00:00 2001 From: Chaoses-Ib Date: Sun, 27 Oct 2024 23:17:52 +0800 Subject: [PATCH] feat(runtime/node): add `nodes` and `get` (#17, #30, #59) --- src/comfy_script/nodes/__init__.py | 2 +- src/comfy_script/runtime/__init__.py | 5 ++- src/comfy_script/runtime/factory.py | 4 ++ src/comfy_script/runtime/node.py | 48 +++++++++++++++++++++++ src/comfy_script/runtime/nodes.py | 10 ++++- src/comfy_script/runtime/real/__init__.py | 7 +++- src/comfy_script/runtime/real/node.py | 48 +++++++++++++++++++++++ src/comfy_script/runtime/real/nodes.py | 5 ++- 8 files changed, 122 insertions(+), 7 deletions(-) create mode 100644 src/comfy_script/runtime/node.py create mode 100644 src/comfy_script/runtime/real/node.py diff --git a/src/comfy_script/nodes/__init__.py b/src/comfy_script/nodes/__init__.py index 2951d55..4a3c17d 100644 --- a/src/comfy_script/nodes/__init__.py +++ b/src/comfy_script/nodes/__init__.py @@ -1,4 +1,4 @@ -'''These nodes are the nodes provided by ComfyScript, not the nodes loaded by ComfyUI, which should be in `runtime` package.''' +'''These nodes are the nodes provided by ComfyScript, not the nodes loaded by ComfyUI, which should be in `runtime` package.''' __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] diff --git a/src/comfy_script/runtime/__init__.py b/src/comfy_script/runtime/__init__.py index c40055e..181c743 100644 --- a/src/comfy_script/runtime/__init__.py +++ b/src/comfy_script/runtime/__init__.py @@ -60,7 +60,8 @@ async def _load(comfyui: str | Client | Path = None, args: ComfyUIArgs | None = nodes_info = await client._get_nodes_info() print(f'Nodes: {len(nodes_info)}') - await nodes.load(nodes_info, vars) + node.nodes.clear() + await nodes.load(nodes_info, vars, nodes=node.nodes) # TODO: Stop watch if watch turns to False if watch: @@ -1050,6 +1051,7 @@ def __exit__(self, exc_type, exc_value, traceback): from .. import client from ..client import Client +from . import node from . import nodes from . import data from .data import * @@ -1063,5 +1065,6 @@ def __exit__(self, exc_type, exc_value, traceback): 'queue', 'Task', 'Workflow', + 'node' ] __all__.extend(data.__all__) \ No newline at end of file diff --git a/src/comfy_script/runtime/factory.py b/src/comfy_script/runtime/factory.py index cf2759c..187b398 100644 --- a/src/comfy_script/runtime/factory.py +++ b/src/comfy_script/runtime/factory.py @@ -81,6 +81,8 @@ def __init__(self, *, hidden_inputs: bool = False, max_enum_values: int = 2000, - `import_fullname_types`: WIP. ''' + self.nodes: dict[str, Any] = {} + self._vars = { id: None for k, dic in self.GLOBAL_ENUMS.items() for id in dic.values() } self._data_type_stubs = {} self._enum_values = {} @@ -520,6 +522,8 @@ def {class_id}( for enum_id, enum in enums.items(): setattr(node, enum_id, enum) self._set_type(info['name'], class_id, node) + + self.nodes[info['name']] = node def vars(self) -> dict: return self._vars diff --git a/src/comfy_script/runtime/node.py b/src/comfy_script/runtime/node.py new file mode 100644 index 0000000..55b5c99 --- /dev/null +++ b/src/comfy_script/runtime/node.py @@ -0,0 +1,48 @@ +from typing import Any + +nodes: dict[str, Any] = {} +'''A dict of loaded nodes keyed by their raw names. Compared to `comfy_script.runtime.nodes` module, `nodes` is more suitable for programmatic access. + +Example: +``` +from comfy_script.runtime import * +load() + +print(node.nodes) +# {'KSampler': , +# 'CheckpointLoaderSimple': , +# 'CLIPTextEncode': , +# ... + +# or: node.get('CheckpointLoaderSimple') +loader = node.nodes['CheckpointLoaderSimple'] +model, clip, vae = loader('v1-5-pruned-emaonly.ckpt') +``` + +With type hint: +``` +from comfy_script.runtime.nodes import CheckpointLoaderSimple + +loader: type[CheckpointLoaderSimple] = node.nodes['CheckpointLoaderSimple'] +model, clip, vae = loader('v1-5-pruned-emaonly.ckpt') +``` + +With compile-time-only type hint: +``` +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from comfy_script.runtime.nodes import CheckpointLoaderSimple + +loader: 'type[CheckpointLoaderSimple]' = node.nodes['CheckpointLoaderSimple'] +model, clip, vae = loader('v1-5-pruned-emaonly.ckpt') +``` +''' + +def get(name: str) -> Any | None: + return nodes.get(name, None) + +__all__ = [ + 'nodes', + 'get' +] \ No newline at end of file diff --git a/src/comfy_script/runtime/nodes.py b/src/comfy_script/runtime/nodes.py index a1dc882..7daf3e3 100644 --- a/src/comfy_script/runtime/nodes.py +++ b/src/comfy_script/runtime/nodes.py @@ -11,7 +11,7 @@ class VirtualRuntimeFactory(factory.RuntimeFactory): def new_node(self, info: dict, defaults: dict, output_types: list[type]): return Node(info, defaults, output_types) -async def load(nodes_info: dict, vars: dict | None) -> None: +async def load(nodes_info: dict, vars: dict | None, *, nodes: dict[str, typing.Any] | None = None) -> None: fact = VirtualRuntimeFactory() await fact.init() @@ -21,7 +21,10 @@ async def load(nodes_info: dict, vars: dict | None) -> None: except Exception as e: print(f'ComfyScript: Failed to load node {node_info["name"]}') traceback.print_exc() - + + if nodes is not None: + nodes.update(fact.nodes) + globals().update(fact.vars()) __all__.extend(fact.vars().keys()) @@ -94,6 +97,9 @@ def __call__(self, *args, **kwds): r = [r] return r + + def __repr__(self): + return f'' @classmethod def set_output_hook(cls, hook: typing.Callable[[data.NodeOutput | list[data.NodeOutput]], None]): diff --git a/src/comfy_script/runtime/real/__init__.py b/src/comfy_script/runtime/real/__init__.py index 45634ec..a357214 100644 --- a/src/comfy_script/runtime/real/__init__.py +++ b/src/comfy_script/runtime/real/__init__.py @@ -40,7 +40,8 @@ def load(comfyui: Path | str = None, args: ComfyUIArgs | None = None, vars: dict config = RealModeConfig.naked() elif config is None: config = RealModeConfig() - asyncio.run(nodes.load(nodes_info, vars, config)) + node.nodes.clear() + asyncio.run(nodes.load(nodes_info, vars, config, nodes=node.nodes)) class Workflow: # TODO: Thread-safe @@ -169,11 +170,13 @@ def naked() -> RealModeConfig: from ... import client from .. import ComfyUIArgs, start_comfyui +from . import node from . import nodes __all__ = [ 'load', 'ComfyUIArgs', 'RealModeConfig', - 'Workflow' + 'Workflow', + 'node' ] \ No newline at end of file diff --git a/src/comfy_script/runtime/real/node.py b/src/comfy_script/runtime/real/node.py new file mode 100644 index 0000000..30da3d7 --- /dev/null +++ b/src/comfy_script/runtime/real/node.py @@ -0,0 +1,48 @@ +from typing import Any + +nodes: dict[str, Any] = {} +'''A dict of loaded nodes keyed by their raw names. Compared to `comfy_script.runtime.real.nodes` module, `nodes` is more suitable for programmatic access. + +Example: +``` +from comfy_script.runtime.real import * +load() + +print(node.nodes) +# {'KSampler': , +# 'CheckpointLoaderSimple': , +# 'CLIPTextEncode': , +# ... + +# or: node.get('CheckpointLoaderSimple') +loader = node.nodes['CheckpointLoaderSimple'] +model, clip, vae = loader('v1-5-pruned-emaonly.ckpt') +``` + +With type hint: +``` +from comfy_script.runtime.real.nodes import CheckpointLoaderSimple + +loader: type[CheckpointLoaderSimple] = node.nodes['CheckpointLoaderSimple'] +model, clip, vae = loader('v1-5-pruned-emaonly.ckpt') +``` + +With compile-time-only type hint: +``` +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from comfy_script.runtime.real.nodes import CheckpointLoaderSimple + +loader: 'type[CheckpointLoaderSimple]' = node.nodes['CheckpointLoaderSimple'] +model, clip, vae = loader('v1-5-pruned-emaonly.ckpt') +``` +''' + +def get(name: str) -> Any | None: + return nodes.get(name, None) + +__all__ = [ + 'nodes', + 'get' +] \ No newline at end of file diff --git a/src/comfy_script/runtime/real/nodes.py b/src/comfy_script/runtime/real/nodes.py index e4b80fc..a59375e 100644 --- a/src/comfy_script/runtime/real/nodes.py +++ b/src/comfy_script/runtime/real/nodes.py @@ -10,7 +10,7 @@ from .. import factory from ..nodes import _positional_args_to_keyword, Node as VirtualNode -async def load(nodes_info: dict, vars: dict | None, config: real.RealModeConfig) -> None: +async def load(nodes_info: dict, vars: dict | None, config: real.RealModeConfig, *, nodes: dict[str, typing.Any] | None = None) -> None: fact = RealRuntimeFactory(config) await fact.init() @@ -21,6 +21,9 @@ async def load(nodes_info: dict, vars: dict | None, config: real.RealModeConfig) print(f'ComfyScript: Failed to load node {node_info["name"]}') traceback.print_exc() + if nodes is not None: + nodes.update(fact.nodes) + globals().update(fact.vars()) __all__.extend(fact.vars().keys())