Skip to content

Commit

Permalink
Update API models to use envelope format consistently (kubeflow#381)
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Creasy <[email protected]>
  • Loading branch information
alexcreasy authored Sep 13, 2024
1 parent cc6455f commit 95e6b7f
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 52 deletions.
10 changes: 5 additions & 5 deletions clients/ui/bff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ curl -i localhost:4000/api/v1/model_registry
```
```
# GET /v1/model_registry/{model_registry_id}/registered_models
curl -i localhost:4000/api/v1/model_registry/model_registry_1/registered_models
curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models
```
```
#POST /v1/model_registry/{model_registry_id}/registered_models
curl -i -X POST "http://localhost:4000/api/v1/model_registry/model_registry/registered_models" \
curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models" \
-H "Content-Type: application/json" \
-d '{
-d '{ "data": {
"customProperties": {
"my-label9": {
"metadataType": "MetadataStringValue",
Expand All @@ -91,9 +91,9 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model_registry/regi
"name": "bella",
"owner": "eder",
"state": "LIVE"
}'
}}'
```
```
# GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}
curl -i localhost:4000/api/v1/model_registry/model_registry/registered_models/1
curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models/1
```
6 changes: 5 additions & 1 deletion clients/ui/bff/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ type ErrorResponse struct {
Message string `json:"message"`
}

type ErrorEnvelope struct {
Error *integrations.HTTPError `json:"error"`
}

func (app *App) LogError(r *http.Request, err error) {
var (
method = r.Method
Expand All @@ -40,7 +44,7 @@ func (app *App) badRequestResponse(w http.ResponseWriter, r *http.Request, err e

func (app *App) errorResponse(w http.ResponseWriter, r *http.Request, error *integrations.HTTPError) {

env := Envelope{"error": error}
env := ErrorEnvelope{Error: error}

err := app.WriteJSON(w, error.StatusCode, env, nil)

Expand Down
13 changes: 10 additions & 3 deletions clients/ui/bff/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ import (
"strings"
)

type Envelope map[string]interface{}
type Envelope[D any, M any] struct {
Data D `json:"data,omitempty"`
Metadata M `json:"metadata,omitempty"`
}

type TypedEnvelope[T any] map[string]T
type None *struct{}

func (app *App) WriteJSON(w http.ResponseWriter, status int, data any, headers http.Header) error {

Expand All @@ -29,7 +32,11 @@ func (app *App) WriteJSON(w http.ResponseWriter, status int, data any, headers h

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
w.Write(js)
_, err = w.Write(js)

if err != nil {
return err
}

return nil
}
Expand Down
7 changes: 5 additions & 2 deletions clients/ui/bff/api/model_registry_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package api

import (
"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/ui/bff/data"
"net/http"
)

type ModelRegistryListEnvelope Envelope[[]data.ModelRegistryModel, None]

func (app *App) ModelRegistryHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {

registries, err := app.models.ModelRegistry.FetchAllModelRegistries(app.kubernetesClient)
Expand All @@ -13,8 +16,8 @@ func (app *App) ModelRegistryHandler(w http.ResponseWriter, r *http.Request, ps
return
}

modelRegistryRes := Envelope{
"model_registry": registries,
modelRegistryRes := ModelRegistryListEnvelope{
Data: registries,
}

err = app.WriteJSON(w, http.StatusOK, modelRegistryRes, nil)
Expand Down
18 changes: 5 additions & 13 deletions clients/ui/bff/api/model_registry_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,20 @@ func TestModelRegistryHandler(t *testing.T) {
defer rs.Body.Close()
body, err := io.ReadAll(rs.Body)
assert.NoError(t, err)
var modelRegistryRes Envelope
err = json.Unmarshal(body, &modelRegistryRes)
var actual ModelRegistryListEnvelope
err = json.Unmarshal(body, &actual)
assert.NoError(t, err)

assert.Equal(t, http.StatusOK, rr.Code)

// Convert the unmarshalled data to the expected type
actualModelRegistry := make([]data.ModelRegistryModel, 0)
for _, v := range modelRegistryRes["model_registry"].([]interface{}) {
model := v.(map[string]interface{})
actualModelRegistry = append(actualModelRegistry, data.ModelRegistryModel{Name: model["name"].(string), Description: model["description"].(string), DisplayName: model["displayName"].(string)})
}
modelRegistryRes["model_registry"] = actualModelRegistry

var expected = Envelope{
"model_registry": []data.ModelRegistryModel{
var expected = ModelRegistryListEnvelope{
Data: []data.ModelRegistryModel{
{Name: "model-registry", Description: "Model registry description", DisplayName: "Model Registry"},
{Name: "model-registry-dora", Description: "Model registry dora description", DisplayName: "Model Registry Dora"},
{Name: "model-registry-bella", Description: "Model registry bella description", DisplayName: "Model Registry Bella"},
},
}

assert.Equal(t, expected, modelRegistryRes)
assert.Equal(t, expected, actual)

}
26 changes: 17 additions & 9 deletions clients/ui/bff/api/registered_models_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import (
"net/http"
)

type RegisteredModelEnvelope Envelope[*openapi.RegisteredModel, None]
type RegisteredModelListEnvelope Envelope[*openapi.RegisteredModelList, None]

func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
//TODO (ederign) implement pagination
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
Expand All @@ -25,8 +28,8 @@ func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Req
return
}

modelRegistryRes := Envelope{
"registered_model_list": modelList,
modelRegistryRes := RegisteredModelListEnvelope{
Data: modelList,
}

err = app.WriteJSON(w, http.StatusOK, modelRegistryRes, nil)
Expand All @@ -42,18 +45,20 @@ func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ
return
}

var model openapi.RegisteredModel
if err := json.NewDecoder(r.Body).Decode(&model); err != nil {
var envelope RegisteredModelEnvelope
if err := json.NewDecoder(r.Body).Decode(&envelope); err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("error decoding JSON:: %v", err.Error()))
return
}

if err := validation.ValidateRegisteredModel(model); err != nil {
data := *envelope.Data

if err := validation.ValidateRegisteredModel(data); err != nil {
app.badRequestResponse(w, r, fmt.Errorf("validation error:: %v", err.Error()))
return
}

jsonData, err := json.Marshal(model)
jsonData, err := json.Marshal(data)
if err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("error marshaling model to JSON: %w", err))
return
Expand All @@ -75,8 +80,11 @@ func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ
return
}

response := RegisteredModelEnvelope{
Data: createdModel,
}
w.Header().Set("Location", fmt.Sprintf("%s/%s", RegisteredModelsPath, *createdModel.Id))
err = app.WriteJSON(w, http.StatusCreated, createdModel, nil)
err = app.WriteJSON(w, http.StatusCreated, response, nil)
if err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("error writing JSON"))
return
Expand All @@ -101,8 +109,8 @@ func (app *App) GetRegisteredModelHandler(w http.ResponseWriter, r *http.Request
return
}

result := Envelope{
"registered_model": model,
result := RegisteredModelEnvelope{
Data: model,
}

err = app.WriteJSON(w, http.StatusOK, result, nil)
Expand Down
44 changes: 25 additions & 19 deletions clients/ui/bff/api/registered_models_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestGetRegisteredModelHandler(t *testing.T) {
}

req, err := http.NewRequest(http.MethodGet,
"/api/v1/model-registry/model-registry/registered_models/1", nil)
"/api/v1/model_registry/model-registry/registered_models/1", nil)
assert.NoError(t, err)

ctx := context.WithValue(req.Context(), httpClientKey, mockClient)
Expand All @@ -37,19 +37,21 @@ func TestGetRegisteredModelHandler(t *testing.T) {

body, err := io.ReadAll(rs.Body)
assert.NoError(t, err)
var registeredModelRes TypedEnvelope[openapi.RegisteredModel]
var registeredModelRes RegisteredModelEnvelope
err = json.Unmarshal(body, &registeredModelRes)
assert.NoError(t, err)

assert.Equal(t, http.StatusOK, rr.Code)

var expected = TypedEnvelope[openapi.RegisteredModel]{
"registered_model": mocks.GetRegisteredModelMocks()[0],
mockModel := mocks.GetRegisteredModelMocks()[0]

var expected = RegisteredModelEnvelope{
Data: &mockModel,
}

//TODO assert the full structure, I couldn't get unmarshalling to work for the full customProperties values
// this issue is in the test only
assert.Equal(t, expected["registered_model"].Name, registeredModelRes["registered_model"].Name)
assert.Equal(t, expected.Data.Name, registeredModelRes.Data.Name)
}

func TestGetAllRegisteredModelsHandler(t *testing.T) {
Expand All @@ -61,7 +63,7 @@ func TestGetAllRegisteredModelsHandler(t *testing.T) {
}

req, err := http.NewRequest(http.MethodGet,
"/api/v1/model-registry/model-registry/registered_models", nil)
"/api/v1/model_registry/model-registry/registered_models", nil)
assert.NoError(t, err)

ctx := context.WithValue(req.Context(), httpClientKey, mockClient)
Expand All @@ -76,20 +78,22 @@ func TestGetAllRegisteredModelsHandler(t *testing.T) {

body, err := io.ReadAll(rs.Body)
assert.NoError(t, err)
var registeredModelsListRes TypedEnvelope[openapi.RegisteredModelList]
var registeredModelsListRes RegisteredModelListEnvelope
err = json.Unmarshal(body, &registeredModelsListRes)
assert.NoError(t, err)

assert.Equal(t, http.StatusOK, rr.Code)

var expected = TypedEnvelope[openapi.RegisteredModelList]{
"registered_model_list": mocks.GetRegisteredModelListMock(),
modelList := mocks.GetRegisteredModelListMock()

var expected = RegisteredModelListEnvelope{
Data: &modelList,
}

assert.Equal(t, expected["registered_model_list"].Size, registeredModelsListRes["registered_model_list"].Size)
assert.Equal(t, expected["registered_model_list"].PageSize, registeredModelsListRes["registered_model_list"].PageSize)
assert.Equal(t, expected["registered_model_list"].NextPageToken, registeredModelsListRes["registered_model_list"].NextPageToken)
assert.Equal(t, len(expected["registered_model_list"].Items), len(registeredModelsListRes["registered_model_list"].Items))
assert.Equal(t, expected.Data.Size, registeredModelsListRes.Data.Size)
assert.Equal(t, expected.Data.PageSize, registeredModelsListRes.Data.PageSize)
assert.Equal(t, expected.Data.NextPageToken, registeredModelsListRes.Data.NextPageToken)
assert.Equal(t, len(expected.Data.Items), len(registeredModelsListRes.Data.Items))
}

func TestCreateRegisteredModelHandler(t *testing.T) {
Expand All @@ -100,14 +104,16 @@ func TestCreateRegisteredModelHandler(t *testing.T) {
modelRegistryClient: mockMRClient,
}

newModel := openapi.NewRegisteredModelCreate("Model One")
newModelJSON, err := newModel.MarshalJSON()
newModel := openapi.NewRegisteredModel("Model One")
newEnvelope := RegisteredModelEnvelope{Data: newModel}

newModelJSON, err := json.Marshal(newEnvelope)
assert.NoError(t, err)

reqBody := bytes.NewReader(newModelJSON)

req, err := http.NewRequest(http.MethodPost,
"/api/v1/model-registry/model-registry/registered_models", reqBody)
"/api/v1/model_registry/model-registry/registered_models", reqBody)
assert.NoError(t, err)

ctx := context.WithValue(req.Context(), httpClientKey, mockClient)
Expand All @@ -122,14 +128,14 @@ func TestCreateRegisteredModelHandler(t *testing.T) {

body, err := io.ReadAll(rs.Body)
assert.NoError(t, err)
var registeredModelRes openapi.RegisteredModel
err = json.Unmarshal(body, &registeredModelRes)
var actual RegisteredModelEnvelope
err = json.Unmarshal(body, &actual)
assert.NoError(t, err)

assert.Equal(t, http.StatusCreated, rr.Code)

var expected = mocks.GetRegisteredModelMocks()[0]

assert.Equal(t, expected.Name, registeredModelRes.Name)
assert.Equal(t, expected.Name, actual.Data.Name)
assert.NotEmpty(t, rs.Header.Get("location"))
}

0 comments on commit 95e6b7f

Please sign in to comment.