Skip to content

Commit

Permalink
Linting.
Browse files Browse the repository at this point in the history
  • Loading branch information
Siraj-Aizlewood committed May 13, 2024
1 parent b040c16 commit c6e9f85
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions tests/unit/llms/test_llm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ def base_llm(self):

@pytest.fixture
def mixed_function_schema(self):
return [{
"name": "test_function",
"description": "A test function with mixed mandatory and optional parameters.",
"signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')",
}]
return [
{
"name": "test_function",
"description": "A test function with mixed mandatory and optional parameters.",
"signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')",
}
]

@pytest.fixture
def mandatory_params(self):
Expand Down Expand Up @@ -96,12 +98,14 @@ def test_mandatory_args_only(self, base_llm, mixed_function_schema):
) # True is implied

def test_all_args_provided(self, base_llm, mixed_function_schema):
inputs = [{
"mandatory1": "value1",
"mandatory2": 42,
"optional1": "opt1",
"optional2": "opt2",
}]
inputs = [
{
"mandatory1": "value1",
"mandatory2": 42,
"optional1": "opt1",
"optional2": "opt2",
}
]
assert base_llm._is_valid_inputs(
inputs, mixed_function_schema
) # True is implied
Expand All @@ -111,13 +115,15 @@ def test_missing_mandatory_arg(self, base_llm, mixed_function_schema):
assert not base_llm._is_valid_inputs(inputs, mixed_function_schema)

def test_extra_arg_provided(self, base_llm, mixed_function_schema):
inputs = [{
"mandatory1": "value1",
"mandatory2": 42,
"optional1": "opt1",
"optional2": "opt2",
"extra": "value",
}]
inputs = [
{
"mandatory1": "value1",
"mandatory2": 42,
"optional1": "opt1",
"optional2": "opt2",
"extra": "value",
}
]
assert not base_llm._is_valid_inputs(inputs, mixed_function_schema)

def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params):
Expand Down Expand Up @@ -201,7 +207,7 @@ def test_validate_single_function_inputs_exception_handling(self, base_llm, mock
mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error")

# Prepare inputs and a malformed function schema
test_inputs = {"timezone": "America/New_York"}
test_inputs = {"timezone": "America/New_York"}
malformed_function_schema = {
"name": "get_time",
"description": "Finds the current time in a specific timezone.",
Expand Down

0 comments on commit c6e9f85

Please sign in to comment.