Skip to content

Commit

Permalink
More WIP on dynamic typing
Browse files Browse the repository at this point in the history
  • Loading branch information
guill committed Aug 24, 2024
1 parent b29521e commit 493023f
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 47 deletions.
35 changes: 31 additions & 4 deletions comfy_execution/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,34 @@ def all_node_ids(self):
def get_original_prompt(self):
return self.original_prompt

def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
def node_class_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes")
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
info['output_node'] = True
else:
info['output_node'] = False

if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY

if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
return info


def get_input_info(node_info, input_name):
valid_inputs = node_info["input"]
input_info = None
input_category = None
if "required" in valid_inputs and input_name in valid_inputs["required"]:
Expand Down Expand Up @@ -86,7 +112,7 @@ def __init__(self, dynprompt):
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return get_input_info(class_def, input_name)
return get_input_info(node_class_info(class_def), input_name)

def make_input_strong_link(self, to_node_id, to_input):
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
Expand Down Expand Up @@ -119,7 +145,8 @@ def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
node_info = node_class_info(self.dynprompt.get_node(from_node_id)["class_type"])
input_type, input_category, input_info = self.get_input_info(node_info, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if include_lazy or not is_lazy:
self.add_strong_link(from_node_id, from_socket, unique_id)
Expand Down
106 changes: 92 additions & 14 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,109 @@
import traceback
from enum import Enum
import inspect
from typing import List, Literal, NamedTuple, Optional
from typing import List, Literal, NamedTuple, Optional, Dict, Tuple

import torch
import nodes

import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker, node_class_info
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy.cli_args import args

def resolve_dynamic_types(prompt):
output = {}
for node_id in prompt:
node = prompt[node_id]
def get_input_output_types(dynprompt, node_id, info, output_lookup) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
node = dynprompt.get_node(node_id)
input_types: Dict[str, str] = {}
for input_name, input_data in node["inputs"].items():
if is_link(input_data):
from_node_id, from_socket = input_data
if from_socket < len(info[from_node_id]["output_name"]):
input_types[input_name] = info[from_node_id]["output"][from_socket]
else:
input_types[input_name] = "*"
output_types: Dict[str, List[str]] = {}
for index in range(len(info[node_id]["output_name"])):
output_name = info[node_id]["output_name"][index]
if (node_id, index) not in output_lookup:
continue
for (to_node_id, to_input_name) in output_lookup[(node_id, index)]:
if output_name not in output_types:
output_types[output_name] = []
if to_input_name in info[to_node_id]["input"]["required"]:
output_types[output_name].append(info[to_node_id]["input"]["required"][to_input_name][0])
elif to_input_name in info[to_node_id]["input"]["optional"]:
output_types[output_name].append(info[to_node_id]["input"]["optional"][to_input_name][0])
return input_types, output_types

def resolve_dynamic_types(dynprompt: DynamicPrompt):
info = {}
output_lookup = {}
entangled = {}
# Pre-fill with class info
for node_id in dynprompt.all_node_ids():
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
info[node_id] = node_class_info(class_type)
for input_name, input_data in node["inputs"].items():
if is_link(input_data):
input_tuple = tuple(input_data)
if input_tuple not in output_lookup:
output_lookup[input_tuple] = []
output_lookup[input_tuple].append((node_id, input_name))
_, _, extra_info = get_input_info(info[node_id], input_name)
if extra_info is not None and extra_info.get("entangleTypes", False):
from_node_id = input_data[0]
if node_id not in entangled:
entangled[node_id] = []
if from_node_id not in entangled:
entangled[from_node_id] = []

entangled[node_id].append((from_node_id, input_name))
entangled[from_node_id].append((node_id, input_name))

# Evaluate node info
to_resolve = set(info.keys())
result = {}
while len(to_resolve) > 0:
node_id = to_resolve.pop()
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, "resolve_dynamic_types"):
inputs, outputs, output_names = class_def.resolve_dynamic_types(node_id, prompt)
output[node_id] = {
"input": inputs,
"output": outputs,
"output_name": output_names,
}
return output
entangled_types = {}
for (entangled_id, entangled_name) in entangled.get(node_id, []):
input_types, output_types = get_input_output_types(dynprompt, entangled_id, info, output_lookup)
if entangled_name not in entangled_types:
entangled_types[entangled_name] = []
entangled_types[entangled_name].append({
"node_id": entangled_id,
"input_types": input_types,
"output_types": output_types
})

input_types, output_types = get_input_output_types(dynprompt, node_id, info, output_lookup)
dynamic_info = class_def.resolve_dynamic_types(
node_id=node_id,
input_types=input_types,
output_types=output_types,
entangled_types=entangled_types
)
old_info = info[node_id].copy()
info[node_id].update(dynamic_info)
# We changed the info, so we potentially need to resolve adjacent and entangled nodes
if old_info != info[node_id]:
result[node_id] = info[node_id]
for (entangled_node_id, _) in entangled.get(node_id, []):
to_resolve.add(entangled_node_id)
for i in range(len(info[node_id]["output"])):
for (output_node_id, _) in output_lookup.get((node_id, i), []):
to_resolve.add(output_node_id)
for _, input_data in node["inputs"].items():
if is_link(input_data):
to_resolve.add(input_data[0])
return result


class ExecutionResult(Enum):
Expand Down Expand Up @@ -557,7 +635,7 @@ def validate_inputs(prompt, item, validated):
received_types = {}

for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x)
type_input, input_category, extra_info = get_input_info(node_class_info(obj_class), x)
assert extra_info is not None
if x not in inputs:
if input_category == "required":
Expand Down
34 changes: 5 additions & 29 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from app.user_manager import UserManager
from model_filemanager import download_model, DownloadModelStatus
from typing import Optional
from comfy_execution.graph import DynamicPrompt, node_class_info

class BinaryEventTypes:
PREVIEW_IMAGE = 1
Expand Down Expand Up @@ -419,37 +420,12 @@ async def get_queue(request):
async def get_prompt(request):
return web.json_response(self.get_queue_info())

def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes")
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
info['output_node'] = True
else:
info['output_node'] = False

if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY

if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
return info

@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
try:
out[x] = node_info(x)
out[x] = node_class_info(x)
except Exception as e:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc())
Expand All @@ -460,7 +436,7 @@ async def get_object_info_node(request):
node_class = request.match_info.get("node_class", None)
out = {}
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
out[node_class] = node_info(node_class)
out[node_class] = node_class_info(node_class)
return web.json_response(out)

@routes.get("/history")
Expand Down Expand Up @@ -489,8 +465,8 @@ async def resolve_dynamic_types(request):
if 'prompt' not in json_data:
return web.json_response({"error": "no prompt"}, status=400)
prompt = json_data['prompt']
logging.info("Resolving dynamic types", prompt)
resolved = execution.resolve_dynamic_types(prompt)
dynprompt = DynamicPrompt(prompt)
resolved = execution.resolve_dynamic_types(dynprompt)
return web.json_response(resolved)


Expand Down

0 comments on commit 493023f

Please sign in to comment.