Skip to content

Commit

Permalink
Adds ability to serve mocked data for MR API calls via cli flag
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Creasy <[email protected]>
  • Loading branch information
alexcreasy committed Sep 11, 2024
1 parent 67bd67d commit b3ef316
Show file tree
Hide file tree
Showing 13 changed files with 300 additions and 24 deletions.
3 changes: 2 additions & 1 deletion clients/ui/bff/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ CONTAINER_TOOL ?= docker
IMG ?= model-registry-bff:latest
PORT ?= 4000
MOCK_K8S_CLIENT ?= false
MOCK_MR_CLIENT ?= false

.PHONY: all
all: build
Expand Down Expand Up @@ -32,7 +33,7 @@ build: fmt vet test

.PHONY: run
run: fmt vet
go run ./cmd/main.go --port=$(PORT) --mock-k8s-client=$(MOCK_K8S_CLIENT)
go run ./cmd/main.go --port=$(PORT) --mock-k8s-client=$(MOCK_K8S_CLIENT) --mock-mr-client=$(MOCK_MR_CLIENT)

.PHONY: docker-build
docker-build:
Expand Down
4 changes: 2 additions & 2 deletions clients/ui/bff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ After building it, you can run our app with:
```shell
make run
```
If you want to use a different port or mock kubernetes client, useful for front-end development, you can run:
If you want to use a different port, mock kubernetes client or model registry client - useful for front-end development, you can run:
```shell
make run PORT=8000 MOCK_K8S_CLIENT=true
make run PORT=8000 MOCK_K8S_CLIENT=true MOCK_MR_CLIENT=true
```

# Building and Deploying
Expand Down
30 changes: 22 additions & 8 deletions clients/ui/bff/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ const (
)

type App struct {
config config.EnvConfig
logger *slog.Logger
models data.Models
kubernetesClient integrations.KubernetesClientInterface
config config.EnvConfig
logger *slog.Logger
models data.Models
kubernetesClient integrations.KubernetesClientInterface
modelRegistryClient data.ModelRegistryClientInterface
}

func NewApp(cfg config.EnvConfig, logger *slog.Logger) (*App, error) {
Expand All @@ -43,10 +44,23 @@ func NewApp(cfg config.EnvConfig, logger *slog.Logger) (*App, error) {
return nil, fmt.Errorf("failed to create Kubernetes client: %w", err)
}

var mrClient data.ModelRegistryClientInterface

if cfg.MockMRClient {
mrClient, err = mocks.NewModelRegistryClient(logger)
} else {
mrClient, err = data.NewModelRegistryClient(logger)
}

if err != nil {
return nil, fmt.Errorf("failed to create ModelRegistry client: %w", err)
}

app := &App{
config: cfg,
logger: logger,
kubernetesClient: k8sClient,
config: cfg,
logger: logger,
kubernetesClient: k8sClient,
modelRegistryClient: mrClient,
}
return app, nil
}
Expand All @@ -59,7 +73,7 @@ func (app *App) Routes() http.Handler {

// HTTP client routes
router.GET(HealthCheckPath, app.HealthcheckHandler)
router.GET(RegisteredModelsPath, app.AttachRESTClient(app.GetRegisteredModelsHandler))
router.GET(RegisteredModelsPath, app.AttachRESTClient(app.GetAllRegisteredModelsHandler))
router.GET(RegisteredModelPath, app.AttachRESTClient(app.GetRegisteredModelHandler))
router.POST(RegisteredModelsPath, app.AttachRESTClient(app.CreateRegisteredModelHandler))

Expand Down
2 changes: 2 additions & 0 deletions clients/ui/bff/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (

type Envelope map[string]interface{}

type TypedEnvelope[T any] map[string]T

func (app *App) WriteJSON(w http.ResponseWriter, status int, data any, headers http.Header) error {

js, err := json.MarshalIndent(data, "", "\t")
Expand Down
13 changes: 6 additions & 7 deletions clients/ui/bff/api/registered_models_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,27 @@ import (
"fmt"
"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/pkg/openapi"
"github.com/kubeflow/model-registry/ui/bff/data"
"github.com/kubeflow/model-registry/ui/bff/integrations"
"github.com/kubeflow/model-registry/ui/bff/validation"
"net/http"
)

func (app *App) GetRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
//TODO (ederign) implement pagination
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

modelList, err := data.GetAllRegisteredModels(client)
modelList, err := app.modelRegistryClient.GetAllRegisteredModels(client)
if err != nil {
app.serverErrorResponse(w, r, err)
return
}

modelRegistryRes := Envelope{
"registered_models": modelList,
"registered_model_list": modelList,
}

err = app.WriteJSON(w, http.StatusOK, modelRegistryRes, nil)
Expand Down Expand Up @@ -60,7 +59,7 @@ func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ
return
}

createdModel, err := data.CreateRegisteredModel(client, jsonData)
createdModel, err := app.modelRegistryClient.CreateRegisteredModel(client, jsonData)
if err != nil {
var httpErr *integrations.HTTPError
if errors.As(err, &httpErr) {
Expand Down Expand Up @@ -91,13 +90,13 @@ func (app *App) GetRegisteredModelHandler(w http.ResponseWriter, r *http.Request
return
}

model, err := data.GetRegisteredModel(client, ps.ByName(RegisteredModelId))
model, err := app.modelRegistryClient.GetRegisteredModel(client, ps.ByName(RegisteredModelId))
if err != nil {
app.serverErrorResponse(w, r, err)
return
}

if _, ok := model.GetNameOk(); !ok {
if _, ok := model.GetIdOk(); !ok {
app.notFoundResponse(w, r)
return
}
Expand Down
135 changes: 135 additions & 0 deletions clients/ui/bff/api/registered_models_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package api

import (
"bytes"
"context"
"encoding/json"
"github.com/kubeflow/model-registry/pkg/openapi"
"github.com/kubeflow/model-registry/ui/bff/internals/mocks"
"github.com/stretchr/testify/assert"
"io"
"net/http"
"net/http/httptest"
"testing"
)

func TestGetRegisteredModelHandler(t *testing.T) {
mockMRClient, _ := mocks.NewModelRegistryClient(nil)
mockClient := new(mocks.MockHTTPClient)

testApp := App{
modelRegistryClient: mockMRClient,
}

req, err := http.NewRequest(http.MethodGet,
"/api/v1/model-registry/model-registry/registered_models/1", nil)
assert.NoError(t, err)

ctx := context.WithValue(req.Context(), httpClientKey, mockClient)
req = req.WithContext(ctx)

rr := httptest.NewRecorder()

testApp.GetRegisteredModelHandler(rr, req, nil)
rs := rr.Result()

defer rs.Body.Close()

body, err := io.ReadAll(rs.Body)
assert.NoError(t, err)
var registeredModelRes TypedEnvelope[openapi.RegisteredModel]
err = json.Unmarshal(body, &registeredModelRes)
assert.NoError(t, err)

assert.Equal(t, http.StatusOK, rr.Code)

var expected = TypedEnvelope[openapi.RegisteredModel]{
"registered_model": mocks.GetRegisteredModelMocks()[0],
}

//TODO assert the full structure, I couldn't get unmarshalling to work for the full customProperties values
// this issue is in the test only
assert.Equal(t, expected["registered_model"].Name, registeredModelRes["registered_model"].Name)
}

func TestGetAllRegisteredModelsHandler(t *testing.T) {
mockMRClient, _ := mocks.NewModelRegistryClient(nil)
mockClient := new(mocks.MockHTTPClient)

testApp := App{
modelRegistryClient: mockMRClient,
}

req, err := http.NewRequest(http.MethodGet,
"/api/v1/model-registry/model-registry/registered_models", nil)
assert.NoError(t, err)

ctx := context.WithValue(req.Context(), httpClientKey, mockClient)
req = req.WithContext(ctx)

rr := httptest.NewRecorder()

testApp.GetAllRegisteredModelsHandler(rr, req, nil)
rs := rr.Result()

defer rs.Body.Close()

body, err := io.ReadAll(rs.Body)
assert.NoError(t, err)
var registeredModelsListRes TypedEnvelope[openapi.RegisteredModelList]
err = json.Unmarshal(body, &registeredModelsListRes)
assert.NoError(t, err)

assert.Equal(t, http.StatusOK, rr.Code)

var expected = TypedEnvelope[openapi.RegisteredModelList]{
"registered_model_list": mocks.GetRegisteredModelListMock(),
}

assert.Equal(t, expected["registered_model_list"].Size, registeredModelsListRes["registered_model_list"].Size)
assert.Equal(t, expected["registered_model_list"].PageSize, registeredModelsListRes["registered_model_list"].PageSize)
assert.Equal(t, expected["registered_model_list"].NextPageToken, registeredModelsListRes["registered_model_list"].NextPageToken)
assert.Equal(t, len(expected["registered_model_list"].Items), len(registeredModelsListRes["registered_model_list"].Items))
}

func TestCreateRegisteredModelHandler(t *testing.T) {
mockMRClient, _ := mocks.NewModelRegistryClient(nil)
mockClient := new(mocks.MockHTTPClient)

testApp := App{
modelRegistryClient: mockMRClient,
}

newModel := openapi.NewRegisteredModelCreate("Model One")
newModelJSON, err := newModel.MarshalJSON()
assert.NoError(t, err)

reqBody := bytes.NewReader(newModelJSON)

req, err := http.NewRequest(http.MethodPost,
"/api/v1/model-registry/model-registry/registered_models", reqBody)
assert.NoError(t, err)

ctx := context.WithValue(req.Context(), httpClientKey, mockClient)
req = req.WithContext(ctx)

rr := httptest.NewRecorder()

testApp.CreateRegisteredModelHandler(rr, req, nil)
rs := rr.Result()

defer rs.Body.Close()

body, err := io.ReadAll(rs.Body)
assert.NoError(t, err)
var registeredModelRes openapi.RegisteredModel
err = json.Unmarshal(body, &registeredModelRes)
assert.NoError(t, err)

assert.Equal(t, http.StatusCreated, rr.Code)

var expected = mocks.GetRegisteredModelMocks()[0]

assert.Equal(t, expected.Name, registeredModelRes.Name)
assert.NotEmpty(t, rs.Header.Get("location"))
}
1 change: 1 addition & 0 deletions clients/ui/bff/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func main() {
var cfg config.EnvConfig
flag.IntVar(&cfg.Port, "port", getEnvAsInt("PORT", 4000), "API server port")
flag.BoolVar(&cfg.MockK8Client, "mock-k8s-client", false, "Use mock Kubernetes client")
flag.BoolVar(&cfg.MockMRClient, "mock-mr-client", false, "Use mock Model Registry client")
flag.Parse()

logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
Expand Down
1 change: 1 addition & 0 deletions clients/ui/bff/config/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ package config
type EnvConfig struct {
Port int
MockK8Client bool
MockMRClient bool
}
18 changes: 18 additions & 0 deletions clients/ui/bff/data/model_registry_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package data

import (
"log/slog"
)

type ModelRegistryClientInterface interface {
RegisteredModelInterface
}

type ModelRegistryClient struct {
logger *slog.Logger
RegisteredModel
}

func NewModelRegistryClient(logger *slog.Logger) (ModelRegistryClientInterface, error) {
return &ModelRegistryClient{logger: logger}, nil
}
16 changes: 13 additions & 3 deletions clients/ui/bff/data/registered_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,17 @@ import (

const registerModelPath = "/registered_models"

func GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) {
type RegisteredModelInterface interface {
GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error)
CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error)
GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error)
}

type RegisteredModel struct {
RegisteredModelInterface
}

func (m RegisteredModel) GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) {

responseData, err := client.GET(registerModelPath)
if err != nil {
Expand All @@ -26,7 +36,7 @@ func GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.R
return &modelList, nil
}

func CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error) {
func (m RegisteredModel) CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error) {
responseData, err := client.POST(registerModelPath, bytes.NewBuffer(jsonData))

if err != nil {
Expand All @@ -41,7 +51,7 @@ func CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []b
return &model, nil
}

func GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error) {
func (m RegisteredModel) GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error) {
path, err := url.JoinPath(registerModelPath, id)
if err != nil {
return nil, err
Expand Down
12 changes: 9 additions & 3 deletions clients/ui/bff/data/registered_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ func TestGetAllRegisteredModels(t *testing.T) {
mockData, err := json.Marshal(expected)
assert.NoError(t, err)

mrClient := ModelRegistryClient{}

mockClient := new(mocks.MockHTTPClient)
mockClient.On("GET", registerModelPath).Return(mockData, nil)

actual, err := GetAllRegisteredModels(mockClient)
actual, err := mrClient.GetAllRegisteredModels(mockClient)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.Equal(t, expected.NextPageToken, actual.NextPageToken)
Expand All @@ -39,13 +41,15 @@ func TestCreateRegisteredModel(t *testing.T) {
mockData, err := json.Marshal(expected)
assert.NoError(t, err)

mrClient := ModelRegistryClient{}

mockClient := new(mocks.MockHTTPClient)
mockClient.On("POST", registerModelPath, mock.Anything).Return(mockData, nil)

jsonInput, err := json.Marshal(expected)
assert.NoError(t, err)

actual, err := CreateRegisteredModel(mockClient, jsonInput)
actual, err := mrClient.CreateRegisteredModel(mockClient, jsonInput)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.Equal(t, expected.Name, actual.Name)
Expand All @@ -62,10 +66,12 @@ func TestGetRegisteredModel(t *testing.T) {
mockData, err := json.Marshal(expected)
assert.NoError(t, err)

mrClient := ModelRegistryClient{}

mockClient := new(mocks.MockHTTPClient)
mockClient.On("GET", registerModelPath+"/"+expected.GetId()).Return(mockData, nil)

actual, err := GetRegisteredModel(mockClient, expected.GetId())
actual, err := mrClient.GetRegisteredModel(mockClient, expected.GetId())
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.Equal(t, expected.Name, actual.Name)
Expand Down
Loading

0 comments on commit b3ef316

Please sign in to comment.