From 2e328889526233a06bf9a1ffbc21a91e6d85d02b Mon Sep 17 00:00:00 2001 From: Isabella do Amaral Date: Fri, 6 Sep 2024 10:15:24 -0300 Subject: [PATCH] api: fix ID field in orderBy Fixes: https://github.com/kubeflow/model-registry/issues/353 Signed-off-by: Isabella do Amaral --- api/openapi/model-registry.yaml | 4 +- .../src/mr_openapi/models/order_by_field.py | 2 +- clients/python/tests/test_client.py | 132 ++++++++++++++++++ pkg/openapi/model_order_by_field.go | 4 +- 4 files changed, 137 insertions(+), 5 deletions(-) diff --git a/api/openapi/model-registry.yaml b/api/openapi/model-registry.yaml index 15e8cef7c..64ce93788 100644 --- a/api/openapi/model-registry.yaml +++ b/api/openapi/model-registry.yaml @@ -1348,7 +1348,7 @@ components: enum: - CREATE_TIME - LAST_UPDATE_TIME - - Id + - ID type: string Artifact: oneOf: @@ -1661,7 +1661,7 @@ components: explode: true examples: orderBy: - value: Id + value: ID name: orderBy description: Specifies the order by criteria for listing entities. schema: diff --git a/clients/python/src/mr_openapi/models/order_by_field.py b/clients/python/src/mr_openapi/models/order_by_field.py index 8d0b2b388..5c43bba26 100644 --- a/clients/python/src/mr_openapi/models/order_by_field.py +++ b/clients/python/src/mr_openapi/models/order_by_field.py @@ -24,7 +24,7 @@ class OrderByField(str, Enum): """ CREATE_TIME = "CREATE_TIME" LAST_UPDATE_TIME = "LAST_UPDATE_TIME" - ID = "Id" + ID = "ID" @classmethod def from_json(cls, json_str: str) -> Self: diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 99223a088..03bc56068 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -203,6 +203,76 @@ def test_get_registered_models(client: ModelRegistry): assert i == models +@pytest.mark.e2e +def test_get_registered_models_order_by(client: ModelRegistry): + models = 5 + + rms = [] + for name in [f"test_model{i}" for i in range(models)]: + rms.append( + client.register_model( + name, + "s3", + model_format_name="test_format", + model_format_version="test_version", + version="1.0.0", + ) + ) + + # id ordering should match creation order + i = 0 + for rm, by_id in zip( + rms, + client.get_registered_models().order_by_id(), + strict=True, + ): + assert rm.id == by_id.id + i += 1 + + assert i == models + + # and obviously, creation ordering should match creation ordering + i = 0 + for rm, by_creation in zip( + rms, + client.get_registered_models().order_by_creation_time(), + strict=True, + ): + assert rm.id == by_creation.id + i += 1 + + assert i == models + + # update order should match creation ordering by default + i = 0 + for rm, by_update in zip( + rms, + client.get_registered_models().order_by_update_time(), + strict=True, + ): + assert rm.id == by_update.id + i += 1 + + assert i == models + + # now update the models in reverse order + for rm in reversed(rms): + rm.description = "updated" + client.update(rm) + + # and they should match in reverse + i = 0 + for rm, by_update in zip( + reversed(rms), + client.get_registered_models().order_by_update_time(), + strict=True, + ): + assert rm.id == by_update.id + i += 1 + + assert i == models + + @pytest.mark.e2e def test_get_registered_models_and_reset(client: ModelRegistry): model_count = 6 @@ -260,6 +330,68 @@ def test_get_model_versions(client: ModelRegistry): assert i == models +@pytest.mark.e2e +def test_get_model_versions_order_by(client: ModelRegistry): + name = "test_model" + models = 5 + mvs = [] + for v in [f"1.0.{i}" for i in range(models)]: + client.register_model( + name, + "s3", + model_format_name="test_format", + model_format_version="test_version", + version=v, + ) + mvs.append(client.get_model_version(name, v)) + + i = 0 + for mv, by_id in zip( + mvs, + client.get_model_versions(name).order_by_id(), + strict=True, + ): + assert mv.id == by_id.id + i += 1 + + assert i == models + + i = 0 + for mv, by_creation in zip( + mvs, + client.get_model_versions(name).order_by_creation_time(), + strict=True, + ): + assert mv.id == by_creation.id + i += 1 + + assert i == models + + i = 0 + for mv, by_update in zip( + mvs, + client.get_model_versions(name).order_by_update_time(), + strict=True, + ): + assert mv.id == by_update.id + i += 1 + + assert i == models + + for mv in reversed(mvs): + mv.description = "updated" + client.update(mv) + + i = 0 + for mv, by_update in zip( + reversed(mvs), + client.get_model_versions(name).order_by_update_time(), + strict=True, + ): + assert mv.id == by_update.id + i += 1 + + @pytest.mark.e2e def test_get_model_versions_and_reset(client: ModelRegistry): name = "test_model" diff --git a/pkg/openapi/model_order_by_field.go b/pkg/openapi/model_order_by_field.go index 4ed6689a7..ec8b82949 100644 --- a/pkg/openapi/model_order_by_field.go +++ b/pkg/openapi/model_order_by_field.go @@ -22,14 +22,14 @@ type OrderByField string const ( ORDERBYFIELD_CREATE_TIME OrderByField = "CREATE_TIME" ORDERBYFIELD_LAST_UPDATE_TIME OrderByField = "LAST_UPDATE_TIME" - ORDERBYFIELD_ID OrderByField = "Id" + ORDERBYFIELD_ID OrderByField = "ID" ) // All allowed values of OrderByField enum var AllowedOrderByFieldEnumValues = []OrderByField{ "CREATE_TIME", "LAST_UPDATE_TIME", - "Id", + "ID", } func (v *OrderByField) UnmarshalJSON(src []byte) error {