From f9f78c34823dca087ca75c33199df0ad471873fc Mon Sep 17 00:00:00 2001 From: Eder Ignatowicz Date: Thu, 16 Jan 2025 17:39:41 -0500 Subject: [PATCH] feat(bff): create endpoint to list all model versions (#707) Signed-off-by: Eder Ignatowicz --- clients/ui/bff/README.md | 4 +++ clients/ui/bff/internal/api/app.go | 3 ++- .../internal/api/model_versions_handler.go | 24 +++++++++++++++++ .../api/model_versions_handler_test.go | 12 +++++++++ .../internal/api/registered_models_handler.go | 2 +- .../mocks/model_registry_client_mock.go | 7 ++++- .../internal/repositories/model_version.go | 16 ++++++++++++ .../repositories/model_version_test.go | 26 +++++++++++++++++++ .../internal/repositories/registered_model.go | 4 +-- .../repositories/registered_model_test.go | 6 ++--- 10 files changed, 96 insertions(+), 8 deletions(-) diff --git a/clients/ui/bff/README.md b/clients/ui/bff/README.md index 3886245e1..c6adf2512 100644 --- a/clients/ui/bff/README.md +++ b/clients/ui/bff/README.md @@ -126,6 +126,10 @@ curl -i -H "kubeflow-userid: user@example.com" -X PATCH "http://localhost:4000/a }}' ``` ``` +# GET /api/v1/model_registry/{model_registry_id}/model_versions +curl -i -H "kubeflow-userid: user@example.com" "http://localhost:4000/api/v1/model_registry/model-registry/model_versions?namespace=kubeflow" +``` +``` # GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} curl -i -H "kubeflow-userid: user@example.com" "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1?namespace=kubeflow" ``` diff --git a/clients/ui/bff/internal/api/app.go b/clients/ui/bff/internal/api/app.go index d9210a1c8..129b41c8d 100644 --- a/clients/ui/bff/internal/api/app.go +++ b/clients/ui/bff/internal/api/app.go @@ -101,8 +101,9 @@ func (app *App) Routes() http.Handler { apiRouter.PATCH(RegisteredModelPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.UpdateRegisteredModelHandler)))) apiRouter.GET(RegisteredModelVersionsPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetAllModelVersionsForRegisteredModelHandler)))) apiRouter.POST(RegisteredModelVersionsPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.CreateModelVersionForRegisteredModelHandler)))) - apiRouter.GET(ModelVersionPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient((app.GetModelVersionHandler))))) apiRouter.POST(ModelVersionListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.CreateModelVersionHandler)))) + apiRouter.GET(ModelVersionListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetAllModelVersionHandler)))) + apiRouter.GET(ModelVersionPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetModelVersionHandler)))) apiRouter.PATCH(ModelVersionPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.UpdateModelVersionHandler)))) apiRouter.GET(ModelVersionArtifactListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.GetAllModelArtifactsByModelVersionHandler)))) apiRouter.POST(ModelVersionArtifactListPath, app.AttachNamespace(app.PerformSARonSpecificService(app.AttachRESTClient(app.CreateModelArtifactByModelVersionHandler)))) diff --git a/clients/ui/bff/internal/api/model_versions_handler.go b/clients/ui/bff/internal/api/model_versions_handler.go index a945d0492..7c74377bf 100644 --- a/clients/ui/bff/internal/api/model_versions_handler.go +++ b/clients/ui/bff/internal/api/model_versions_handler.go @@ -18,6 +18,30 @@ type ModelVersionUpdateEnvelope Envelope[*openapi.ModelVersionUpdate, None] type ModelArtifactListEnvelope Envelope[*openapi.ModelArtifactList, None] type ModelArtifactEnvelope Envelope[*openapi.ModelArtifact, None] +func (app *App) GetAllModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + if !ok { + app.serverErrorResponse(w, r, errors.New("REST client not found")) + return + } + + versionList, err := app.repositories.ModelRegistryClient.GetAllModelVersions(client) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + responseBody := ModelVersionListEnvelope{ + Data: versionList, + } + + err = app.WriteJSON(w, http.StatusOK, responseBody, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } + +} + func (app *App) GetModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { diff --git a/clients/ui/bff/internal/api/model_versions_handler_test.go b/clients/ui/bff/internal/api/model_versions_handler_test.go index cce4cf022..c8cebc2a0 100644 --- a/clients/ui/bff/internal/api/model_versions_handler_test.go +++ b/clients/ui/bff/internal/api/model_versions_handler_test.go @@ -11,6 +11,18 @@ import ( var _ = Describe("TestGetModelVersionHandler", func() { Context("testing Model Version Handler", Ordered, func() { + It("should retrieve all model versions", func() { + By("fetching all model versions") + data := mocks.GetModelVersionListMock() + expected := ModelVersionListEnvelope{Data: &data} + actual, rs, err := setupApiTest[ModelVersionListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions?namespace=kubeflow", nil, k8sClient, mocks.KubeflowUserIDHeaderValue, "kubeflow") + Expect(err).NotTo(HaveOccurred()) + By("should match the expected model versions") + Expect(rs.StatusCode).To(Equal(http.StatusOK)) + Expect(actual.Data.Size).To(Equal(expected.Data.Size)) + Expect(actual.Data.Items).To(Equal(expected.Data.Items)) + }) + It("should retrieve a model version", func() { By("fetching a model version") data := mocks.GetModelVersionMocks()[0] diff --git a/clients/ui/bff/internal/api/registered_models_handler.go b/clients/ui/bff/internal/api/registered_models_handler.go index 98daef1cd..7f781f69b 100644 --- a/clients/ui/bff/internal/api/registered_models_handler.go +++ b/clients/ui/bff/internal/api/registered_models_handler.go @@ -177,7 +177,7 @@ func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWrit return } - versionList, err := app.repositories.ModelRegistryClient.GetAllModelVersions(client, ps.ByName(RegisteredModelId), r.URL.Query()) + versionList, err := app.repositories.ModelRegistryClient.GetAllModelVersionsForRegisteredModel(client, ps.ByName(RegisteredModelId), r.URL.Query()) if err != nil { app.serverErrorResponse(w, r, err) diff --git a/clients/ui/bff/internal/mocks/model_registry_client_mock.go b/clients/ui/bff/internal/mocks/model_registry_client_mock.go index 9a35e331e..d9fcfd48a 100644 --- a/clients/ui/bff/internal/mocks/model_registry_client_mock.go +++ b/clients/ui/bff/internal/mocks/model_registry_client_mock.go @@ -41,6 +41,11 @@ func (m *ModelRegistryClientMock) UpdateRegisteredModel(_ integrations.HTTPClien return &mockData, nil } +func (m *ModelRegistryClientMock) GetAllModelVersions(_ integrations.HTTPClientInterface) (*openapi.ModelVersionList, error) { + mockData := GetModelVersionListMock() + return &mockData, nil +} + func (m *ModelRegistryClientMock) GetModelVersion(_ integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error) { if id == "3" { mockData := GetModelVersionMocks()[2] @@ -61,7 +66,7 @@ func (m *ModelRegistryClientMock) UpdateModelVersion(_ integrations.HTTPClientIn return &mockData, nil } -func (m *ModelRegistryClientMock) GetAllModelVersions(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelVersionList, error) { +func (m *ModelRegistryClientMock) GetAllModelVersionsForRegisteredModel(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelVersionList, error) { mockData := GetModelVersionListMock() return &mockData, nil } diff --git a/clients/ui/bff/internal/repositories/model_version.go b/clients/ui/bff/internal/repositories/model_version.go index 526b187b4..b5a30278d 100644 --- a/clients/ui/bff/internal/repositories/model_version.go +++ b/clients/ui/bff/internal/repositories/model_version.go @@ -13,6 +13,7 @@ const modelVersionPath = "/model_versions" const artifactsByModelVersionPath = "/artifacts" type ModelVersionInterface interface { + GetAllModelVersions(client integrations.HTTPClientInterface) (*openapi.ModelVersionList, error) GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error) CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error) UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) @@ -24,6 +25,21 @@ type ModelVersion struct { ModelVersionInterface } +func (v ModelVersion) GetAllModelVersions(client integrations.HTTPClientInterface) (*openapi.ModelVersionList, error) { + response, err := client.GET(modelVersionPath) + + if err != nil { + return nil, fmt.Errorf("error fetching model versions: %w", err) + } + + var models openapi.ModelVersionList + if err := json.Unmarshal(response, &models); err != nil { + return nil, fmt.Errorf("error decoding response data: %w", err) + } + + return &models, nil +} + func (v ModelVersion) GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error) { path, err := url.JoinPath(modelVersionPath, id) if err != nil { diff --git a/clients/ui/bff/internal/repositories/model_version_test.go b/clients/ui/bff/internal/repositories/model_version_test.go index e5c389ff7..40d0522ff 100644 --- a/clients/ui/bff/internal/repositories/model_version_test.go +++ b/clients/ui/bff/internal/repositories/model_version_test.go @@ -38,6 +38,32 @@ func TestGetModelVersion(t *testing.T) { mockClient.AssertExpectations(t) } +func TestGetAllModelVersions(t *testing.T) { + _ = gofakeit.Seed(0) + + expected := mocks.GenerateMockModelVersionList() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + modelVersion := ModelVersion{} + + mockClient := new(mocks.MockHTTPClient) + mockClient.On("GET", modelVersionPath).Return(mockData, nil) + + actual, err := modelVersion.GetAllModelVersions(mockClient) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal(t, expected.NextPageToken, actual.NextPageToken) + assert.Equal(t, expected.PageSize, actual.PageSize) + assert.Equal(t, expected.Size, actual.Size) + assert.Equal(t, len(expected.Items), len(actual.Items)) + + mockClient.AssertExpectations(t) +} + func TestCreateModelVersion(t *testing.T) { _ = gofakeit.Seed(0) diff --git a/clients/ui/bff/internal/repositories/registered_model.go b/clients/ui/bff/internal/repositories/registered_model.go index e93552726..ba08cfe27 100644 --- a/clients/ui/bff/internal/repositories/registered_model.go +++ b/clients/ui/bff/internal/repositories/registered_model.go @@ -17,7 +17,7 @@ type RegisteredModelInterface interface { CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error) GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error) UpdateRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.RegisteredModel, error) - GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) + GetAllModelVersionsForRegisteredModel(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) CreateModelVersionForRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) } @@ -94,7 +94,7 @@ func (m RegisteredModel) UpdateRegisteredModel(client integrations.HTTPClientInt return &model, nil } -func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) { +func (m RegisteredModel) GetAllModelVersionsForRegisteredModel(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) { path, err := url.JoinPath(registeredModelPath, id, versionsPath) if err != nil { diff --git a/clients/ui/bff/internal/repositories/registered_model_test.go b/clients/ui/bff/internal/repositories/registered_model_test.go index d94aed1db..78aa0d3e1 100644 --- a/clients/ui/bff/internal/repositories/registered_model_test.go +++ b/clients/ui/bff/internal/repositories/registered_model_test.go @@ -134,7 +134,7 @@ func TestUpdateRegisteredModel(t *testing.T) { mockClient.AssertExpectations(t) } -func TestGetAllModelVersions(t *testing.T) { +func TestGetAllModelVersionsByRegisteredModel(t *testing.T) { _ = gofakeit.Seed(0) expected := mocks.GenerateMockModelVersionList() @@ -149,7 +149,7 @@ func TestGetAllModelVersions(t *testing.T) { assert.NoError(t, err) mockClient.On("GET", path).Return(mockData, nil) - actual, err := registeredModel.GetAllModelVersions(mockClient, "1", nil) + actual, err := registeredModel.GetAllModelVersionsForRegisteredModel(mockClient, "1", nil) assert.NoError(t, err) assert.NotNil(t, actual) assert.NoError(t, err) @@ -180,7 +180,7 @@ func TestGetAllModelVersionsWithPageParams(t *testing.T) { mockClient.On("GET", reqUrl).Return(mockData, nil) - actual, err := registeredModel.GetAllModelVersions(mockClient, "1", pageValues) + actual, err := registeredModel.GetAllModelVersionsForRegisteredModel(mockClient, "1", pageValues) assert.NoError(t, err) assert.NotNil(t, actual)