Skip to content

Commit

Permalink
feat: add UTs for get_trial_metrics.
Browse files Browse the repository at this point in the history
Signed-off-by: Electronic-Waste <[email protected]>
  • Loading branch information
Electronic-Waste committed Jul 29, 2024
1 parent ad7de0a commit 85d034f
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 11 deletions.
96 changes: 93 additions & 3 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from kubeflow.katib import V1beta1TrialParameterSpec
from kubeflow.katib import V1beta1TrialTemplate
from kubeflow.katib.constants import constants
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
from kubernetes.client import V1ObjectMeta
import pytest

Expand Down Expand Up @@ -238,7 +239,7 @@ def create_experiment(


@pytest.fixture
def katib_client():
def katib_client_create_experiment():
with patch(
"kubernetes.client.CustomObjectsApi",
return_value=Mock(
Expand All @@ -255,14 +256,103 @@ def katib_client():


@pytest.mark.parametrize("test_name,kwargs,expected_output", test_create_experiment_data)
def test_create_experiment(katib_client, test_name, kwargs, expected_output):
def test_create_experiment(katib_client_create_experiment, test_name, kwargs, expected_output):
"""
test create_experiment function of katib client
"""
print("\n\nExecuting test:", test_name)
try:
katib_client.create_experiment(**kwargs)
katib_client_create_experiment.create_experiment(**kwargs)
assert expected_output == TEST_RESULT_SUCCESS
except Exception as e:
assert type(e) is expected_output
print("test execution complete")


def get_observation_log_response(*args, **kwargs):
if kwargs.get("timeout") == 0:
raise TimeoutError
elif args[0].trial_name == "invalid":
raise RuntimeError
else:
return katib_api_pb2.GetObservationLogReply(
observation_log=katib_api_pb2.ObservationLog(
metric_logs=[
katib_api_pb2.MetricLog(
time_stamp="2024-07-29T15:09:08Z",
metric=katib_api_pb2.Metric(name="result",value="0.99")
)
]
)
)

test_get_trial_metrics_data = [
(
"valid trial name",
{
"name": "example",
"namespace": "valid",
"timeout": constants.DEFAULT_TIMEOUT
},
[
katib_api_pb2.MetricLog(
time_stamp="2024-07-29T15:09:08Z",
metric=katib_api_pb2.Metric(name="result",value="0.99")
)
]
),
(
"invalid trial name",
{
"name": "invalid",
"namespace": "invalid",
"timeout": constants.DEFAULT_TIMEOUT
},
RuntimeError
),
(
"GetObservationLog timeout error",
{
"name": "example",
"namespace": "valid",
"timeout": 0
},
RuntimeError
)
]


@pytest.fixture
def katib_client_get_trial_metrics():
with patch(
"kubernetes.client.CustomObjectsApi",
return_value=Mock(),
), patch(
"kubernetes.config.load_kube_config",
return_value=Mock()
):
client = KatibClient()
yield client


@pytest.fixture
def mock_get_observation_log():
with patch("kubeflow.katib.katib_api_pb2_grpc.DBManagerStub") as mock:
mock_instance = mock.return_value
mock_instance.GetObservationLog.side_effect = get_observation_log_response
yield mock_instance


@pytest.mark.parametrize("test_name,kwargs,expected_output", test_get_trial_metrics_data)
def test_get_trial_metrics(test_name, kwargs, expected_output, katib_client_get_trial_metrics, mock_get_observation_log):
"""
test get_trial_metrics function of katib client
"""
print("\n\nExecuting test:", test_name)
try:
metrics = katib_client_get_trial_metrics.get_trial_metrics(**kwargs)
for i in range(len(metrics)):
assert metrics[i] == expected_output[i]
except Exception as e:
assert type(e) is expected_output
print("test execution complete")
10 changes: 2 additions & 8 deletions sdk/python/v1beta1/kubeflow/katib/api/report_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def report_observation_log_response(*args, **kwargs):
"metrics": {
"result": 0.99
},
"db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS,
"timeout": constants.DEFAULT_TIMEOUT

},
Expand All @@ -34,7 +33,6 @@ def report_observation_log_response(*args, **kwargs):
"metrics": {
"result": "0.99"
},
"db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS,
"timeout": constants.DEFAULT_TIMEOUT
},
TEST_RESULT_SUCCESS,
Expand All @@ -46,7 +44,6 @@ def report_observation_log_response(*args, **kwargs):
"metrics": {
"result": 1
},
"db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS,
"timeout": constants.DEFAULT_TIMEOUT
},
TEST_RESULT_SUCCESS,
Expand All @@ -58,7 +55,6 @@ def report_observation_log_response(*args, **kwargs):
"metrics": {
"result": 0.99
},
"db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS,
"timeout": 0
},
RuntimeError,
Expand All @@ -70,7 +66,6 @@ def report_observation_log_response(*args, **kwargs):
"metrics": {
"result": "abc"
},
"db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS,
"timeout": constants.DEFAULT_TIMEOUT
},
ValueError,
Expand All @@ -82,7 +77,6 @@ def report_observation_log_response(*args, **kwargs):
"metrics": {
"result": 0.99
},
"db_manager_address": constants.DEFAULT_DB_MANAGER_ADDRESS,
"timeout": constants.DEFAULT_TIMEOUT
},
ValueError,
Expand Down Expand Up @@ -117,7 +111,7 @@ def mock_report_observation_log():


@pytest.mark.parametrize(
"test_name, kwargs, expected_output, mock_getenv",
"test_name,kwargs,expected_output,mock_getenv",
test_report_metrics_data,
indirect=["mock_getenv"]
)
Expand All @@ -131,4 +125,4 @@ def test_report_metrics(test_name, kwargs, expected_output, mock_getenv, mock_ge
assert expected_output == TEST_RESULT_SUCCESS
except Exception as e:
assert type(e) is expected_output
print("test execution complete")
print("test execution complete")

0 comments on commit 85d034f

Please sign in to comment.