Skip to content

Commit

Permalink
feat(bff): use 'kubeflow-userid' header to authorize BFF endpoints (k…
Browse files Browse the repository at this point in the history
…ubeflow#599)

* feat(bff): use 'kubeflow-userid' header to authorize BFF endpoints

Signed-off-by: Eder Ignatowicz <[email protected]>

* fixing lint

Signed-off-by: Eder Ignatowicz <[email protected]>

---------

Signed-off-by: Eder Ignatowicz <[email protected]>
  • Loading branch information
ederign authored Dec 4, 2024
1 parent 3eb4e9c commit ea1afd2
Show file tree
Hide file tree
Showing 19 changed files with 364 additions and 115 deletions.
34 changes: 18 additions & 16 deletions clients/ui/bff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,23 @@ make docker-build
| POST /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts | CreateModelArtifactByModelVersion | Create a ModelArtifact entity for a specific ModelVersion |

### Sample local calls

You will need to inject your requests with a kubeflow-userid header for authorization purposes. When running the service with the mocked Kubernetes client (MOCK_K8S_CLIENT=true), the user [email protected] is preconfigured with the necessary RBAC permissions to perform these actions.
```
# GET /v1/healthcheck
curl -i localhost:4000/api/v1/healthcheck
curl -i -H "kubeflow-userid: [email protected]" localhost:4000/api/v1/healthcheck
```
```
# GET /v1/model_registry
curl -i localhost:4000/api/v1/model_registry
curl -i -H "kubeflow-userid: [email protected]" localhost:4000/api/v1/model_registry
```
```
# GET /v1/model_registry/{model_registry_id}/registered_models
curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models
curl -i -H "kubeflow-userid: [email protected]" localhost:4000/api/v1/model_registry/model-registry/registered_models
```
```
#POST /v1/model_registry/{model_registry_id}/registered_models
curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models" \
curl -i -H "kubeflow-userid: [email protected]" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models" \
-H "Content-Type: application/json" \
-d '{ "data": {
"customProperties": {
Expand All @@ -103,23 +105,23 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/regi
```
```
# GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}
curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models/1
curl -i -H "kubeflow-userid: [email protected]" localhost:4000/api/v1/model_registry/model-registry/registered_models/1
```
```
# PATCH /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}
curl -i -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/registered_models/1" \
curl -i -H "kubeflow-userid: [email protected]" -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/registered_models/1" \
-H "Content-Type: application/json" \
-d '{ "data": {
"description": "New description"
}}'
```
```
# GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}
curl -i http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1
curl -i -H "kubeflow-userid: [email protected]" http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1
```
```
# POST /api/v1/model_registry/{model_registry_id}/model_versions
curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions" \
curl -i -H "kubeflow-userid: [email protected]" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions" \
-H "Content-Type: application/json" \
-d '{ "data": {
"customProperties": {
Expand All @@ -138,19 +140,19 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/mode
```
```
# PATCH /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}
curl -i -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1" \
curl -i -H "kubeflow-userid: [email protected]" -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1" \
-H "Content-Type: application/json" \
-d '{ "data": {
"description": "New description 2"
}}'
```
```
# GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions
curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions
curl -i -H "kubeflow-userid: [email protected]" localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions
```
```
# POST /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions
curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions" \
curl -i -H "kubeflow-userid: [email protected]" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions" \
-H "Content-Type: application/json" \
-d '{ "data": {
"customProperties": {
Expand All @@ -163,17 +165,17 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/regi
"externalId": "9928",
"name": "ModelVersion One",
"state": "LIVE",
"author": "alex"
"author": "alex",
"registeredModelId: "1"
}}'
```
```
# GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts
curl -i http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts
curl -i -H "kubeflow-userid: [email protected]" http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts
```
```
# POST /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts
curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts" \
curl -i -H "kubeflow-userid: [email protected]" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts" \
-H "Content-Type: application/json" \
-d '{ "data": {
"customProperties": {
Expand Down Expand Up @@ -203,9 +205,9 @@ The following query parameters are supported by "Get All" style endpoints to con
### Sample local calls
```
# Get with a page size of 5 getting a specific page.
curl -i "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?pageSize=5&nextPageToken=CAEQARoCCAE"
curl -i -H "kubeflow-userid: [email protected]" "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?pageSize=5&nextPageToken=CAEQARoCCAE"
```
```
# Get with a page size of 5, order by last update time in descending order.
curl -i "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?pageSize=5&orderBy=LAST_UPDATE_TIME&sortOrder=DESC"
curl -i -H "kubeflow-userid: [email protected]" "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?pageSize=5&orderBy=LAST_UPDATE_TIME&sortOrder=DESC"
```
2 changes: 1 addition & 1 deletion clients/ui/bff/internal/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,5 @@ func (app *App) Routes() http.Handler {
router.GET(ModelRegistryListPath, app.ModelRegistryHandler)
router.PATCH(ModelRegistryPath, app.AttachRESTClient(app.UpdateModelVersionHandler))

return app.RecoverPanic(app.enableCORS(router))
return app.RecoverPanic(app.enableCORS(app.RequireAccessControl(router)))
}
11 changes: 11 additions & 0 deletions clients/ui/bff/internal/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ func (app *App) badRequestResponse(w http.ResponseWriter, r *http.Request, err e
app.errorResponse(w, r, httpError)
}

func (app *App) forbiddenResponse(w http.ResponseWriter, r *http.Request, message string) {
httpError := &integrations.HTTPError{
StatusCode: http.StatusForbidden,
ErrorResponse: integrations.ErrorResponse{
Code: strconv.Itoa(http.StatusForbidden),
Message: message,
},
}
app.errorResponse(w, r, httpError)
}

func (app *App) errorResponse(w http.ResponseWriter, r *http.Request, error *integrations.HTTPError) {

env := ErrorEnvelope{Error: error}
Expand Down
3 changes: 3 additions & 0 deletions clients/ui/bff/internal/api/healthcheck__handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ func TestHealthCheckHandler(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, HealthCheckPath, nil)
assert.NoError(t, err)

req.Header.Set(kubeflowUserId, mocks.KubeflowUserIDHeaderValue)

app.HealthcheckHandler(rr, req, nil)
rs := rr.Result()

Expand All @@ -46,6 +48,7 @@ func TestHealthCheckHandler(t *testing.T) {
SystemInfo: models.SystemInfo{
Version: Version,
},
UserID: mocks.KubeflowUserIDHeaderValue,
}

assert.Equal(t, expected, healthCheckRes)
Expand Down
4 changes: 3 additions & 1 deletion clients/ui/bff/internal/api/healthcheck_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (

func (app *App) HealthcheckHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {

healthCheck, err := app.repositories.HealthCheck.HealthCheck(Version)
userID := r.Header.Get(kubeflowUserId)

healthCheck, err := app.repositories.HealthCheck.HealthCheck(Version, userID)
if err != nil {
app.serverErrorResponse(w, r, err)
return
Expand Down
59 changes: 31 additions & 28 deletions clients/ui/bff/internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
type contextKey string

const httpClientKey contextKey = "httpClientKey"
const userAccessToken = "x-forwarded-access-token"
const kubeflowUserId = "kubeflow-userid"

func (app *App) RecoverPanic(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -48,14 +48,8 @@ func (app *App) AttachRESTClient(handler func(http.ResponseWriter, *http.Request
app.serverErrorResponse(w, r, fmt.Errorf("failed to resolve model registry base URL): %v", err))
return
}
var bearerToken string
bearerToken, err = resolveBearerToken(app.kubernetesClient, r.Header)
if err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("failed to resolve BearerToken): %v", err))
return
}

client, err := integrations.NewHTTPClient(modelRegistryBaseURL, bearerToken)
client, err := integrations.NewHTTPClient(modelRegistryBaseURL)
if err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("failed to create Kubernetes client: %v", err))
return
Expand All @@ -65,26 +59,6 @@ func (app *App) AttachRESTClient(handler func(http.ResponseWriter, *http.Request
}
}

func resolveBearerToken(k8s integrations.KubernetesClientInterface, header http.Header) (string, error) {
var bearerToken string
//check if I'm inside cluster
if k8s.IsInCluster() {
//in cluster
bearerToken = header.Get(userAccessToken)
if bearerToken == "" {
return "", fmt.Errorf("failed to create Rest client (not able to get bearerToken on cluster)")
}
} else {
//off cluster (development)
var err error
bearerToken, err = k8s.BearerToken()
if err != nil {
return "", fmt.Errorf("failed to fetch BearerToken in development mode: %v", err)
}
}
return bearerToken, nil
}

func resolveModelRegistryURL(id string, client integrations.KubernetesClientInterface, config config.EnvConfig) (string, error) {
serviceDetails, err := client.GetServiceDetailsByName(id)
if err != nil {
Expand All @@ -99,3 +73,32 @@ func resolveModelRegistryURL(id string, client integrations.KubernetesClientInte
url := fmt.Sprintf("http://%s:%d/api/model_registry/v1alpha3", serviceDetails.ClusterIP, serviceDetails.HTTPPort)
return url, nil
}

func (app *App) RequireAccessControl(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

// Skip SAR for health check
if r.URL.Path == HealthCheckPath {
next.ServeHTTP(w, r)
return
}

user := r.Header.Get(kubeflowUserId)
if user == "" {
app.forbiddenResponse(w, r, "missing kubeflow-userid header")
return
}

allowed, err := app.kubernetesClient.PerformSAR(user)
if err != nil {
app.forbiddenResponse(w, r, "failed to perform SAR: %v")
return
}
if !allowed {
app.forbiddenResponse(w, r, "access denied")
return
}

next.ServeHTTP(w, r)
})
}
59 changes: 54 additions & 5 deletions clients/ui/bff/internal/api/model_versions_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ var _ = Describe("TestGetModelVersionHandler", func() {
By("fetching a model version")
data := mocks.GetModelVersionMocks()[0]
expected := ModelVersionEnvelope{Data: &data}
actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1", nil, k8sClient)
actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1", nil, k8sClient, mocks.KubeflowUserIDHeaderValue)
Expect(err).NotTo(HaveOccurred())
By("should match the expected model version")
Expect(rs.StatusCode).To(Equal(http.StatusOK))
Expand All @@ -27,7 +27,7 @@ var _ = Describe("TestGetModelVersionHandler", func() {
data := mocks.GetModelVersionMocks()[0]
expected := ModelVersionEnvelope{Data: &data}
body := ModelVersionEnvelope{Data: openapi.NewModelVersion("Model One", "1")}
actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions", body, k8sClient)
actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions", body, k8sClient, mocks.KubeflowUserIDHeaderValue)
Expect(err).NotTo(HaveOccurred())

By("should match the expected model version created")
Expand All @@ -46,7 +46,7 @@ var _ = Describe("TestGetModelVersionHandler", func() {
}
body := ModelVersionUpdateEnvelope{Data: &reqData}

actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/model_versions/1", body, k8sClient)
actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/model_versions/1", body, k8sClient, mocks.KubeflowUserIDHeaderValue)
Expect(err).NotTo(HaveOccurred())

By("should match the expected model version updated")
Expand All @@ -58,7 +58,7 @@ var _ = Describe("TestGetModelVersionHandler", func() {
By("getting a model artifacts by model version")
data := mocks.GetModelArtifactListMock()
expected := ModelArtifactListEnvelope{Data: &data}
actual, rs, err := setupApiTest[ModelArtifactListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", nil, k8sClient)
actual, rs, err := setupApiTest[ModelArtifactListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", nil, k8sClient, mocks.KubeflowUserIDHeaderValue)
Expect(err).NotTo(HaveOccurred())

By("should get all expected model version artifacts")
Expand All @@ -79,7 +79,7 @@ var _ = Describe("TestGetModelVersionHandler", func() {
ArtifactType: "ARTIFACT_TYPE_ONE",
}
body := ModelArtifactEnvelope{Data: &artifact}
actual, rs, err := setupApiTest[ModelArtifactEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", body, k8sClient)
actual, rs, err := setupApiTest[ModelArtifactEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", body, k8sClient, mocks.KubeflowUserIDHeaderValue)
Expect(err).NotTo(HaveOccurred())

By("should get all expected model artifacts")
Expand All @@ -88,5 +88,54 @@ var _ = Describe("TestGetModelVersionHandler", func() {
Expect(rs.Header.Get("Location")).To(Equal("/api/v1/model_registry/model-registry/model_artifacts/1"))

})

It("should return 403 when not using the wrong KubeflowUserIDHeaderValue", func() {
By("making a request with an incorrect username")
wrongUserIDHeader := "[email protected]" // Incorrect username header value

// Test: GET /model_versions/1
_, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1", nil, k8sClient, wrongUserIDHeader)

Expect(err).NotTo(HaveOccurred())
By("should return a 403 Forbidden response")
Expect(rs.StatusCode).To(Equal(http.StatusForbidden))

// Test: POST /model_versions/1/artifacts
artifact := openapi.ModelArtifact{
Name: openapi.PtrString("Artifact One"),
ArtifactType: "ARTIFACT_TYPE_ONE",
}
body := ModelArtifactEnvelope{Data: &artifact}
_, rs, err = setupApiTest[ModelArtifactEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", body, k8sClient, wrongUserIDHeader)

Expect(err).NotTo(HaveOccurred())
By("should return a 403 Forbidden response")
Expect(rs.StatusCode).To(Equal(http.StatusForbidden))

// Test: GET /model_versions/1/artifacts
_, rs, err = setupApiTest[ModelArtifactListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", nil, k8sClient, wrongUserIDHeader)

Expect(err).NotTo(HaveOccurred())
By("should return a 403 Forbidden response")
Expect(rs.StatusCode).To(Equal(http.StatusForbidden))

// Test: PATCH /model_versions/1
reqData := openapi.ModelVersionUpdate{
Description: openapi.PtrString("New description"),
}
body1 := ModelVersionUpdateEnvelope{Data: &reqData}
_, rs, err = setupApiTest[ModelVersionEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/model_versions/1", body1, k8sClient, wrongUserIDHeader)

Expect(err).NotTo(HaveOccurred())
By("should return a 403 Forbidden response")
Expect(rs.StatusCode).To(Equal(http.StatusForbidden))

// Test: POST /model_versions
body2 := ModelVersionEnvelope{Data: openapi.NewModelVersion("Model One", "1")}
_, rs, err = setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions", body2, k8sClient, wrongUserIDHeader)
Expect(err).NotTo(HaveOccurred())
By("should return a 403 Forbidden response")
Expect(rs.StatusCode).To(Equal(http.StatusForbidden))
})
})
})
Loading

0 comments on commit ea1afd2

Please sign in to comment.