Skip to content

Commit

Permalink
Allow kwargs in VALIDATE_INPUTS functions
Browse files Browse the repository at this point in the history
When kwargs are used, validation is skipped for all inputs as if they
had been mentioned explicitly.
  • Loading branch information
guill committed Aug 8, 2024
1 parent 655548d commit 36131f0
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 4 deletions.
11 changes: 7 additions & 4 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,11 @@ def validate_inputs(prompt, item, validated):
valid = True

validate_function_inputs = []
validate_has_kwargs = False
if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
validate_function_inputs = argspec.args
validate_has_kwargs = argspec.varkw is not None
received_types = {}

for x in valid_inputs:
Expand Down Expand Up @@ -641,7 +644,7 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue

if x not in validate_function_inputs:
if x not in validate_function_inputs and not validate_has_kwargs:
if "min" in extra_info and val < extra_info["min"]:
error = {
"type": "value_smaller_than_min",
Expand Down Expand Up @@ -695,11 +698,11 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue

if len(validate_function_inputs) > 0:
if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs:
if x in validate_function_inputs or validate_has_kwargs:
input_filtered[x] = input_data_all[x]
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]
Expand Down
16 changes: 16 additions & 0 deletions tests/inference/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,22 @@ def test_validation_error_edge4(self, test_type, test_value, expect_error, clien
else:
client.run(g)

@pytest.mark.parametrize("test_value1, test_value2, expect_error", [
(0.0, 0.5, False),
(0.0, 5.0, False),
(0.0, 7.0, True)
])
def test_validation_error_kwargs(self, test_value1, test_value2, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
validation5 = g.node("TestCustomValidation5", input1=test_value1, input2=test_value2)
g.node("SaveImage", images=validation5.out(0))

if expect_error:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
else:
client.run(g)

def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
Expand Down
27 changes: 27 additions & 0 deletions tests/inference/testing_nodes/testing-pack/specific_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,31 @@ def VALIDATE_INPUTS(cls, input1, input2):

return True

class TestCustomValidation5:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("FLOAT", {"min": 0.0, "max": 1.0}),
"input2": ("FLOAT", {"min": 0.0, "max": 1.0}),
},
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation5"

CATEGORY = "Testing/Nodes"

def custom_validation5(self, input1, input2):
value = input1 * input2
return (torch.ones([1, 512, 512, 3]) * value,)

@classmethod
def VALIDATE_INPUTS(cls, **kwargs):
if kwargs['input2'] == 7.0:
return "7s are not allowed. I've never liked 7s."
return True

class TestDynamicDependencyCycle:
@classmethod
def INPUT_TYPES(cls):
Expand Down Expand Up @@ -291,6 +316,7 @@ def mixed_expansion_returns(self, input1):
"TestCustomValidation2": TestCustomValidation2,
"TestCustomValidation3": TestCustomValidation3,
"TestCustomValidation4": TestCustomValidation4,
"TestCustomValidation5": TestCustomValidation5,
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
"TestMixedExpansionReturns": TestMixedExpansionReturns,
}
Expand All @@ -303,6 +329,7 @@ def mixed_expansion_returns(self, input1):
"TestCustomValidation2": "Custom Validation 2",
"TestCustomValidation3": "Custom Validation 3",
"TestCustomValidation4": "Custom Validation 4",
"TestCustomValidation5": "Custom Validation 5",
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
"TestMixedExpansionReturns": "Mixed Expansion Returns",
}

0 comments on commit 36131f0

Please sign in to comment.