Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update test cases for AAP-32080 #1327

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 0 additions & 40 deletions ansible_ai_connect/ai/api/model_client/tests/test_wca_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,26 +399,6 @@ def test_playbook_gen_error(self):
model_client.generate_playbook(request, text="Install Wordpress", create_outline=True)
self.assertInLog("Caught retryable error after 1 tries.", log)

def test_playbook_gen_model_id(self):
manstis marked this conversation as resolved.
Show resolved Hide resolved
self.assertion_count = 0
request = Mock()
model_client = WCAClient(inference_url="http://example.com/")
model_client.get_api_key = Mock(return_value="some-key")
model_client.get_token = Mock(return_value={"access_token": "a-token"})
model_client.session = Mock()

def get_my_model_id(user, organization_id, model_id):
self.assertEqual(model_id, "mymodel")
self.assertion_count += 1
return model_id

model_client.get_model_id = get_my_model_id

model_client.generate_playbook(
request, text="Install Wordpress", create_outline=True, model_id="mymodel"
)
self.assertGreater(self.assertion_count, 0)

@assert_call_count_metrics(metric=wca_explain_playbook_hist)
def test_playbook_exp(self):
request = Mock()
Expand Down Expand Up @@ -478,26 +458,6 @@ def test_playbook_exp_error(self):
model_client.explain_playbook(request, content="Some playbook")
self.assertInLog("Caught retryable error after 1 tries.", log)

def test_playbook_exp_model_id(self):
manstis marked this conversation as resolved.
Show resolved Hide resolved
request = Mock()
model_client = WCAClient(inference_url="http://example.com/")
model_client.get_api_key = Mock(return_value="some-key")
model_client.get_token = Mock(return_value={"access_token": "a-token"})
model_client.session = Mock()

self.assertion_count = 0

def get_my_model_id(user, organization_id, model_id):
self.assertEqual(model_id, "mymodel")
self.assertion_count += 1
return model_id

model_client.get_model_id = get_my_model_id

model_client.explain_playbook(request, content="Some playbook", model_id="mymodel")
self.assertGreater(self.assertion_count, 0)
self.assertion_count = 0

@assert_call_count_metrics(metric=wca_codegen_playbook_hist)
def test_playbook_gen_no_org(self):
request = Mock()
Expand Down
1 change: 1 addition & 0 deletions ansible_ai_connect/ai/api/model_client/wca_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def get_model_id(
organization_id: Optional[int] = None,
requested_model_id: str = "",
) -> str:
logger.debug(f"requested_model_id={requested_model_id}")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new log line used only for testing.

if not organization_id and user.organization:
# TODO: the organization_id parameter should be removed
organization_id = user.organization.id # type: ignore[reportAttributeAccessIssue]
Expand Down
78 changes: 40 additions & 38 deletions ansible_ai_connect/ai/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2453,6 +2453,7 @@ def test_ok_with_model_id(self):
r = self.client.post(reverse("explanations"), payload, format="json")
segment_events = self.extractSegmentEventsFromLog(log)
self.assertEqual(segment_events[0]["properties"]["playbook_length"], 197)
self.assertEqual(segment_events[0]["properties"]["modelName"], "mymodel")
self.assertEqual(r.status_code, HTTPStatus.OK)
self.assertIsNotNone(r.data["content"])
self.assertEqual(r.data["format"], "markdown")
Expand Down Expand Up @@ -2677,7 +2678,8 @@ def stub_wca_client(
model_client = WCAClient(inference_url="https://wca_api_url")
model_client.session.post = Mock(return_value=response)
model_client.get_api_key = mock_api_key
model_client.get_model_id = mock_model_id
if mock_model_id:
model_client.get_model_id = mock_model_id
model_client.get_token = Mock(return_value={"access_token": "abc"})
return model_client

Expand All @@ -2691,11 +2693,12 @@ def assert_test(
with self.assertLogs(logger="root", level="DEBUG") as log:
r = self.client.post(reverse("explanations"), self.payload, format="json")
self.assertEqual(r.status_code, expected_status_code)
if expected_status_code != HTTPStatus.OK:
if expected_exception() is not None:
self.assert_error_detail(
r, expected_exception().default_code, expected_exception().default_detail
)
self.assertInLog(expected_log_message, log)
return r

def test_bad_wca_request(self):
model_client = self.stub_wca_client(
Expand Down Expand Up @@ -2838,32 +2841,30 @@ def test_wca_instance_deleted(self):
)

def test_wca_request_with_model_id_given(self):
self.assertion_count = 0
self.payload["model"] = "mymodel"

def get_model_id(user, organization_id, model_id):
self.assertEqual(model_id, "mymodel")
self.assertion_count += 1
return model_id

model_client = self.stub_wca_client(200)
model_client.get_model_id = get_model_id
model_client = self.stub_wca_client(
200,
mock_model_id=None,
)
model_client.explain_playbook = lambda *args: {
"content": "string",
"format": "string",
"explanationId": "3fa85f64-5717-4562-b3fc-2c963f66afa6",
}

self.assert_test(
model_client,
HTTPStatus.OK,
None,
None,
)
self.assertGreater(self.assertion_count, 0)
with self.assertLogs(
logger="ansible_ai_connect.ai.api.model_client.wca_client", level="DEBUG"
) as log:
self.assert_test(
model_client,
HTTPStatus.OK,
lambda: None,
None,
)
self.assertInLog("requested_model_id=mymodel", log)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit overkill IMO since it doesn't really improve the quality of the code and it's testing a behaviour that is obvious.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before AAP-32080, Explanations/Generation views did not specify requested_model_id when they call wca_client's get_model_id command. I agree it looks obvious, but this line ensures the requested_model_id is set as expected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @TamiTakamiya whilst the test may be an overkill it does detect and protect against a known issue.

However, I do struggle with these tests and the changes you made to support them. This class tests the view and it should be enough to ensure we call ModelMeshClient's method with the correct model_id. I'd then have tests for WCAClient that check the correct model_id is used e.g. https://github.com/ansible/ansible-ai-connect-service/blob/main/ansible_ai_connect/ai/api/model_client/tests/test_wca_client.py#L1259-L1264

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we've probably spent more time discussing these tests than they warrant.

I just think they smell weird.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goneri What do you think? I am open to remove this line.



@override_settings(ANSIBLE_AI_MODEL_MESH_API_TYPE="dummy")
@override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE")
class TestGenerationView(WisdomAppsBackendMocking, WisdomServiceAPITestCaseBase):

response_data = """yaml
Expand Down Expand Up @@ -2929,7 +2930,10 @@ def test_ok_with_model_id(self):
"model": model,
}
self.client.force_authenticate(user=self.user)
r = self.client.post(reverse("generations"), payload, format="json")
with self.assertLogs(logger="root", level="DEBUG") as log:
r = self.client.post(reverse("generations"), payload, format="json")
segment_events = self.extractSegmentEventsFromLog(log)
self.assertEqual(segment_events[0]["properties"]["modelName"], "mymodel")
self.assertEqual(r.status_code, HTTPStatus.OK)
self.assertIsNotNone(r.data["playbook"])
self.assertEqual(r.data["format"], "plaintext")
Expand Down Expand Up @@ -3191,7 +3195,8 @@ def stub_wca_client(
model_client = WCAClient(inference_url="https://wca_api_url")
model_client.session.post = Mock(return_value=response)
model_client.get_api_key = mock_api_key
model_client.get_model_id = mock_model_id
if mock_model_id:
model_client.get_model_id = mock_model_id
model_client.get_token = Mock(return_value={"access_token": "abc"})
return model_client

Expand All @@ -3205,7 +3210,7 @@ def assert_test(
with self.assertLogs(logger="root", level="DEBUG") as log:
r = self.client.post(reverse("generations"), self.payload, format="json")
self.assertEqual(r.status_code, expected_status_code)
if expected_status_code != 200:
if expected_exception() is not None:
self.assert_error_detail(
r, expected_exception().default_code, expected_exception().default_detail
)
Expand Down Expand Up @@ -3353,26 +3358,23 @@ def test_wca_instance_deleted(self):
)

def test_wca_request_with_model_id_given(self):
self.assertion_count = 0
self.payload["model"] = "mymodel"
model_client = self.stub_wca_client(
200,
mock_model_id=None,
)
model_client.generate_playbook = lambda *args: ("playbook", "outline", "warning")

def get_model_id(user, organization_id, model_id):
self.assertEqual(model_id, "mymodel")
self.assertion_count += 1
return model_id

model_client.get_model_id = get_model_id
model_client.generate_playbook = lambda *args: ("playbook", "outline", "warnings")
self.assert_test(
model_client,
HTTPStatus.OK,
None,
None,
)
self.assertGreater(self.assertion_count, 0)
with self.assertLogs(
logger="ansible_ai_connect.ai.api.model_client.wca_client", level="DEBUG"
) as log:
self.assert_test(
model_client,
HTTPStatus.OK,
lambda: None,
None,
)
self.assertInLog("requested_model_id=mymodel", log)

def test_warnings(self):
model_client = self.stub_wca_client(
Expand All @@ -3382,7 +3384,7 @@ def test_warnings(self):
'"warnings": [{"id": "id-1", "message": '
'"Something went wrong", "details": "Some details"}]}',
)
r = self.assert_test(model_client, HTTPStatus.OK, None, None)
r = self.assert_test(model_client, HTTPStatus.OK, lambda: None, None)
self.assertTrue("warnings" in r.data)
warnings = r.data["warnings"]
self.assertEqual(1, len(warnings))
Expand Down
Loading