From 85d034fc44b121ee19a3148b51b01733e27878b1 Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Mon, 29 Jul 2024 10:28:26 +0000 Subject: [PATCH] feat: add UTs for get_trial_metrics. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- .../kubeflow/katib/api/katib_client_test.py | 96 ++++++++++++++++++- .../kubeflow/katib/api/report_metrics_test.py | 10 +- 2 files changed, 95 insertions(+), 11 deletions(-) diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py index 7f9f5e3fdba..19b82939c7f 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py @@ -13,6 +13,7 @@ from kubeflow.katib import V1beta1TrialParameterSpec from kubeflow.katib import V1beta1TrialTemplate from kubeflow.katib.constants import constants +import kubeflow.katib.katib_api_pb2 as katib_api_pb2 from kubernetes.client import V1ObjectMeta import pytest @@ -238,7 +239,7 @@ def create_experiment( @pytest.fixture -def katib_client(): +def katib_client_create_experiment(): with patch( "kubernetes.client.CustomObjectsApi", return_value=Mock( @@ -255,14 +256,103 @@ def katib_client(): @pytest.mark.parametrize("test_name,kwargs,expected_output", test_create_experiment_data) -def test_create_experiment(katib_client, test_name, kwargs, expected_output): +def test_create_experiment(katib_client_create_experiment, test_name, kwargs, expected_output): """ test create_experiment function of katib client """ print("\n\nExecuting test:", test_name) try: - katib_client.create_experiment(**kwargs) + katib_client_create_experiment.create_experiment(**kwargs) assert expected_output == TEST_RESULT_SUCCESS except Exception as e: assert type(e) is expected_output print("test execution complete") + + +def get_observation_log_response(*args, **kwargs): + if kwargs.get("timeout") == 0: + raise TimeoutError + elif args[0].trial_name == "invalid": + raise RuntimeError + else: + return katib_api_pb2.GetObservationLogReply( + observation_log=katib_api_pb2.ObservationLog( + metric_logs=[ + katib_api_pb2.MetricLog( + time_stamp="2024-07-29T15:09:08Z", + metric=katib_api_pb2.Metric(name="result",value="0.99") + ) + ] + ) + ) + +test_get_trial_metrics_data = [ + ( + "valid trial name", + { + "name": "example", + "namespace": "valid", + "timeout": constants.DEFAULT_TIMEOUT + }, + [ + katib_api_pb2.MetricLog( + time_stamp="2024-07-29T15:09:08Z", + metric=katib_api_pb2.Metric(name="result",value="0.99") + ) + ] + ), + ( + "invalid trial name", + { + "name": "invalid", + "namespace": "invalid", + "timeout": constants.DEFAULT_TIMEOUT + }, + RuntimeError + ), + ( + "GetObservationLog timeout error", + { + "name": "example", + "namespace": "valid", + "timeout": 0 + }, + RuntimeError + ) +] + + +@pytest.fixture +def katib_client_get_trial_metrics(): + with patch( + "kubernetes.client.CustomObjectsApi", + return_value=Mock(), + ), patch( + "kubernetes.config.load_kube_config", + return_value=Mock() + ): + client = KatibClient() + yield client + + +@pytest.fixture +def mock_get_observation_log(): + with patch("kubeflow.katib.katib_api_pb2_grpc.DBManagerStub") as mock: + mock_instance = mock.return_value + mock_instance.GetObservationLog.side_effect = get_observation_log_response + yield mock_instance + + +@pytest.mark.parametrize("test_name,kwargs,expected_output", test_get_trial_metrics_data) +def test_get_trial_metrics(test_name, kwargs, expected_output, katib_client_get_trial_metrics, mock_get_observation_log): + """ + test get_trial_metrics function of katib client + """ + print("\n\nExecuting test:", test_name) + try: + metrics = katib_client_get_trial_metrics.get_trial_metrics(**kwargs) + for i in range(len(metrics)): + assert metrics[i] == expected_output[i] + except Exception as e: + assert type(e) is expected_output + print("test execution complete") diff --git a/sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py b/sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py index a1ec51b4103..129afcf4577 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py @@ -21,7 +21,6 @@ def report_observation_log_response(*args, **kwargs): "metrics": { "result": 0.99 }, - "db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS, "timeout": constants.DEFAULT_TIMEOUT }, @@ -34,7 +33,6 @@ def report_observation_log_response(*args, **kwargs): "metrics": { "result": "0.99" }, - "db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS, "timeout": constants.DEFAULT_TIMEOUT }, TEST_RESULT_SUCCESS, @@ -46,7 +44,6 @@ def report_observation_log_response(*args, **kwargs): "metrics": { "result": 1 }, - "db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS, "timeout": constants.DEFAULT_TIMEOUT }, TEST_RESULT_SUCCESS, @@ -58,7 +55,6 @@ def report_observation_log_response(*args, **kwargs): "metrics": { "result": 0.99 }, - "db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS, "timeout": 0 }, RuntimeError, @@ -70,7 +66,6 @@ def report_observation_log_response(*args, **kwargs): "metrics": { "result": "abc" }, - "db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS, "timeout": constants.DEFAULT_TIMEOUT }, ValueError, @@ -82,7 +77,6 @@ def report_observation_log_response(*args, **kwargs): "metrics": { "result": 0.99 }, - "db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS, "timeout": constants.DEFAULT_TIMEOUT }, ValueError, @@ -117,7 +111,7 @@ def mock_report_observation_log(): @pytest.mark.parametrize( - "test_name, kwargs, expected_output, mock_getenv", + "test_name,kwargs,expected_output,mock_getenv", test_report_metrics_data, indirect=["mock_getenv"] ) @@ -131,4 +125,4 @@ def test_report_metrics(test_name, kwargs, expected_output, mock_getenv, mock_ge assert expected_output == TEST_RESULT_SUCCESS except Exception as e: assert type(e) is expected_output - print("test execution complete") \ No newline at end of file + print("test execution complete")