Skip to content

Commit

Permalink
Merge pull request #275 from aurelio-labs/fix_is_valid_inputs
Browse files Browse the repository at this point in the history
fix: is_valid_inputs in base.py
  • Loading branch information
jamescalam authored May 15, 2024
2 parents b781441 + 1e51c35 commit 2dd21e3
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 11 deletions.
58 changes: 49 additions & 9 deletions semantic_router/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,30 @@ def __init__(self, name: str, **kwargs):
def __call__(self, messages: List[Message]) -> Optional[str]:
raise NotImplementedError("Subclasses must implement this method")

def _check_for_mandatory_inputs(
self, inputs: dict[str, Any], mandatory_params: List[str]
) -> bool:
"""Check for mandatory parameters in inputs"""
for name in mandatory_params:
if name not in inputs:
logger.error(f"Mandatory input {name} missing from query")
return False
return True

def _check_for_extra_inputs(
self, inputs: dict[str, Any], all_params: List[str]
) -> bool:
"""Check for extra parameters not defined in the signature"""
input_keys = set(inputs.keys())
param_keys = set(all_params)
if not input_keys.issubset(param_keys):
extra_keys = input_keys - param_keys
logger.error(
f"Extra inputs provided that are not in the signature: {extra_keys}"
)
return False
return True

def _is_valid_inputs(
self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]]
) -> bool:
Expand Down Expand Up @@ -48,17 +72,33 @@ def _validate_single_function_inputs(
) -> bool:
"""Validate the extracted inputs against the function schema"""
try:
# Extract parameter names and types from the signature string
# Extract parameter names and determine if they are optional
signature = function_schema["signature"]
param_info = [param.strip() for param in signature[1:-1].split(",")]
param_names = [info.split(":")[0].strip() for info in param_info]
param_types = [
info.split(":")[1].strip().split("=")[0].strip() for info in param_info
]
for name, type_str in zip(param_names, param_types):
if name not in inputs:
logger.error(f"Input {name} missing from query")
return False
mandatory_params = []
all_params = []

for info in param_info:
parts = info.split("=")
name_type_pair = parts[0].strip()
if ":" in name_type_pair:
name, _ = name_type_pair.split(":")
else:
name = name_type_pair
all_params.append(name)

# If there is no default value, it's a mandatory parameter
if len(parts) == 1:
mandatory_params.append(name)

# Check for mandatory parameters
if not self._check_for_mandatory_inputs(inputs, mandatory_params):
return False

# Check for extra parameters not defined in the signature
if not self._check_for_extra_inputs(inputs, all_params):
return False

return True
except Exception as e:
logger.error(f"Single input validation error: {str(e)}")
Expand Down
76 changes: 74 additions & 2 deletions tests/unit/llms/test_llm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,29 @@


class TestBaseLLM:

@pytest.fixture
def base_llm(self):
return BaseLLM(name="TestLLM")

@pytest.fixture
def mixed_function_schema(self):
return [
{
"name": "test_function",
"description": "A test function with mixed mandatory and optional parameters.",
"signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')",
}
]

@pytest.fixture
def mandatory_params(self):
return ["param1", "param2"]

@pytest.fixture
def all_params(self):
return ["param1", "param2", "optional1"]

def test_base_llm_initialization(self, base_llm):
assert base_llm.name == "TestLLM", "Initialization of name failed"

Expand Down Expand Up @@ -72,6 +91,59 @@ def test_base_llm_extract_function_inputs_no_output(self, base_llm, mocker):
test_query = "What time is it in America/New_York?"
base_llm.extract_function_inputs(test_schema, test_query)

def test_mandatory_args_only(self, base_llm, mixed_function_schema):
inputs = [{"mandatory1": "value1", "mandatory2": 42}]
assert base_llm._is_valid_inputs(
inputs, mixed_function_schema
) # True is implied

def test_all_args_provided(self, base_llm, mixed_function_schema):
inputs = [
{
"mandatory1": "value1",
"mandatory2": 42,
"optional1": "opt1",
"optional2": "opt2",
}
]
assert base_llm._is_valid_inputs(
inputs, mixed_function_schema
) # True is implied

def test_missing_mandatory_arg(self, base_llm, mixed_function_schema):
inputs = [{"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"}]
assert not base_llm._is_valid_inputs(inputs, mixed_function_schema)

def test_extra_arg_provided(self, base_llm, mixed_function_schema):
inputs = [
{
"mandatory1": "value1",
"mandatory2": 42,
"optional1": "opt1",
"optional2": "opt2",
"extra": "value",
}
]
assert not base_llm._is_valid_inputs(inputs, mixed_function_schema)

def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params):
inputs = {"param1": "value1", "param2": "value2"}
assert base_llm._check_for_mandatory_inputs(
inputs, mandatory_params
) # True is implied

def test_check_for_mandatory_inputs_missing_one(self, base_llm, mandatory_params):
inputs = {"param1": "value1"}
assert not base_llm._check_for_mandatory_inputs(inputs, mandatory_params)

def test_check_for_extra_inputs_no_extras(self, base_llm, all_params):
inputs = {"param1": "value1", "param2": "value2"}
assert base_llm._check_for_extra_inputs(inputs, all_params) # True is implied

def test_check_for_extra_inputs_with_extras(self, base_llm, all_params):
inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"}
assert not base_llm._check_for_extra_inputs(inputs, all_params)

def test_is_valid_inputs_multiple_inputs(self, base_llm, mocker):
# Mock the logger to capture the error messages
mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error")
Expand Down Expand Up @@ -139,7 +211,7 @@ def test_validate_single_function_inputs_exception_handling(self, base_llm, mock
malformed_function_schema = {
"name": "get_time",
"description": "Finds the current time in a specific timezone.",
"signature": "(timezone str)", # Malformed signature missing colon
"signiture": "(timezone: str)", # Malformed key name
"output": "<class 'str'>",
}

Expand All @@ -152,7 +224,7 @@ def test_validate_single_function_inputs_exception_handling(self, base_llm, mock
assert not result, "Method should return False when an exception occurs"

# Check that the appropriate error message was logged
expected_error_message = "Single input validation error: list index out of range" # Adjust based on the actual exception message
expected_error_message = "Single input validation error: 'signature'" # Adjust based on the actual exception message
mocked_logger.assert_called_once_with(expected_error_message)

def test_extract_parameter_info_valid(self, base_llm):
Expand Down

0 comments on commit 2dd21e3

Please sign in to comment.