diff --git a/clients/ui/bff/README.md b/clients/ui/bff/README.md index dc8db578..02645396 100644 --- a/clients/ui/bff/README.md +++ b/clients/ui/bff/README.md @@ -54,14 +54,21 @@ make docker-build ### Endpoints -| URL Pattern | Handler | Action | -|--------------------------------------------------------------------------------------|-------------------------------|----------------------------------------------| -| GET /v1/healthcheck | HealthcheckHandler | Show application information. | -| GET /v1/model_registry | ModelRegistryHandler | Get all model registries, | -| GET /v1/model_registry/{model_registry_id}/registered_models | GetAllRegisteredModelsHandler | Gets a list of all RegisteredModel entities. | -| POST /v1/model_registry/{model_registry_id}/registered_models | CreateRegisteredModelHandler | Create a RegisteredModel entity. | -| GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} | GetRegisteredModelHandler | Get a RegisteredModel entity by ID | -| PATCH /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} | UpdateRegisteredModelHandler | Update a RegisteredModel entity by ID | +| URL Pattern | Handler | Action | +|----------------------------------------------------------------------------------------------|----------------------------------------------|-------------------------------------------------------------| +| GET /v1/healthcheck | HealthcheckHandler | Show application information. | +| GET /v1/model_registry | ModelRegistryHandler | Get all model registries, | +| GET /v1/model_registry/{model_registry_id}/registered_models | GetAllRegisteredModelsHandler | Gets a list of all RegisteredModel entities. | +| POST /v1/model_registry/{model_registry_id}/registered_models | CreateRegisteredModelHandler | Create a RegisteredModel entity. | +| GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} | GetRegisteredModelHandler | Get a RegisteredModel entity by ID | +| PATCH /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} | UpdateRegisteredModelHandler | Update a RegisteredModel entity by ID | +| GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} | GetModelVersionHandler | Get a ModelVersion by ID | +| POST /api/v1/model_registry/{model_registry_id}/model_versions | CreateModelVersionHandler | Create a ModelVersion entity | +| PATCH /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} | UpdateModelVersionHandler | Update a ModelVersion entity by ID | +| GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions | GetAllModelVersionsForRegisteredModelHandler | Get all ModelVersion entities by RegisteredModel ID | +| POST /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions | CreateModelVersionForRegisteredModelHandler | Create a ModelVersion entity for a specific RegisteredModel | +| GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts | GetAllModelArtifactsByModelVersionHandler | Get all ModelArtifact entities by ModelVersion ID | +| POST /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts | CreateModelArtifactByModelVersion | Create a ModelArtifact entity for a specific ModelVersion | ### Sample local calls ``` @@ -115,4 +122,90 @@ curl -i -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/reg "owner": "eder", "state": "LIVE" }}' +``` +``` +# GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} +curl -i http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1 +``` +``` +# POST /api/v1/model_registry/{model_registry_id}/model_versions +curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions" \ + -H "Content-Type: application/json" \ + -d '{ "data": { + "customProperties": { + "my-label9": { + "metadataType": "MetadataStringValue", + "string_value": "val" + } + }, + "description": "Version description", + "externalId": "9927", + "name": "ModelVersion One", + "state": "LIVE", + "author": "alex", + "registeredModelId": "1" +}}' +``` +``` +# PATCH /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} +curl -i -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1" \ + -H "Content-Type: application/json" \ + -d '{ "data": { + "customProperties": { + "my-label9": { + "metadataType": "MetadataStringValue", + "string_value": "val" + } + }, + "description": "New description", + "externalId": "9927", + "name": "ModelVersion One", + "state": "LIVE", + "author": "alex", + "registeredModelId": "1" +}}' +``` +``` +# GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions +curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions +``` +``` +# POST /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions +curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions" \ + -H "Content-Type: application/json" \ + -d '{ "data": { + "customProperties": { + "my-label9": { + "metadataType": "MetadataStringValue", + "string_value": "val" + } + }, + "description": "New description", + "externalId": "9927", + "name": "ModelVersion One", + "state": "LIVE", + "author": "alex" +}}' +``` +``` +# GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts +curl -i http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts +``` +``` +# POST /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts +curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts" \ + -H "Content-Type: application/json" \ + -d '{ "data": { + "customProperties": { + "my-label9": { + "metadataType": "MetadataStringValue", + "string_value": "val" + } + }, + "description": "New description", + "externalId": "9927", + "name": "ModelArtifact One", + "state": "LIVE", + "artifactType": "TYPE_ONE" +}}' ``` \ No newline at end of file diff --git a/clients/ui/bff/api/app.go b/clients/ui/bff/api/app.go index a87350d9..36b0f6b9 100644 --- a/clients/ui/bff/api/app.go +++ b/clients/ui/bff/api/app.go @@ -13,14 +13,23 @@ import ( ) const ( - Version = "1.0.0" - PathPrefix = "/api/v1" - ModelRegistryId = "model_registry_id" - RegisteredModelId = "registered_model_id" - HealthCheckPath = PathPrefix + "/healthcheck" - ModelRegistry = PathPrefix + "/model_registry" - RegisteredModelListPath = ModelRegistry + "/:" + ModelRegistryId + "/registered_models" - RegisteredModelPath = RegisteredModelListPath + "/:" + RegisteredModelId + Version = "1.0.0" + PathPrefix = "/api/v1" + ModelRegistryId = "model_registry_id" + RegisteredModelId = "registered_model_id" + ModelVersionId = "model_version_id" + ModelArtifactId = "model_artifact_id" + HealthCheckPath = PathPrefix + "/healthcheck" + ModelRegistryListPath = PathPrefix + "/model_registry" + ModelRegistryPath = ModelRegistryListPath + "/:" + ModelRegistryId + RegisteredModelListPath = ModelRegistryPath + "/registered_models" + RegisteredModelPath = RegisteredModelListPath + "/:" + RegisteredModelId + RegisteredModelVersionsPath = RegisteredModelPath + "/versions" + ModelVersionListPath = ModelRegistryPath + "/model_versions" + ModelVersionPath = ModelVersionListPath + "/:" + ModelVersionId + ModelVersionArtifactListPath = ModelVersionPath + "/artifacts" + ModelArtifactListPath = ModelRegistryPath + "/model_artifacts" + ModelArtifactPath = ModelArtifactListPath + "/:" + ModelArtifactId ) type App struct { @@ -54,7 +63,7 @@ func NewApp(cfg config.EnvConfig, logger *slog.Logger) (*App, error) { } if err != nil { - return nil, fmt.Errorf("failed to create ModelRegistry client: %w", err) + return nil, fmt.Errorf("failed to create ModelRegistryListPath client: %w", err) } app := &App{ @@ -78,9 +87,18 @@ func (app *App) Routes() http.Handler { router.GET(RegisteredModelPath, app.AttachRESTClient(app.GetRegisteredModelHandler)) router.POST(RegisteredModelListPath, app.AttachRESTClient(app.CreateRegisteredModelHandler)) router.PATCH(RegisteredModelPath, app.AttachRESTClient(app.UpdateRegisteredModelHandler)) + router.GET(RegisteredModelVersionsPath, app.AttachRESTClient(app.GetAllModelVersionsForRegisteredModelHandler)) + router.POST(RegisteredModelVersionsPath, app.AttachRESTClient(app.CreateModelVersionForRegisteredModelHandler)) + + router.GET(ModelVersionPath, app.AttachRESTClient(app.GetModelVersionHandler)) + router.POST(ModelVersionListPath, app.AttachRESTClient(app.CreateModelVersionHandler)) + router.PATCH(ModelVersionPath, app.AttachRESTClient(app.UpdateModelVersionHandler)) + router.GET(ModelVersionArtifactListPath, app.AttachRESTClient(app.GetAllModelArtifactsByModelVersionHandler)) + router.POST(ModelVersionArtifactListPath, app.AttachRESTClient(app.CreateModelArtifactByModelVersionHandler)) // Kubernetes client routes - router.GET(ModelRegistry, app.ModelRegistryHandler) + router.GET(ModelRegistryListPath, app.ModelRegistryHandler) + router.PATCH(ModelRegistryPath, app.AttachRESTClient(app.UpdateModelVersionHandler)) return app.RecoverPanic(app.enableCORS(router)) } diff --git a/clients/ui/bff/api/helpers.go b/clients/ui/bff/api/helpers.go index a129e335..6cb6329b 100644 --- a/clients/ui/bff/api/helpers.go +++ b/clients/ui/bff/api/helpers.go @@ -94,3 +94,15 @@ func (app *App) ReadJSON(w http.ResponseWriter, r *http.Request, dst any) error return nil } + +func ParseURLTemplate(tmpl string, params map[string]string) string { + args := make([]string, len(params)*2) + + for k, v := range params { + args = append(args, ":"+k, v) + } + + r := strings.NewReplacer(args...) + + return r.Replace(tmpl) +} diff --git a/clients/ui/bff/api/helpers_test.go b/clients/ui/bff/api/helpers_test.go new file mode 100644 index 00000000..4cf02821 --- /dev/null +++ b/clients/ui/bff/api/helpers_test.go @@ -0,0 +1,21 @@ +package api + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParseURLTemplate(t *testing.T) { + expected := "/v1/model_registry/demo-registry/registered_models/111-222-333/versions" + tmpl := "/v1/model_registry/:model_registry_id/registered_models/:registered_model_id/versions" + params := map[string]string{"model_registry_id": "demo-registry", "registered_model_id": "111-222-333"} + + actual := ParseURLTemplate(tmpl, params) + + assert.Equal(t, expected, actual) +} + +func TestParseURLTemplateWhenEmpty(t *testing.T) { + actual := ParseURLTemplate("", nil) + assert.Empty(t, actual) +} diff --git a/clients/ui/bff/api/model_registry_handler_test.go b/clients/ui/bff/api/model_registry_handler_test.go index 0b4949d5..355f4b6e 100644 --- a/clients/ui/bff/api/model_registry_handler_test.go +++ b/clients/ui/bff/api/model_registry_handler_test.go @@ -18,7 +18,7 @@ func TestModelRegistryHandler(t *testing.T) { kubernetesClient: mockK8sClient, } - req, err := http.NewRequest(http.MethodGet, ModelRegistry, nil) + req, err := http.NewRequest(http.MethodGet, ModelRegistryListPath, nil) assert.NoError(t, err) rr := httptest.NewRecorder() diff --git a/clients/ui/bff/api/model_versions_handler.go b/clients/ui/bff/api/model_versions_handler.go new file mode 100644 index 00000000..a68c2576 --- /dev/null +++ b/clients/ui/bff/api/model_versions_handler.go @@ -0,0 +1,232 @@ +package api + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/julienschmidt/httprouter" + "github.com/kubeflow/model-registry/pkg/openapi" + "github.com/kubeflow/model-registry/ui/bff/integrations" + "github.com/kubeflow/model-registry/ui/bff/validation" + "net/http" +) + +type ModelVersionEnvelope Envelope[*openapi.ModelVersion, None] +type ModelVersionListEnvelope Envelope[*openapi.ModelVersionList, None] +type ModelArtifactListEnvelope Envelope[*openapi.ModelArtifactList, None] +type ModelArtifactEnvelope Envelope[*openapi.ModelArtifact, None] + +func (app *App) GetModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + if !ok { + app.serverErrorResponse(w, r, errors.New("REST client not found")) + return + } + + model, err := app.modelRegistryClient.GetModelVersion(client, ps.ByName(ModelVersionId)) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + if _, ok := model.GetIdOk(); !ok { + app.notFoundResponse(w, r) + return + } + + result := ModelVersionEnvelope{ + Data: model, + } + + err = app.WriteJSON(w, http.StatusOK, result, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} + +func (app *App) CreateModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + if !ok { + app.serverErrorResponse(w, r, errors.New("REST client not found")) + return + } + + var envelope ModelVersionEnvelope + if err := json.NewDecoder(r.Body).Decode(&envelope); err != nil { + app.serverErrorResponse(w, r, fmt.Errorf("error decoding JSON:: %v", err.Error())) + return + } + + data := *envelope.Data + + if err := validation.ValidateModelVersion(data); err != nil { + app.badRequestResponse(w, r, fmt.Errorf("validation error:: %v", err.Error())) + return + } + + jsonData, err := json.Marshal(data) + if err != nil { + app.serverErrorResponse(w, r, fmt.Errorf("error marshaling ModelVersion to JSON: %w", err)) + return + } + + createdVersion, err := app.modelRegistryClient.CreateModelVersion(client, jsonData) + if err != nil { + var httpErr *integrations.HTTPError + if errors.As(err, &httpErr) { + app.errorResponse(w, r, httpErr) + } else { + app.serverErrorResponse(w, r, err) + } + return + } + + if createdVersion == nil { + app.serverErrorResponse(w, r, fmt.Errorf("created ModelVersion is nil")) + return + } + + response := ModelVersionEnvelope{ + Data: createdVersion, + } + + w.Header().Set("Location", r.URL.JoinPath(*createdVersion.Id).String()) + err = app.WriteJSON(w, http.StatusCreated, response, nil) + if err != nil { + app.serverErrorResponse(w, r, fmt.Errorf("error writing JSON")) + return + } +} + +func (app *App) UpdateModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + if !ok { + app.serverErrorResponse(w, r, errors.New("REST client not found")) + return + } + + var envelope ModelVersionEnvelope + if err := json.NewDecoder(r.Body).Decode(&envelope); err != nil { + app.serverErrorResponse(w, r, fmt.Errorf("error decoding JSON:: %v", err.Error())) + return + } + + data := *envelope.Data + + if err := validation.ValidateModelVersion(data); err != nil { + app.badRequestResponse(w, r, fmt.Errorf("validation error:: %v", err.Error())) + return + } + + jsonData, err := json.Marshal(data) + if err != nil { + app.serverErrorResponse(w, r, fmt.Errorf("error marshaling ModelVersion to JSON: %w", err)) + return + } + + patchedModel, err := app.modelRegistryClient.UpdateModelVersion(client, ps.ByName(ModelVersionId), jsonData) + if err != nil { + var httpErr *integrations.HTTPError + if errors.As(err, &httpErr) { + app.errorResponse(w, r, httpErr) + } else { + app.serverErrorResponse(w, r, err) + } + return + } + + if patchedModel == nil { + app.serverErrorResponse(w, r, fmt.Errorf("patched ModelVersion is nil")) + return + } + + responseBody := ModelVersionEnvelope{ + Data: patchedModel, + } + + err = app.WriteJSON(w, http.StatusOK, responseBody, nil) + if err != nil { + app.serverErrorResponse(w, r, fmt.Errorf("error writing JSON")) + return + } +} + +func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + if !ok { + app.serverErrorResponse(w, r, errors.New("REST client not found")) + return + } + + data, err := app.modelRegistryClient.GetModelArtifactsByModelVersion(client, ps.ByName(ModelVersionId)) + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + result := ModelArtifactListEnvelope{ + Data: data, + } + + err = app.WriteJSON(w, http.StatusOK, result, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} + +func (app *App) CreateModelArtifactByModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + if !ok { + app.serverErrorResponse(w, r, errors.New("REST client not found")) + return + } + + var envelope ModelArtifactEnvelope + if err := json.NewDecoder(r.Body).Decode(&envelope); err != nil { + app.serverErrorResponse(w, r, fmt.Errorf("error decoding JSON:: %v", err.Error())) + return + } + + data := *envelope.Data + + if err := validation.ValidateModelArtifact(data); err != nil { + app.badRequestResponse(w, r, fmt.Errorf("validation error:: %v", err.Error())) + return + } + + jsonData, err := json.Marshal(data) + if err != nil { + app.serverErrorResponse(w, r, fmt.Errorf("error marshaling ModelVersion to JSON: %w", err)) + return + } + + createdArtifact, err := app.modelRegistryClient.CreateModelArtifactByModelVersion(client, ps.ByName(ModelVersionId), jsonData) + if err != nil { + var httpErr *integrations.HTTPError + if errors.As(err, &httpErr) { + app.errorResponse(w, r, httpErr) + } else { + app.serverErrorResponse(w, r, err) + } + return + } + + if createdArtifact == nil { + app.serverErrorResponse(w, r, fmt.Errorf("created ModelArtifact is nil")) + return + } + + response := ModelArtifactEnvelope{ + Data: createdArtifact, + } + + w.Header().Set("Location", ParseURLTemplate(ModelArtifactPath, map[string]string{ + ModelRegistryId: ps.ByName(ModelRegistryId), + ModelArtifactId: createdArtifact.GetId(), + })) + err = app.WriteJSON(w, http.StatusCreated, response, nil) + if err != nil { + app.serverErrorResponse(w, r, fmt.Errorf("error writing JSON")) + return + } +} diff --git a/clients/ui/bff/api/model_versions_handler_test.go b/clients/ui/bff/api/model_versions_handler_test.go new file mode 100644 index 00000000..27050e81 --- /dev/null +++ b/clients/ui/bff/api/model_versions_handler_test.go @@ -0,0 +1,79 @@ +package api + +import ( + "github.com/kubeflow/model-registry/pkg/openapi" + "github.com/kubeflow/model-registry/ui/bff/internals/mocks" + "github.com/stretchr/testify/assert" + "net/http" + "testing" +) + +func TestGetModelVersionHandler(t *testing.T) { + data := mocks.GetModelVersionMocks()[0] + expected := ModelVersionEnvelope{Data: &data} + + actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1", nil) + assert.NoError(t, err) + + assert.Equal(t, http.StatusOK, rs.StatusCode) + assert.Equal(t, expected.Data.Name, actual.Data.Name) +} + +func TestCreateModelVersionHandler(t *testing.T) { + data := mocks.GetModelVersionMocks()[0] + expected := ModelVersionEnvelope{Data: &data} + + body := ModelVersionEnvelope{Data: openapi.NewModelVersion("Model One", "1")} + + actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions", body) + assert.NoError(t, err) + + assert.Equal(t, http.StatusCreated, rs.StatusCode) + assert.Equal(t, expected.Data.Name, actual.Data.Name) + assert.Equal(t, rs.Header.Get("Location"), "/api/v1/model_registry/model-registry/model_versions/1") +} + +func TestUpdateModelVersionHandler(t *testing.T) { + data := mocks.GetModelVersionMocks()[0] + expected := ModelVersionEnvelope{Data: &data} + + body := ModelVersionEnvelope{Data: openapi.NewModelVersion("Model One", "1")} + + actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/model_versions/1", body) + assert.NoError(t, err) + + assert.Equal(t, http.StatusOK, rs.StatusCode) + assert.Equal(t, expected.Data.Name, actual.Data.Name) +} + +func TestGetAllModelArtifactsByModelVersionHandler(t *testing.T) { + data := mocks.GetModelArtifactListMock() + expected := ModelArtifactListEnvelope{Data: &data} + + actual, rs, err := setupApiTest[ModelArtifactListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", nil) + assert.NoError(t, err) + + assert.Equal(t, http.StatusOK, rs.StatusCode) + assert.Equal(t, expected.Data.Size, actual.Data.Size) + assert.Equal(t, expected.Data.PageSize, actual.Data.PageSize) + assert.Equal(t, expected.Data.NextPageToken, actual.Data.NextPageToken) + assert.Equal(t, len(expected.Data.Items), len(actual.Data.Items)) +} + +func TestCreateModelArtifactByModelVersionHandler(t *testing.T) { + data := mocks.GetModelArtifactMocks()[0] + expected := ModelArtifactEnvelope{Data: &data} + + artifact := openapi.ModelArtifact{ + Name: openapi.PtrString("Artifact One"), + ArtifactType: "ARTIFACT_TYPE_ONE", + } + body := ModelArtifactEnvelope{Data: &artifact} + + actual, rs, err := setupApiTest[ModelArtifactEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", body) + assert.NoError(t, err) + + assert.Equal(t, http.StatusCreated, rs.StatusCode) + assert.Equal(t, expected.Data.GetArtifactType(), actual.Data.GetArtifactType()) + assert.Equal(t, rs.Header.Get("Location"), "/api/v1/model_registry/model-registry/model_artifacts/1") +} diff --git a/clients/ui/bff/api/registered_models_handler.go b/clients/ui/bff/api/registered_models_handler.go index bafe14e3..6a27215b 100644 --- a/clients/ui/bff/api/registered_models_handler.go +++ b/clients/ui/bff/api/registered_models_handler.go @@ -172,3 +172,79 @@ func (app *App) UpdateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ return } } + +func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + //TODO (acreasy) implement pagination + client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + if !ok { + app.serverErrorResponse(w, r, errors.New("REST client not found")) + return + } + + versionList, err := app.modelRegistryClient.GetAllModelVersions(client, ps.ByName(RegisteredModelId)) + + if err != nil { + app.serverErrorResponse(w, r, err) + return + } + + responseBody := ModelVersionListEnvelope{ + Data: versionList, + } + + err = app.WriteJSON(w, http.StatusOK, responseBody, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} + +func (app *App) CreateModelVersionForRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + if !ok { + app.serverErrorResponse(w, r, errors.New("REST client not found")) + return + } + + var envelope ModelVersionEnvelope + if err := json.NewDecoder(r.Body).Decode(&envelope); err != nil { + app.badRequestResponse(w, r, fmt.Errorf("error decoding JSON:: %v", err.Error())) + return + } + + data := *envelope.Data + + if err := validation.ValidateModelVersion(data); err != nil { + app.badRequestResponse(w, r, fmt.Errorf("validation error:: %v", err.Error())) + } + + jsonData, err := json.Marshal(data) + if err != nil { + app.serverErrorResponse(w, r, fmt.Errorf("error marshaling model to JSON: %w", err)) + } + + createdVersion, err := app.modelRegistryClient.CreateModelVersionForRegisteredModel(client, ps.ByName(RegisteredModelId), jsonData) + if err != nil { + var httpErr *integrations.HTTPError + if errors.As(err, &httpErr) { + app.errorResponse(w, r, httpErr) + } else { + app.serverErrorResponse(w, r, err) + } + return + } + + if createdVersion == nil { + app.serverErrorResponse(w, r, fmt.Errorf("created model version is nil")) + return + } + + responseBody := ModelVersionEnvelope{ + Data: createdVersion, + } + + w.Header().Set("Location", ParseURLTemplate(ModelVersionPath, map[string]string{ModelRegistryId: ps.ByName(ModelRegistryId), ModelVersionId: createdVersion.GetId()})) + err = app.WriteJSON(w, http.StatusCreated, responseBody, nil) + if err != nil { + app.serverErrorResponse(w, r, err) + } +} diff --git a/clients/ui/bff/api/registered_models_handler_test.go b/clients/ui/bff/api/registered_models_handler_test.go index f7a9005b..125be1cd 100644 --- a/clients/ui/bff/api/registered_models_handler_test.go +++ b/clients/ui/bff/api/registered_models_handler_test.go @@ -1,185 +1,90 @@ package api import ( - "bytes" - "context" - "encoding/json" "github.com/kubeflow/model-registry/pkg/openapi" "github.com/kubeflow/model-registry/ui/bff/internals/mocks" "github.com/stretchr/testify/assert" - "io" "net/http" - "net/http/httptest" "testing" ) func TestGetRegisteredModelHandler(t *testing.T) { - mockMRClient, _ := mocks.NewModelRegistryClient(nil) - mockClient := new(mocks.MockHTTPClient) + data := mocks.GetRegisteredModelMocks()[0] + expected := RegisteredModelEnvelope{Data: &data} - testApp := App{ - modelRegistryClient: mockMRClient, - } - - req, err := http.NewRequest(http.MethodGet, - "/api/v1/model_registry/model-registry/registered_models/1", nil) - assert.NoError(t, err) - - ctx := context.WithValue(req.Context(), httpClientKey, mockClient) - req = req.WithContext(ctx) - - rr := httptest.NewRecorder() - - testApp.GetRegisteredModelHandler(rr, req, nil) - rs := rr.Result() - - defer rs.Body.Close() - - body, err := io.ReadAll(rs.Body) - assert.NoError(t, err) - var registeredModelRes RegisteredModelEnvelope - err = json.Unmarshal(body, ®isteredModelRes) + actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1", nil) assert.NoError(t, err) - assert.Equal(t, http.StatusOK, rr.Code) - - mockModel := mocks.GetRegisteredModelMocks()[0] - - var expected = RegisteredModelEnvelope{ - Data: &mockModel, - } - //TODO assert the full structure, I couldn't get unmarshalling to work for the full customProperties values // this issue is in the test only - assert.Equal(t, expected.Data.Name, registeredModelRes.Data.Name) + assert.Equal(t, http.StatusOK, rs.StatusCode) + assert.Equal(t, expected.Data.Name, actual.Data.Name) } func TestGetAllRegisteredModelsHandler(t *testing.T) { - mockMRClient, _ := mocks.NewModelRegistryClient(nil) - mockClient := new(mocks.MockHTTPClient) + data := mocks.GetRegisteredModelListMock() + expected := RegisteredModelListEnvelope{Data: &data} - testApp := App{ - modelRegistryClient: mockMRClient, - } - - req, err := http.NewRequest(http.MethodGet, - "/api/v1/model_registry/model-registry/registered_models", nil) - assert.NoError(t, err) - - ctx := context.WithValue(req.Context(), httpClientKey, mockClient) - req = req.WithContext(ctx) - - rr := httptest.NewRecorder() - - testApp.GetAllRegisteredModelsHandler(rr, req, nil) - rs := rr.Result() - - defer rs.Body.Close() - - body, err := io.ReadAll(rs.Body) - assert.NoError(t, err) - var registeredModelsListRes RegisteredModelListEnvelope - err = json.Unmarshal(body, ®isteredModelsListRes) + actual, rs, err := setupApiTest[RegisteredModelListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models", nil) assert.NoError(t, err) - assert.Equal(t, http.StatusOK, rr.Code) - - modelList := mocks.GetRegisteredModelListMock() - - var expected = RegisteredModelListEnvelope{ - Data: &modelList, - } - - assert.Equal(t, expected.Data.Size, registeredModelsListRes.Data.Size) - assert.Equal(t, expected.Data.PageSize, registeredModelsListRes.Data.PageSize) - assert.Equal(t, expected.Data.NextPageToken, registeredModelsListRes.Data.NextPageToken) - assert.Equal(t, len(expected.Data.Items), len(registeredModelsListRes.Data.Items)) + assert.Equal(t, http.StatusOK, rs.StatusCode) + assert.Equal(t, expected.Data.Size, actual.Data.Size) + assert.Equal(t, expected.Data.PageSize, actual.Data.PageSize) + assert.Equal(t, expected.Data.NextPageToken, actual.Data.NextPageToken) + assert.Equal(t, len(expected.Data.Items), len(actual.Data.Items)) } func TestCreateRegisteredModelHandler(t *testing.T) { - mockMRClient, _ := mocks.NewModelRegistryClient(nil) - mockClient := new(mocks.MockHTTPClient) - - testApp := App{ - modelRegistryClient: mockMRClient, - } + data := mocks.GetRegisteredModelMocks()[0] + expected := RegisteredModelEnvelope{Data: &data} - newModel := openapi.NewRegisteredModel("Model One") - newEnvelope := RegisteredModelEnvelope{Data: newModel} + body := RegisteredModelEnvelope{Data: openapi.NewRegisteredModel("Model One")} - newModelJSON, err := json.Marshal(newEnvelope) + actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models", body) assert.NoError(t, err) - reqBody := bytes.NewReader(newModelJSON) - - req, err := http.NewRequest(http.MethodPost, - "/api/v1/model_registry/model-registry/registered_models", reqBody) - assert.NoError(t, err) - - ctx := context.WithValue(req.Context(), httpClientKey, mockClient) - req = req.WithContext(ctx) - - rr := httptest.NewRecorder() - - testApp.CreateRegisteredModelHandler(rr, req, nil) - rs := rr.Result() - - defer rs.Body.Close() - - body, err := io.ReadAll(rs.Body) - assert.NoError(t, err) - var actual RegisteredModelEnvelope - err = json.Unmarshal(body, &actual) - assert.NoError(t, err) - - assert.Equal(t, http.StatusCreated, rr.Code) - - var expected = mocks.GetRegisteredModelMocks()[0] - - assert.Equal(t, expected.Name, actual.Data.Name) - assert.NotEmpty(t, rs.Header.Get("location")) + assert.Equal(t, http.StatusCreated, rs.StatusCode) + assert.Equal(t, expected.Data.Name, actual.Data.Name) + assert.Equal(t, rs.Header.Get("location"), "/api/v1/model_registry/model-registry/registered_models/1") } func TestUpdateRegisteredModelHandler(t *testing.T) { - mockMRClient, _ := mocks.NewModelRegistryClient(nil) - mockClient := new(mocks.MockHTTPClient) - - testApp := App{ - modelRegistryClient: mockMRClient, - } + data := mocks.GetRegisteredModelMocks()[0] + expected := RegisteredModelEnvelope{Data: &data} - newModel := openapi.NewRegisteredModel("Model One") - newEnvelope := RegisteredModelEnvelope{Data: newModel} + body := RegisteredModelEnvelope{Data: openapi.NewRegisteredModel("Model One")} - newEnvelopeJSON, err := json.Marshal(newEnvelope) + actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/registered_models/1", body) assert.NoError(t, err) - reqBody := bytes.NewReader(newEnvelopeJSON) - - req, err := http.NewRequest(http.MethodPatch, - "/api/v1/model_registry/model-registry/registered_models/1", reqBody) - assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rs.StatusCode) + assert.Equal(t, expected.Data.Name, actual.Data.Name) +} - ctx := context.WithValue(req.Context(), httpClientKey, mockClient) - req = req.WithContext(ctx) +func TestGetAllModelVersionsForRegisteredModelHandler(t *testing.T) { + data := mocks.GetModelVersionListMock() + expected := ModelVersionListEnvelope{Data: &data} - rr := httptest.NewRecorder() + actual, rs, err := setupApiTest[ModelVersionListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1/versions", nil) + assert.NoError(t, err) - testApp.UpdateRegisteredModelHandler(rr, req, nil) - rs := rr.Result() + assert.Equal(t, http.StatusOK, rs.StatusCode) + assert.Equal(t, expected.Data.Size, actual.Data.Size) + assert.Equal(t, expected.Data.PageSize, actual.Data.PageSize) + assert.Equal(t, expected.Data.NextPageToken, actual.Data.NextPageToken) + assert.Equal(t, len(expected.Data.Items), len(actual.Data.Items)) +} - defer rs.Body.Close() +func TestCreateModelVersionForRegisteredModelHandler(t *testing.T) { + data := mocks.GetModelVersionMocks()[0] + expected := ModelVersionEnvelope{Data: &data} - body, err := io.ReadAll(rs.Body) - assert.NoError(t, err) - var actual RegisteredModelEnvelope - err = json.Unmarshal(body, &actual) + body := ModelVersionEnvelope{Data: openapi.NewModelVersion("Version Fifty", "")} + actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models/1/versions", body) assert.NoError(t, err) - assert.Equal(t, http.StatusOK, rr.Code) - - expectedModel := mocks.GetRegisteredModelMocks()[0] - expected := RegisteredModelEnvelope{Data: &expectedModel} - + assert.Equal(t, http.StatusCreated, rs.StatusCode) assert.Equal(t, expected.Data.Name, actual.Data.Name) + assert.Equal(t, rs.Header.Get("Location"), "/api/v1/model_registry/model-registry/model_versions/1") } diff --git a/clients/ui/bff/api/test_utils.go b/clients/ui/bff/api/test_utils.go new file mode 100644 index 00000000..aea27d03 --- /dev/null +++ b/clients/ui/bff/api/test_utils.go @@ -0,0 +1,73 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "github.com/kubeflow/model-registry/ui/bff/internals/mocks" + "io" + "net/http" + "net/http/httptest" +) + +func setupApiTest[T any](method string, url string, body interface{}) (T, *http.Response, error) { + mockMRClient, err := mocks.NewModelRegistryClient(nil) + if err != nil { + return *new(T), nil, err + } + mockK8sClient, err := mocks.NewKubernetesClient(nil) + if err != nil { + return *new(T), nil, err + } + + mockClient := new(mocks.MockHTTPClient) + + testApp := App{ + modelRegistryClient: mockMRClient, + kubernetesClient: mockK8sClient, + } + + var req *http.Request + if body != nil { + r, err := json.Marshal(body) + if err != nil { + return *new(T), nil, err + } + bytes.NewReader(r) + req, err = http.NewRequest(method, url, bytes.NewReader(r)) + if err != nil { + return *new(T), nil, err + } + } else { + req, err = http.NewRequest(method, url, nil) + if err != nil { + return *new(T), nil, err + } + } + + ctx := context.WithValue(req.Context(), httpClientKey, mockClient) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + + testApp.Routes().ServeHTTP(rr, req) + + rs := rr.Result() + defer rs.Body.Close() + respBody, err := io.ReadAll(rs.Body) + if err != nil { + return *new(T), nil, err + } + + var entity T + err = json.Unmarshal(respBody, &entity) + if err != nil { + if err == io.EOF { + // There's no body to parse. + return *new(T), rs, nil + } + return *new(T), nil, err + } + + return entity, rs, nil +} diff --git a/clients/ui/bff/data/model_registry_client.go b/clients/ui/bff/data/model_registry_client.go index ccc4ff9e..0f0426a6 100644 --- a/clients/ui/bff/data/model_registry_client.go +++ b/clients/ui/bff/data/model_registry_client.go @@ -6,11 +6,13 @@ import ( type ModelRegistryClientInterface interface { RegisteredModelInterface + ModelVersionInterface } type ModelRegistryClient struct { logger *slog.Logger RegisteredModel + ModelVersion } func NewModelRegistryClient(logger *slog.Logger) (ModelRegistryClientInterface, error) { diff --git a/clients/ui/bff/data/model_version.go b/clients/ui/bff/data/model_version.go new file mode 100644 index 00000000..41057d0a --- /dev/null +++ b/clients/ui/bff/data/model_version.go @@ -0,0 +1,119 @@ +package data + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/kubeflow/model-registry/pkg/openapi" + "github.com/kubeflow/model-registry/ui/bff/integrations" + "net/url" +) + +const modelVersionPath = "/model_versions" +const artifactsByModelVersionPath = "/artifacts" + +type ModelVersionInterface interface { + GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error) + CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error) + UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) + GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) + CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) +} + +type ModelVersion struct { + ModelVersionInterface +} + +func (v ModelVersion) GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error) { + path, err := url.JoinPath(modelVersionPath, id) + if err != nil { + return nil, err + } + + response, err := client.GET(path) + + if err != nil { + return nil, fmt.Errorf("error fetching model version: %w", err) + } + + var model openapi.ModelVersion + if err := json.Unmarshal(response, &model); err != nil { + return nil, fmt.Errorf("error decoding response data: %w", err) + } + + return &model, nil +} + +func (v ModelVersion) CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error) { + responseData, err := client.POST(modelVersionPath, bytes.NewBuffer(jsonData)) + + if err != nil { + return nil, fmt.Errorf("error posting registered model: %w", err) + } + + var model openapi.ModelVersion + if err := json.Unmarshal(responseData, &model); err != nil { + return nil, fmt.Errorf("error decoding response data: %w", err) + } + + return &model, nil +} + +func (v ModelVersion) UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) { + path, err := url.JoinPath(modelVersionPath, id) + + if err != nil { + return nil, err + } + + responseData, err := client.PATCH(path, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("error patching ModelVersion: %w", err) + } + + var model openapi.ModelVersion + if err := json.Unmarshal(responseData, &model); err != nil { + return nil, fmt.Errorf("error decoding response data: %w", err) + } + + return &model, nil +} + +func (v ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) { + path, err := url.JoinPath(modelVersionPath, id, artifactsByModelVersionPath) + + if err != nil { + return nil, err + } + + responseData, err := client.GET(path) + if err != nil { + return nil, fmt.Errorf("error fetching model version artifacts: %w", err) + } + + var model openapi.ModelArtifactList + if err := json.Unmarshal(responseData, &model); err != nil { + return nil, fmt.Errorf("error decoding response data: %w", err) + } + + return &model, nil +} + +func (v ModelVersion) CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) { + path, err := url.JoinPath(modelVersionPath, id, artifactsByModelVersionPath) + if err != nil { + return nil, err + } + + responseData, err := client.POST(path, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("error posting model artifact: %w", err) + } + + var model openapi.ModelArtifact + if err := json.Unmarshal(responseData, &model); err != nil { + return nil, fmt.Errorf("error decoding response data: %w", err) + } + + return &model, nil +} diff --git a/clients/ui/bff/data/model_version_test.go b/clients/ui/bff/data/model_version_test.go new file mode 100644 index 00000000..a17f0d1f --- /dev/null +++ b/clients/ui/bff/data/model_version_test.go @@ -0,0 +1,142 @@ +package data + +import ( + "encoding/json" + "github.com/brianvoe/gofakeit/v7" + "github.com/kubeflow/model-registry/ui/bff/internals/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "net/http" + "net/url" + "testing" +) + +func TestGetModelVersion(t *testing.T) { + gofakeit.Seed(0) + + expected := mocks.GenerateMockModelVersion() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + modelVersion := ModelVersion{} + + path, err := url.JoinPath(modelVersionPath, expected.GetId()) + assert.NoError(t, err) + + mockClient := new(mocks.MockHTTPClient) + mockClient.On("GET", path).Return(mockData, nil) + + actual, err := modelVersion.GetModelVersion(mockClient, expected.GetId()) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal(t, expected.Name, actual.Name) + assert.Equal(t, *expected.Author, *actual.Author) + + mockClient.AssertExpectations(t) +} + +func TestCreateModelVersion(t *testing.T) { + gofakeit.Seed(0) + + expected := mocks.GenerateMockModelVersion() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + modelVersion := ModelVersion{} + + mockClient := new(mocks.MockHTTPClient) + mockClient.On("POST", modelVersionPath, mock.Anything).Return(mockData, nil) + + jsonInput, err := json.Marshal(expected) + assert.NoError(t, err) + + actual, err := modelVersion.CreateModelVersion(mockClient, jsonInput) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal(t, expected.Name, actual.Name) + assert.Equal(t, *expected.Author, *actual.Author) + + mockClient.AssertExpectations(t) +} + +func TestUpdateModelVersion(t *testing.T) { + gofakeit.Seed(0) + + expected := mocks.GenerateMockModelVersion() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + modelVersion := ModelVersion{} + + path, err := url.JoinPath(modelVersionPath, expected.GetId()) + assert.NoError(t, err) + + mockClient := new(mocks.MockHTTPClient) + mockClient.On(http.MethodPatch, path, mock.Anything).Return(mockData, nil) + + jsonInput, err := json.Marshal(expected) + assert.NoError(t, err) + + actual, err := modelVersion.UpdateModelVersion(mockClient, expected.GetId(), jsonInput) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal(t, expected.Name, actual.Name) + assert.Equal(t, *expected.Author, *actual.Author) + + mockClient.AssertExpectations(t) +} + +func TestGetModelArtifactsByModelVersion(t *testing.T) { + gofakeit.Seed(0) + + expected := mocks.GenerateMockModelArtifactList() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + modelVersion := ModelVersion{} + + path, err := url.JoinPath(modelVersionPath, "1", artifactsByModelVersionPath) + assert.NoError(t, err) + + mockClient := new(mocks.MockHTTPClient) + mockClient.On(http.MethodGet, path, mock.Anything).Return(mockData, nil) + + actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1") + assert.NoError(t, err) + + assert.NotNil(t, actual) + assert.Equal(t, expected.Size, actual.Size) + assert.Equal(t, expected.NextPageToken, actual.NextPageToken) + assert.Equal(t, expected.PageSize, actual.PageSize) + assert.Equal(t, len(expected.Items), len(actual.Items)) +} + +func TestCreateModelArtifactByModelVersion(t *testing.T) { + gofakeit.Seed(0) + + expected := mocks.GenerateMockModelArtifact() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + modelVersion := ModelVersion{} + + path, err := url.JoinPath(modelVersionPath, "1", artifactsByModelVersionPath) + assert.NoError(t, err) + + mockClient := new(mocks.MockHTTPClient) + mockClient.On(http.MethodPost, path, mock.Anything).Return(mockData, nil) + + jsonInnput, err := json.Marshal(expected) + assert.NoError(t, err) + + actual, err := modelVersion.CreateModelArtifactByModelVersion(mockClient, "1", jsonInnput) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal(t, expected.Name, actual.Name) + assert.Equal(t, expected.ArtifactType, actual.ArtifactType) +} diff --git a/clients/ui/bff/data/registered_model.go b/clients/ui/bff/data/registered_model.go index a6cad7aa..cfefbfed 100644 --- a/clients/ui/bff/data/registered_model.go +++ b/clients/ui/bff/data/registered_model.go @@ -10,12 +10,15 @@ import ( ) const registeredModelPath = "/registered_models" +const versionsPath = "/versions" type RegisteredModelInterface interface { GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error) GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error) UpdateRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.RegisteredModel, error) + GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) + CreateModelVersionForRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) } type RegisteredModel struct { @@ -90,3 +93,44 @@ func (m RegisteredModel) UpdateRegisteredModel(client integrations.HTTPClientInt return &model, nil } + +func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) { + path, err := url.JoinPath(registeredModelPath, id, versionsPath) + + if err != nil { + return nil, err + } + + responseData, err := client.GET(path) + + if err != nil { + return nil, fmt.Errorf("error fetching model versions: %w", err) + } + + var model openapi.ModelVersionList + if err := json.Unmarshal(responseData, &model); err != nil { + return nil, fmt.Errorf("error decoding response data: %w", err) + } + + return &model, nil +} + +func (m RegisteredModel) CreateModelVersionForRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) { + path, err := url.JoinPath(registeredModelPath, id, versionsPath) + + if err != nil { + return nil, err + } + + responseData, err := client.POST(path, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("error posting model version: %w", err) + } + + var model openapi.ModelVersion + if err := json.Unmarshal(responseData, &model); err != nil { + return nil, fmt.Errorf("error decoding response data: %w", err) + } + + return &model, nil +} diff --git a/clients/ui/bff/data/registered_model_test.go b/clients/ui/bff/data/registered_model_test.go index 2f1e86cb..871e1bc0 100644 --- a/clients/ui/bff/data/registered_model_test.go +++ b/clients/ui/bff/data/registered_model_test.go @@ -19,12 +19,12 @@ func TestGetAllRegisteredModels(t *testing.T) { mockData, err := json.Marshal(expected) assert.NoError(t, err) - mrClient := ModelRegistryClient{} + registeredModel := RegisteredModel{} mockClient := new(mocks.MockHTTPClient) mockClient.On("GET", registeredModelPath).Return(mockData, nil) - actual, err := mrClient.GetAllRegisteredModels(mockClient) + actual, err := registeredModel.GetAllRegisteredModels(mockClient) assert.NoError(t, err) assert.NotNil(t, actual) assert.Equal(t, expected.NextPageToken, actual.NextPageToken) @@ -38,12 +38,12 @@ func TestGetAllRegisteredModels(t *testing.T) { func TestCreateRegisteredModel(t *testing.T) { gofakeit.Seed(0) - expected := mocks.GenerateRegisteredModel() + expected := mocks.GenerateMockRegisteredModel() mockData, err := json.Marshal(expected) assert.NoError(t, err) - mrClient := ModelRegistryClient{} + registeredModel := RegisteredModel{} mockClient := new(mocks.MockHTTPClient) mockClient.On("POST", registeredModelPath, mock.Anything).Return(mockData, nil) @@ -51,7 +51,7 @@ func TestCreateRegisteredModel(t *testing.T) { jsonInput, err := json.Marshal(expected) assert.NoError(t, err) - actual, err := mrClient.CreateRegisteredModel(mockClient, jsonInput) + actual, err := registeredModel.CreateRegisteredModel(mockClient, jsonInput) assert.NoError(t, err) assert.NotNil(t, actual) assert.Equal(t, expected.Name, actual.Name) @@ -63,17 +63,17 @@ func TestCreateRegisteredModel(t *testing.T) { func TestGetRegisteredModel(t *testing.T) { gofakeit.Seed(0) - expected := mocks.GenerateRegisteredModel() + expected := mocks.GenerateMockRegisteredModel() mockData, err := json.Marshal(expected) assert.NoError(t, err) - mrClient := ModelRegistryClient{} + registeredModel := RegisteredModel{} mockClient := new(mocks.MockHTTPClient) mockClient.On("GET", registeredModelPath+"/"+expected.GetId()).Return(mockData, nil) - actual, err := mrClient.GetRegisteredModel(mockClient, expected.GetId()) + actual, err := registeredModel.GetRegisteredModel(mockClient, expected.GetId()) assert.NoError(t, err) assert.NotNil(t, actual) assert.Equal(t, expected.Name, actual.Name) @@ -85,12 +85,12 @@ func TestGetRegisteredModel(t *testing.T) { func TestUpdateRegisteredModel(t *testing.T) { gofakeit.Seed(0) - expected := mocks.GenerateRegisteredModel() + expected := mocks.GenerateMockRegisteredModel() mockData, err := json.Marshal(expected) assert.NoError(t, err) - mrClient := ModelRegistryClient{} + registeredModel := RegisteredModel{} path, err := url.JoinPath(registeredModelPath, expected.GetId()) assert.NoError(t, err) @@ -101,7 +101,7 @@ func TestUpdateRegisteredModel(t *testing.T) { jsonInput, err := json.Marshal(expected) assert.NoError(t, err) - actual, err := mrClient.UpdateRegisteredModel(mockClient, expected.GetId(), jsonInput) + actual, err := registeredModel.UpdateRegisteredModel(mockClient, expected.GetId(), jsonInput) assert.NoError(t, err) assert.NotNil(t, actual) assert.Equal(t, expected.Name, actual.Name) @@ -109,3 +109,59 @@ func TestUpdateRegisteredModel(t *testing.T) { mockClient.AssertExpectations(t) } + +func TestGetAllModelVersions(t *testing.T) { + gofakeit.Seed(0) + + expected := mocks.GenerateMockModelVersionList() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + registeredModel := RegisteredModel{} + + mockClient := new(mocks.MockHTTPClient) + path, err := url.JoinPath(registeredModelPath, "1", versionsPath) + assert.NoError(t, err) + mockClient.On("GET", path).Return(mockData, nil) + + actual, err := registeredModel.GetAllModelVersions(mockClient, "1") + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal(t, expected.NextPageToken, actual.NextPageToken) + assert.Equal(t, expected.PageSize, actual.PageSize) + assert.Equal(t, expected.Size, actual.Size) + assert.Equal(t, len(expected.Items), len(actual.Items)) + + mockClient.AssertExpectations(t) +} + +func TestCreateModelVersionForRegisteredModel(t *testing.T) { + gofakeit.Seed(0) + + expected := mocks.GenerateMockModelVersion() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + registeredModel := RegisteredModel{} + + path, err := url.JoinPath(registeredModelPath, "1", versionsPath) + assert.NoError(t, err) + mockClient := new(mocks.MockHTTPClient) + mockClient.On(http.MethodPost, path, mock.Anything).Return(mockData, nil) + + jsonInput, err := json.Marshal(expected) + assert.NoError(t, err) + + actual, err := registeredModel.CreateModelVersionForRegisteredModel(mockClient, "1", jsonInput) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal(t, expected.Name, actual.Name) + assert.Equal(t, *expected.Author, *actual.Author) + assert.Equal(t, expected.RegisteredModelId, actual.RegisteredModelId) + + mockClient.AssertExpectations(t) +} diff --git a/clients/ui/bff/internals/mocks/model_registry_client_mock.go b/clients/ui/bff/internals/mocks/model_registry_client_mock.go index 58af851e..375f5dbf 100644 --- a/clients/ui/bff/internals/mocks/model_registry_client_mock.go +++ b/clients/ui/bff/internals/mocks/model_registry_client_mock.go @@ -34,3 +34,37 @@ func (m *ModelRegistryClientMock) UpdateRegisteredModel(client integrations.HTTP mockData := GetRegisteredModelMocks()[0] return &mockData, nil } + +func (m *ModelRegistryClientMock) GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error) { + mockData := GetModelVersionMocks()[0] + return &mockData, nil +} + +func (m *ModelRegistryClientMock) CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error) { + mockData := GetModelVersionMocks()[0] + return &mockData, nil +} + +func (m *ModelRegistryClientMock) UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) { + mockData := GetModelVersionMocks()[0] + return &mockData, nil +} + +func (m *ModelRegistryClientMock) GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) { + mockData := GetModelVersionListMock() + return &mockData, nil +} + +func (m *ModelRegistryClientMock) CreateModelVersionForRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) { + mockData := GetModelVersionMocks()[0] + return &mockData, nil +} + +func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) { + mockData := GetModelArtifactListMock() + return &mockData, nil +} +func (m *ModelRegistryClientMock) CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) { + mockData := GetModelArtifactMocks()[0] + return &mockData, nil +} diff --git a/clients/ui/bff/internals/mocks/static_data_mock.go b/clients/ui/bff/internals/mocks/static_data_mock.go index 73221003..d534fde5 100644 --- a/clients/ui/bff/internals/mocks/static_data_mock.go +++ b/clients/ui/bff/internals/mocks/static_data_mock.go @@ -6,14 +6,7 @@ import ( func GetRegisteredModelMocks() []openapi.RegisteredModel { model1 := openapi.RegisteredModel{ - CustomProperties: &map[string]openapi.MetadataValue{ - "my-label9": { - MetadataStringValue: &openapi.MetadataStringValue{ - StringValue: "property9", - MetadataType: "string", - }, - }, - }, + CustomProperties: newCustomProperties(), Name: "Model One", Description: stringToPointer("This model does things and stuff"), ExternalId: stringToPointer("934589798"), @@ -25,14 +18,7 @@ func GetRegisteredModelMocks() []openapi.RegisteredModel { } model2 := openapi.RegisteredModel{ - CustomProperties: &map[string]openapi.MetadataValue{ - "my-label9": { - MetadataStringValue: &openapi.MetadataStringValue{ - StringValue: "property9", - MetadataType: "string", - }, - }, - }, + CustomProperties: newCustomProperties(), Name: "Model Two", Description: stringToPointer("This model does things and stuff"), ExternalId: stringToPointer("345235987"), @@ -56,3 +42,106 @@ func GetRegisteredModelListMock() openapi.RegisteredModelList { Items: models, } } + +func GetModelVersionMocks() []openapi.ModelVersion { + model1 := openapi.ModelVersion{ + CustomProperties: newCustomProperties(), + Name: "Version One", + Description: stringToPointer("This version improves stuff and things"), + ExternalId: stringToPointer("934589798"), + Id: stringToPointer("1"), + CreateTimeSinceEpoch: stringToPointer("1725282249921"), + LastUpdateTimeSinceEpoch: stringToPointer("1725282249921"), + RegisteredModelId: "1", + Author: stringToPointer("Sherlock Holmes"), + State: stateToPointer(openapi.MODELVERSIONSTATE_LIVE), + } + + model2 := openapi.ModelVersion{ + CustomProperties: newCustomProperties(), + Name: "Version Two", + Description: stringToPointer("This version improves stuff and things"), + ExternalId: stringToPointer("934589799"), + Id: stringToPointer("2"), + CreateTimeSinceEpoch: stringToPointer("1725282249921"), + LastUpdateTimeSinceEpoch: stringToPointer("1725282249921"), + RegisteredModelId: "2", + Author: stringToPointer("Sherlock Holmes"), + State: stateToPointer(openapi.MODELVERSIONSTATE_LIVE), + } + + return []openapi.ModelVersion{model1, model2} +} + +func GetModelVersionListMock() openapi.ModelVersionList { + versions := GetModelVersionMocks() + + return openapi.ModelVersionList{ + NextPageToken: "abcdefgh", + PageSize: 2, + Items: versions, + Size: 2, + } +} + +func GetModelArtifactMocks() []openapi.ModelArtifact { + artifact1 := openapi.ModelArtifact{ + ArtifactType: "TYPE_ONE", + CustomProperties: newCustomProperties(), + Description: stringToPointer("This artifact can do more than you would expect"), + ExternalId: stringToPointer("1000001"), + Uri: stringToPointer("http://localhost/artifacts/1"), + State: stateToPointer(openapi.ARTIFACTSTATE_LIVE), + Name: stringToPointer("Artifact One"), + Id: stringToPointer("1"), + CreateTimeSinceEpoch: stringToPointer("1725282249921"), + LastUpdateTimeSinceEpoch: stringToPointer("1725282249921"), + ModelFormatName: stringToPointer("ONNX"), + StorageKey: stringToPointer("key1"), + StoragePath: stringToPointer("/artifacts/1"), + ModelFormatVersion: stringToPointer("1.0.0"), + ServiceAccountName: stringToPointer("service-1"), + } + + artifact2 := openapi.ModelArtifact{ + ArtifactType: "TYPE_TWO", + CustomProperties: newCustomProperties(), + Description: stringToPointer("This artifact can do more than you would expect, but less than you would hope"), + ExternalId: stringToPointer("1000002"), + Uri: stringToPointer("http://localhost/artifacts/2"), + State: stateToPointer(openapi.ARTIFACTSTATE_PENDING), + Name: stringToPointer("Artifact Two"), + Id: stringToPointer("2"), + CreateTimeSinceEpoch: stringToPointer("1725282249921"), + LastUpdateTimeSinceEpoch: stringToPointer("1725282249921"), + ModelFormatName: stringToPointer("TensorFlow"), + StorageKey: stringToPointer("key2"), + StoragePath: stringToPointer("/artifacts/2"), + ModelFormatVersion: stringToPointer("1.0.0"), + ServiceAccountName: stringToPointer("service-2"), + } + + return []openapi.ModelArtifact{artifact1, artifact2} +} + +func GetModelArtifactListMock() openapi.ModelArtifactList { + return openapi.ModelArtifactList{ + NextPageToken: "abcdefgh", + PageSize: 2, + Items: GetModelArtifactMocks(), + Size: 2, + } +} + +func newCustomProperties() *map[string]openapi.MetadataValue { + result := map[string]openapi.MetadataValue{ + "my-label9": { + MetadataStringValue: &openapi.MetadataStringValue{ + StringValue: "property9", + MetadataType: "string", + }, + }, + } + + return &result +} diff --git a/clients/ui/bff/internals/mocks/types_mock.go b/clients/ui/bff/internals/mocks/types_mock.go index 7ab35211..f1d81a9a 100644 --- a/clients/ui/bff/internals/mocks/types_mock.go +++ b/clients/ui/bff/internals/mocks/types_mock.go @@ -9,7 +9,7 @@ import ( func GenerateMockRegisteredModelList() openapi.RegisteredModelList { var models []openapi.RegisteredModel for i := 0; i < 2; i++ { - model := GenerateRegisteredModel() + model := GenerateMockRegisteredModel() models = append(models, model) } @@ -21,7 +21,7 @@ func GenerateMockRegisteredModelList() openapi.RegisteredModelList { } } -func GenerateRegisteredModel() openapi.RegisteredModel { +func GenerateMockRegisteredModel() openapi.RegisteredModel { model := openapi.RegisteredModel{ CustomProperties: &map[string]openapi.MetadataValue{ "example_key": { @@ -35,15 +35,113 @@ func GenerateRegisteredModel() openapi.RegisteredModel { ExternalId: stringToPointer(gofakeit.UUID()), Name: gofakeit.Name(), Id: stringToPointer(gofakeit.UUID()), - CreateTimeSinceEpoch: stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli())), - LastUpdateTimeSinceEpoch: stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli())), + CreateTimeSinceEpoch: randomEpochTime(), + LastUpdateTimeSinceEpoch: randomEpochTime(), Owner: stringToPointer(gofakeit.Name()), State: stateToPointer(openapi.RegisteredModelState(gofakeit.RandomString([]string{string(openapi.REGISTEREDMODELSTATE_LIVE), string(openapi.REGISTEREDMODELSTATE_ARCHIVED)}))), } return model } -func stateToPointer(s openapi.RegisteredModelState) *openapi.RegisteredModelState { +func GenerateMockModelVersion() openapi.ModelVersion { + model := openapi.ModelVersion{ + CustomProperties: &map[string]openapi.MetadataValue{ + "example_key": { + MetadataStringValue: &openapi.MetadataStringValue{ + StringValue: gofakeit.Sentence(3), + MetadataType: "string", + }, + }, + }, + Description: stringToPointer(gofakeit.Sentence(5)), + ExternalId: stringToPointer(gofakeit.UUID()), + Name: gofakeit.Name(), + Id: stringToPointer(gofakeit.UUID()), + CreateTimeSinceEpoch: randomEpochTime(), + LastUpdateTimeSinceEpoch: randomEpochTime(), + Author: stringToPointer(gofakeit.Name()), + State: stateToPointer(openapi.ModelVersionState(gofakeit.RandomString([]string{string(openapi.MODELVERSIONSTATE_LIVE), string(openapi.MODELVERSIONSTATE_ARCHIVED)}))), + } + return model +} + +func GenerateMockModelVersionList() openapi.ModelVersionList { + var versions []openapi.ModelVersion + + for i := 0; i < 2; i++ { + version := GenerateMockModelVersion() + versions = append(versions, version) + } + + return openapi.ModelVersionList{ + NextPageToken: gofakeit.UUID(), + PageSize: int32(gofakeit.Number(1, 20)), + Size: int32(len(versions)), + Items: versions, + } +} + +func GenerateMockModelArtifact() openapi.ModelArtifact { + artifact := openapi.ModelArtifact{ + ArtifactType: gofakeit.Word(), + CustomProperties: &map[string]openapi.MetadataValue{ + "example_key": { + MetadataStringValue: &openapi.MetadataStringValue{ + StringValue: gofakeit.Sentence(3), + MetadataType: "string", + }, + }, + }, + Description: stringToPointer(gofakeit.Sentence(5)), + ExternalId: stringToPointer(gofakeit.UUID()), + Uri: stringToPointer(gofakeit.URL()), + State: randomArtifactState(), + Name: stringToPointer(gofakeit.Name()), + Id: stringToPointer(gofakeit.UUID()), + CreateTimeSinceEpoch: randomEpochTime(), + LastUpdateTimeSinceEpoch: randomEpochTime(), + ModelFormatName: stringToPointer(gofakeit.Name()), + StorageKey: stringToPointer(gofakeit.Word()), + StoragePath: stringToPointer("/" + gofakeit.Word() + "/" + gofakeit.Word()), + ModelFormatVersion: stringToPointer(gofakeit.AppVersion()), + ServiceAccountName: stringToPointer(gofakeit.Username()), + } + return artifact +} + +func GenerateMockModelArtifactList() openapi.ModelArtifactList { + var artifacts []openapi.ModelArtifact + + for i := 0; i < 2; i++ { + artifact := GenerateMockModelArtifact() + artifacts = append(artifacts, artifact) + } + + return openapi.ModelArtifactList{ + NextPageToken: gofakeit.UUID(), + PageSize: int32(gofakeit.Number(1, 20)), + Size: int32(len(artifacts)), + Items: artifacts, + } +} + +func randomEpochTime() *string { + return stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli())) +} + +func randomArtifactState() *openapi.ArtifactState { + return stateToPointer(openapi.ArtifactState(gofakeit.RandomString([]string{ + string(openapi.ARTIFACTSTATE_LIVE), + string(openapi.ARTIFACTSTATE_DELETED), + string(openapi.ARTIFACTSTATE_ABANDONED), + string(openapi.ARTIFACTSTATE_MARKED_FOR_DELETION), + string(openapi.ARTIFACTSTATE_PENDING), + string(openapi.ARTIFACTSTATE_REFERENCE), + string(openapi.ARTIFACTSTATE_UNKNOWN), + }))) +} + +func stateToPointer[T any](s T) *T { return &s } diff --git a/clients/ui/bff/validation/test_helpers.go b/clients/ui/bff/validation/test_helpers.go new file mode 100644 index 00000000..29f5ab37 --- /dev/null +++ b/clients/ui/bff/validation/test_helpers.go @@ -0,0 +1,25 @@ +package validation + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +type testSpec[T any] struct { + name string + input T + wantErr bool +} + +func validateTestSpecs[T any](t *testing.T, specs []testSpec[T], validator func(input T) error) { + for _, tt := range specs { + t.Run(tt.name, func(t *testing.T) { + err := validator(tt.input) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/clients/ui/bff/validation/validation.go b/clients/ui/bff/validation/validation.go index 52b78073..2c988bfc 100644 --- a/clients/ui/bff/validation/validation.go +++ b/clients/ui/bff/validation/validation.go @@ -12,3 +12,19 @@ func ValidateRegisteredModel(input openapi.RegisteredModel) error { // Add more field validations as required return nil } + +func ValidateModelVersion(input openapi.ModelVersion) error { + if input.Name == "" { + return errors.New("name cannot be empty") + } + // Add more field validations as required + return nil +} + +func ValidateModelArtifact(input openapi.ModelArtifact) error { + if input.GetName() == "" { + return errors.New("name cannot be empty") + } + // Add more field validations as required + return nil +} diff --git a/clients/ui/bff/validation/validation_test.go b/clients/ui/bff/validation/validation_test.go index 338427e1..b79f4bba 100644 --- a/clients/ui/bff/validation/validation_test.go +++ b/clients/ui/bff/validation/validation_test.go @@ -2,16 +2,11 @@ package validation import ( "github.com/kubeflow/model-registry/pkg/openapi" - "github.com/stretchr/testify/assert" "testing" ) func TestValidateRegisteredModel(t *testing.T) { - tests := []struct { - name string - input openapi.RegisteredModel - wantErr bool - }{ + specs := []testSpec[openapi.RegisteredModel]{ { name: "Empty name", input: openapi.RegisteredModel{Name: ""}, @@ -24,14 +19,39 @@ func TestValidateRegisteredModel(t *testing.T) { }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := ValidateRegisteredModel(tt.input) - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) + validateTestSpecs(t, specs, ValidateRegisteredModel) +} + +func TestValidateModelVersion(t *testing.T) { + specs := []testSpec[openapi.ModelVersion]{ + { + name: "Empty name", + input: openapi.ModelVersion{Name: ""}, + wantErr: true, + }, + { + name: "Valid name", + input: openapi.ModelVersion{Name: "ValidName"}, + wantErr: false, + }, } + + validateTestSpecs(t, specs, ValidateModelVersion) +} + +func TestValidateModel(t *testing.T) { + specs := []testSpec[openapi.ModelArtifact]{ + { + name: "Empty name", + input: openapi.ModelArtifact{Name: openapi.PtrString("")}, + wantErr: true, + }, + { + name: "Valid name", + input: openapi.ModelArtifact{Name: openapi.PtrString("ValidName")}, + wantErr: false, + }, + } + + validateTestSpecs(t, specs, ValidateModelArtifact) }