Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements pagination support for GetAll style endpoints #429

Merged
merged 1 commit into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion clients/ui/bff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,24 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/mode
"state": "LIVE",
"artifactType": "TYPE_ONE"
}}'
```
```

### Pagination
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied the definitions here off the MR OpenAPI spec.

The following query parameters are supported by "Get All" style endpoints to control pagination.

| Parameter Name | Description |
|----------------|-----------------------------------------------------------------------------------------------------------|
| pageSize | Number of entities in each page |
| orderBy | Specifies the order by criteria for listing entities. Available values: CREATE_TIME, LAST_UPDATE_TIME, ID |
| sortOrder | Specifies the sort order for listing entities. Available values: ASC, DESC. Default: ASC |
| nextPageToken | Token to use to retrieve next page of results. |

### Sample local calls
```
# Get with a page size of 5 getting a specific page.
curl -i "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?pageSize=5&nextPageToken=CAEQARoCCAE"
```
```
# Get with a page size of 5, order by last update time in descending order.
curl -i "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?pageSize=5&orderBy=LAST_UPDATE_TIME&sortOrder=DESC"
```
2 changes: 1 addition & 1 deletion clients/ui/bff/internal/api/model_versions_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter,
return
}

data, err := app.modelRegistryClient.GetModelArtifactsByModelVersion(client, ps.ByName(ModelVersionId))
data, err := app.modelRegistryClient.GetModelArtifactsByModelVersion(client, ps.ByName(ModelVersionId), r.URL.Query())
if err != nil {
app.serverErrorResponse(w, r, err)
return
Expand Down
15 changes: 6 additions & 9 deletions clients/ui/bff/internal/api/registered_models_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,25 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/pkg/openapi"
"github.com/kubeflow/model-registry/ui/bff/internal/integrations"
"github.com/kubeflow/model-registry/ui/bff/internal/validation"
"net/http"

"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/pkg/openapi"
)

type RegisteredModelEnvelope Envelope[*openapi.RegisteredModel, None]
type RegisteredModelListEnvelope Envelope[*openapi.RegisteredModelList, None]
type RegisteredModelUpdateEnvelope Envelope[*openapi.RegisteredModelUpdate, None]

func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
//TODO (ederign) implement pagination
func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

modelList, err := app.modelRegistryClient.GetAllRegisteredModels(client)
modelList, err := app.modelRegistryClient.GetAllRegisteredModels(client, r.URL.Query())
if err != nil {
app.serverErrorResponse(w, r, err)
return
Expand All @@ -40,7 +38,7 @@ func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Req
}
}

func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
Expand Down Expand Up @@ -173,14 +171,13 @@ func (app *App) UpdateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ
}

func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
//TODO (acreasy) implement pagination
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

versionList, err := app.modelRegistryClient.GetAllModelVersions(client, ps.ByName(RegisteredModelId))
versionList, err := app.modelRegistryClient.GetAllModelVersions(client, ps.ByName(RegisteredModelId), r.URL.Query())

if err != nil {
app.serverErrorResponse(w, r, err)
Expand Down
38 changes: 38 additions & 0 deletions clients/ui/bff/internal/data/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package data

import (
"fmt"
"net/url"
)

func FilterPageValues(values url.Values) url.Values {
result := url.Values{}

if v := values.Get("pageSize"); v != "" {
result.Set("pageSize", v)
}
if v := values.Get("orderBy"); v != "" {
result.Set("orderBy", v)
}
if v := values.Get("sortOrder"); v != "" {
result.Set("sortOrder", v)
}
if v := values.Get("nextPageToken"); v != "" {
result.Set("nextPageToken", v)
}

return result
}

func UrlWithParams(url string, values url.Values) string {
queryString := values.Encode()
if queryString == "" {
return url
}
return fmt.Sprintf("%s?%s", url, queryString)
}

func UrlWithPageParams(url string, values url.Values) string {
pageValues := FilterPageValues(values)
return UrlWithParams(url, pageValues)
}
6 changes: 3 additions & 3 deletions clients/ui/bff/internal/data/model_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type ModelVersionInterface interface {
GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error)
CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error)
UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error)
GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error)
GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelArtifactList, error)
CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error)
}

Expand Down Expand Up @@ -79,14 +79,14 @@ func (v ModelVersion) UpdateModelVersion(client integrations.HTTPClientInterface
return &model, nil
}

func (v ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) {
func (v ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelArtifactList, error) {
path, err := url.JoinPath(modelVersionPath, id, artifactsByModelVersionPath)

if err != nil {
return nil, err
}

responseData, err := client.GET(path)
responseData, err := client.GET(UrlWithPageParams(path, pageValues))
if err != nil {
return nil, fmt.Errorf("error fetching model version artifacts: %w", err)
}
Expand Down
28 changes: 27 additions & 1 deletion clients/ui/bff/internal/data/model_version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package data

import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"testing"
Expand Down Expand Up @@ -106,7 +107,7 @@ func TestGetModelArtifactsByModelVersion(t *testing.T) {
mockClient := new(mocks.MockHTTPClient)
mockClient.On(http.MethodGet, path, mock.Anything).Return(mockData, nil)

actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1")
actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1", nil)
assert.NoError(t, err)

assert.NotNil(t, actual)
Expand All @@ -116,6 +117,31 @@ func TestGetModelArtifactsByModelVersion(t *testing.T) {
assert.Equal(t, len(expected.Items), len(actual.Items))
}

func TestGetModelArtifactsByModelVersionWithPageParams(t *testing.T) {
gofakeit.Seed(0) //nolint:errcheck

pageValues := mocks.GenerateMockPageValues()
expected := mocks.GenerateMockModelArtifactList()

mockData, err := json.Marshal(expected)
assert.NoError(t, err)

modelVersion := ModelVersion{}

path, err := url.JoinPath(modelVersionPath, "1", artifactsByModelVersionPath)
assert.NoError(t, err)
reqUrl := fmt.Sprintf("%s?%s", path, pageValues.Encode())

mockClient := new(mocks.MockHTTPClient)
mockClient.On(http.MethodGet, reqUrl, mock.Anything).Return(mockData, nil)

actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1", pageValues)
assert.NoError(t, err)

assert.NotNil(t, actual)
mockClient.AssertExpectations(t)
}

func TestCreateModelArtifactByModelVersion(t *testing.T) {
gofakeit.Seed(0) //nolint:errcheck

Expand Down
12 changes: 6 additions & 6 deletions clients/ui/bff/internal/data/registered_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ const registeredModelPath = "/registered_models"
const versionsPath = "/versions"

type RegisteredModelInterface interface {
GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error)
GetAllRegisteredModels(client integrations.HTTPClientInterface, pageValues url.Values) (*openapi.RegisteredModelList, error)
CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error)
GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error)
UpdateRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.RegisteredModel, error)
GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error)
GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error)
CreateModelVersionForRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error)
}

type RegisteredModel struct {
RegisteredModelInterface
}

func (m RegisteredModel) GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) {
func (m RegisteredModel) GetAllRegisteredModels(client integrations.HTTPClientInterface, pageValues url.Values) (*openapi.RegisteredModelList, error) {
responseData, err := client.GET(UrlWithPageParams(registeredModelPath, pageValues))

responseData, err := client.GET(registeredModelPath)
if err != nil {
return nil, fmt.Errorf("error fetching registered models: %w", err)
}
Expand Down Expand Up @@ -94,14 +94,14 @@ func (m RegisteredModel) UpdateRegisteredModel(client integrations.HTTPClientInt
return &model, nil
}

func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) {
func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) {
path, err := url.JoinPath(registeredModelPath, id, versionsPath)

if err != nil {
return nil, err
}

responseData, err := client.GET(path)
responseData, err := client.GET(UrlWithPageParams(path, pageValues))

if err != nil {
return nil, fmt.Errorf("error fetching model versions: %w", err)
Expand Down
52 changes: 50 additions & 2 deletions clients/ui/bff/internal/data/registered_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package data

import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"testing"
Expand All @@ -25,7 +26,7 @@ func TestGetAllRegisteredModels(t *testing.T) {
mockClient := new(mocks.MockHTTPClient)
mockClient.On("GET", registeredModelPath).Return(mockData, nil)

actual, err := registeredModel.GetAllRegisteredModels(mockClient)
actual, err := registeredModel.GetAllRegisteredModels(mockClient, nil)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.Equal(t, expected.NextPageToken, actual.NextPageToken)
Expand All @@ -36,6 +37,28 @@ func TestGetAllRegisteredModels(t *testing.T) {
mockClient.AssertExpectations(t)
}

func TestGetAllRegisteredModelsWithPageParams(t *testing.T) {
gofakeit.Seed(0) //nolint:errcheck

pageValues := mocks.GenerateMockPageValues()
expected := mocks.GenerateMockRegisteredModelList()

mockData, err := json.Marshal(expected)
assert.NoError(t, err)

reqUrl := fmt.Sprintf("%s?%s", registeredModelPath, pageValues.Encode())

registeredModel := RegisteredModel{}

mockClient := new(mocks.MockHTTPClient)
mockClient.On("GET", reqUrl).Return(mockData, nil)

actual, err := registeredModel.GetAllRegisteredModels(mockClient, pageValues)
assert.NoError(t, err)
assert.NotNil(t, actual)
mockClient.AssertExpectations(t)
}

func TestCreateRegisteredModel(t *testing.T) {
gofakeit.Seed(0) //nolint:errcheck

Expand Down Expand Up @@ -126,7 +149,7 @@ func TestGetAllModelVersions(t *testing.T) {
assert.NoError(t, err)
mockClient.On("GET", path).Return(mockData, nil)

actual, err := registeredModel.GetAllModelVersions(mockClient, "1")
actual, err := registeredModel.GetAllModelVersions(mockClient, "1", nil)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.NoError(t, err)
Expand All @@ -139,6 +162,31 @@ func TestGetAllModelVersions(t *testing.T) {
mockClient.AssertExpectations(t)
}

func TestGetAllModelVersionsWithPageParams(t *testing.T) {
gofakeit.Seed(0) //nolint:errcheck

pageValues := mocks.GenerateMockPageValues()
expected := mocks.GenerateMockModelVersionList()

mockData, err := json.Marshal(expected)
assert.NoError(t, err)

registeredModel := RegisteredModel{}

mockClient := new(mocks.MockHTTPClient)
path, err := url.JoinPath(registeredModelPath, "1", versionsPath)
assert.NoError(t, err)
reqUrl := fmt.Sprintf("%s?%s", path, pageValues.Encode())

mockClient.On("GET", reqUrl).Return(mockData, nil)

actual, err := registeredModel.GetAllModelVersions(mockClient, "1", pageValues)
assert.NoError(t, err)
assert.NotNil(t, actual)

mockClient.AssertExpectations(t)
}

func TestCreateModelVersionForRegisteredModel(t *testing.T) {
gofakeit.Seed(0) //nolint:errcheck

Expand Down
10 changes: 6 additions & 4 deletions clients/ui/bff/internal/mocks/model_registry_client_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@ import (
"github.com/kubeflow/model-registry/ui/bff/internal/integrations"
"github.com/stretchr/testify/mock"
"log/slog"
"net/url"
)

type ModelRegistryClientMock struct {
mock.Mock
}

func NewModelRegistryClient(logger *slog.Logger) (*ModelRegistryClientMock, error) {
func NewModelRegistryClient(_ *slog.Logger) (*ModelRegistryClientMock, error) {
return &ModelRegistryClientMock{}, nil
}

func (m *ModelRegistryClientMock) GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) {
func (m *ModelRegistryClientMock) GetAllRegisteredModels(_ integrations.HTTPClientInterface, _ url.Values) (*openapi.RegisteredModelList, error) {
mockData := GetRegisteredModelListMock()
return &mockData, nil
}
Expand Down Expand Up @@ -50,7 +51,7 @@ func (m *ModelRegistryClientMock) UpdateModelVersion(client integrations.HTTPCli
return &mockData, nil
}

func (m *ModelRegistryClientMock) GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) {
func (m *ModelRegistryClientMock) GetAllModelVersions(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelVersionList, error) {
mockData := GetModelVersionListMock()
return &mockData, nil
}
Expand All @@ -60,10 +61,11 @@ func (m *ModelRegistryClientMock) CreateModelVersionForRegisteredModel(client in
return &mockData, nil
}

func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) {
func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelArtifactList, error) {
mockData := GetModelArtifactListMock()
return &mockData, nil
}

func (m *ModelRegistryClientMock) CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) {
mockData := GetModelArtifactMocks()[0]
return &mockData, nil
Expand Down
Loading