diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 5ee56213..bbd39b4e 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -19,6 +19,30 @@ def __init__(self, name: str, **kwargs): def __call__(self, messages: List[Message]) -> Optional[str]: raise NotImplementedError("Subclasses must implement this method") + def _check_for_mandatory_inputs( + self, inputs: dict[str, Any], mandatory_params: List[str] + ) -> bool: + """Check for mandatory parameters in inputs""" + for name in mandatory_params: + if name not in inputs: + logger.error(f"Mandatory input {name} missing from query") + return False + return True + + def _check_for_extra_inputs( + self, inputs: dict[str, Any], all_params: List[str] + ) -> bool: + """Check for extra parameters not defined in the signature""" + input_keys = set(inputs.keys()) + param_keys = set(all_params) + if not input_keys.issubset(param_keys): + extra_keys = input_keys - param_keys + logger.error( + f"Extra inputs provided that are not in the signature: {extra_keys}" + ) + return False + return True + def _is_valid_inputs( self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]] ) -> bool: @@ -48,17 +72,33 @@ def _validate_single_function_inputs( ) -> bool: """Validate the extracted inputs against the function schema""" try: - # Extract parameter names and types from the signature string + # Extract parameter names and determine if they are optional signature = function_schema["signature"] param_info = [param.strip() for param in signature[1:-1].split(",")] - param_names = [info.split(":")[0].strip() for info in param_info] - param_types = [ - info.split(":")[1].strip().split("=")[0].strip() for info in param_info - ] - for name, type_str in zip(param_names, param_types): - if name not in inputs: - logger.error(f"Input {name} missing from query") - return False + mandatory_params = [] + all_params = [] + + for info in param_info: + parts = info.split("=") + name_type_pair = parts[0].strip() + if ":" in name_type_pair: + name, _ = name_type_pair.split(":") + else: + name = name_type_pair + all_params.append(name) + + # If there is no default value, it's a mandatory parameter + if len(parts) == 1: + mandatory_params.append(name) + + # Check for mandatory parameters + if not self._check_for_mandatory_inputs(inputs, mandatory_params): + return False + + # Check for extra parameters not defined in the signature + if not self._check_for_extra_inputs(inputs, all_params): + return False + return True except Exception as e: logger.error(f"Single input validation error: {str(e)}") @@ -124,7 +164,7 @@ def extract_function_inputs( === EXAMPLE_OUTPUT End === ### EXAMPLE End ### -Note: I will tip $500 for and accurate JSON output. You will be penalized for an inaccurate JSON output. +Note: I will tip $500 for an accurate JSON output. You will be penalized for an inaccurate JSON output. Provide JSON output now: """ diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index 680b5d2d..3699ded0 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -4,10 +4,29 @@ class TestBaseLLM: + @pytest.fixture def base_llm(self): return BaseLLM(name="TestLLM") + @pytest.fixture + def mixed_function_schema(self): + return [ + { + "name": "test_function", + "description": "A test function with mixed mandatory and optional parameters.", + "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')", + } + ] + + @pytest.fixture + def mandatory_params(self): + return ["param1", "param2"] + + @pytest.fixture + def all_params(self): + return ["param1", "param2", "optional1"] + def test_base_llm_initialization(self, base_llm): assert base_llm.name == "TestLLM", "Initialization of name failed" @@ -72,6 +91,59 @@ def test_base_llm_extract_function_inputs_no_output(self, base_llm, mocker): test_query = "What time is it in America/New_York?" base_llm.extract_function_inputs(test_schema, test_query) + def test_mandatory_args_only(self, base_llm, mixed_function_schema): + inputs = [{"mandatory1": "value1", "mandatory2": 42}] + assert base_llm._is_valid_inputs( + inputs, mixed_function_schema + ) # True is implied + + def test_all_args_provided(self, base_llm, mixed_function_schema): + inputs = [ + { + "mandatory1": "value1", + "mandatory2": 42, + "optional1": "opt1", + "optional2": "opt2", + } + ] + assert base_llm._is_valid_inputs( + inputs, mixed_function_schema + ) # True is implied + + def test_missing_mandatory_arg(self, base_llm, mixed_function_schema): + inputs = [{"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"}] + assert not base_llm._is_valid_inputs(inputs, mixed_function_schema) + + def test_extra_arg_provided(self, base_llm, mixed_function_schema): + inputs = [ + { + "mandatory1": "value1", + "mandatory2": 42, + "optional1": "opt1", + "optional2": "opt2", + "extra": "value", + } + ] + assert not base_llm._is_valid_inputs(inputs, mixed_function_schema) + + def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params): + inputs = {"param1": "value1", "param2": "value2"} + assert base_llm._check_for_mandatory_inputs( + inputs, mandatory_params + ) # True is implied + + def test_check_for_mandatory_inputs_missing_one(self, base_llm, mandatory_params): + inputs = {"param1": "value1"} + assert not base_llm._check_for_mandatory_inputs(inputs, mandatory_params) + + def test_check_for_extra_inputs_no_extras(self, base_llm, all_params): + inputs = {"param1": "value1", "param2": "value2"} + assert base_llm._check_for_extra_inputs(inputs, all_params) # True is implied + + def test_check_for_extra_inputs_with_extras(self, base_llm, all_params): + inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"} + assert not base_llm._check_for_extra_inputs(inputs, all_params) + def test_is_valid_inputs_multiple_inputs(self, base_llm, mocker): # Mock the logger to capture the error messages mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") @@ -139,7 +211,7 @@ def test_validate_single_function_inputs_exception_handling(self, base_llm, mock malformed_function_schema = { "name": "get_time", "description": "Finds the current time in a specific timezone.", - "signature": "(timezone str)", # Malformed signature missing colon + "signiture": "(timezone: str)", # Malformed key name "output": "", } @@ -152,7 +224,7 @@ def test_validate_single_function_inputs_exception_handling(self, base_llm, mock assert not result, "Method should return False when an exception occurs" # Check that the appropriate error message was logged - expected_error_message = "Single input validation error: list index out of range" # Adjust based on the actual exception message + expected_error_message = "Single input validation error: 'signature'" # Adjust based on the actual exception message mocked_logger.assert_called_once_with(expected_error_message) def test_extract_parameter_info_valid(self, base_llm):