diff --git a/clients/ui/bff/README.md b/clients/ui/bff/README.md index 4c707bbb1..a9e0d08f6 100644 --- a/clients/ui/bff/README.md +++ b/clients/ui/bff/README.md @@ -71,21 +71,23 @@ make docker-build | POST /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts | CreateModelArtifactByModelVersion | Create a ModelArtifact entity for a specific ModelVersion | ### Sample local calls + +You will need to inject your requests with a kubeflow-userid header for authorization purposes. When running the service with the mocked Kubernetes client (MOCK_K8S_CLIENT=true), the user user@example.com is preconfigured with the necessary RBAC permissions to perform these actions. ``` # GET /v1/healthcheck -curl -i localhost:4000/api/v1/healthcheck +curl -i -H "kubeflow-userid: user@example.com" localhost:4000/api/v1/healthcheck ``` ``` # GET /v1/model_registry -curl -i localhost:4000/api/v1/model_registry +curl -i -H "kubeflow-userid: user@example.com" localhost:4000/api/v1/model_registry ``` ``` # GET /v1/model_registry/{model_registry_id}/registered_models -curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models +curl -i -H "kubeflow-userid: user@example.com" localhost:4000/api/v1/model_registry/model-registry/registered_models ``` ``` #POST /v1/model_registry/{model_registry_id}/registered_models -curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models" \ +curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models" \ -H "Content-Type: application/json" \ -d '{ "data": { "customProperties": { @@ -103,11 +105,11 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/regi ``` ``` # GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} -curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models/1 +curl -i -H "kubeflow-userid: user@example.com" localhost:4000/api/v1/model_registry/model-registry/registered_models/1 ``` ``` # PATCH /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} -curl -i -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/registered_models/1" \ +curl -i -H "kubeflow-userid: user@example.com" -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/registered_models/1" \ -H "Content-Type: application/json" \ -d '{ "data": { "description": "New description" @@ -115,11 +117,11 @@ curl -i -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/reg ``` ``` # GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} -curl -i http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1 +curl -i -H "kubeflow-userid: user@example.com" http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1 ``` ``` # POST /api/v1/model_registry/{model_registry_id}/model_versions -curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions" \ +curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions" \ -H "Content-Type: application/json" \ -d '{ "data": { "customProperties": { @@ -138,7 +140,7 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/mode ``` ``` # PATCH /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} -curl -i -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1" \ +curl -i -H "kubeflow-userid: user@example.com" -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1" \ -H "Content-Type: application/json" \ -d '{ "data": { "description": "New description 2" @@ -146,11 +148,11 @@ curl -i -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/mod ``` ``` # GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions -curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions +curl -i -H "kubeflow-userid: user@example.com" localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions ``` ``` # POST /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions -curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions" \ +curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions" \ -H "Content-Type: application/json" \ -d '{ "data": { "customProperties": { @@ -163,17 +165,17 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/regi "externalId": "9928", "name": "ModelVersion One", "state": "LIVE", - "author": "alex" + "author": "alex", "registeredModelId: "1" }}' ``` ``` # GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts -curl -i http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts +curl -i -H "kubeflow-userid: user@example.com" http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts ``` ``` # POST /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts -curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts" \ +curl -i -H "kubeflow-userid: user@example.com" -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts" \ -H "Content-Type: application/json" \ -d '{ "data": { "customProperties": { @@ -203,9 +205,9 @@ The following query parameters are supported by "Get All" style endpoints to con ### Sample local calls ``` # Get with a page size of 5 getting a specific page. -curl -i "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?pageSize=5&nextPageToken=CAEQARoCCAE" +curl -i -H "kubeflow-userid: user@example.com" "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?pageSize=5&nextPageToken=CAEQARoCCAE" ``` ``` # Get with a page size of 5, order by last update time in descending order. -curl -i "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?pageSize=5&orderBy=LAST_UPDATE_TIME&sortOrder=DESC" +curl -i -H "kubeflow-userid: user@example.com" "http://localhost:4000/api/v1/model_registry/model-registry/registered_models?pageSize=5&orderBy=LAST_UPDATE_TIME&sortOrder=DESC" ``` diff --git a/clients/ui/bff/internal/api/app.go b/clients/ui/bff/internal/api/app.go index 93b635666..916ad348d 100644 --- a/clients/ui/bff/internal/api/app.go +++ b/clients/ui/bff/internal/api/app.go @@ -106,5 +106,5 @@ func (app *App) Routes() http.Handler { router.GET(ModelRegistryListPath, app.ModelRegistryHandler) router.PATCH(ModelRegistryPath, app.AttachRESTClient(app.UpdateModelVersionHandler)) - return app.RecoverPanic(app.enableCORS(router)) + return app.RecoverPanic(app.enableCORS(app.RequireAccessControl(router))) } diff --git a/clients/ui/bff/internal/api/errors.go b/clients/ui/bff/internal/api/errors.go index 686cf7c44..089df5254 100644 --- a/clients/ui/bff/internal/api/errors.go +++ b/clients/ui/bff/internal/api/errors.go @@ -43,6 +43,17 @@ func (app *App) badRequestResponse(w http.ResponseWriter, r *http.Request, err e app.errorResponse(w, r, httpError) } +func (app *App) forbiddenResponse(w http.ResponseWriter, r *http.Request, message string) { + httpError := &integrations.HTTPError{ + StatusCode: http.StatusForbidden, + ErrorResponse: integrations.ErrorResponse{ + Code: strconv.Itoa(http.StatusForbidden), + Message: message, + }, + } + app.errorResponse(w, r, httpError) +} + func (app *App) errorResponse(w http.ResponseWriter, r *http.Request, error *integrations.HTTPError) { env := ErrorEnvelope{Error: error} diff --git a/clients/ui/bff/internal/api/healthcheck__handler_test.go b/clients/ui/bff/internal/api/healthcheck__handler_test.go index e6b93c93e..20ac52df9 100644 --- a/clients/ui/bff/internal/api/healthcheck__handler_test.go +++ b/clients/ui/bff/internal/api/healthcheck__handler_test.go @@ -27,6 +27,8 @@ func TestHealthCheckHandler(t *testing.T) { req, err := http.NewRequest(http.MethodGet, HealthCheckPath, nil) assert.NoError(t, err) + req.Header.Set(kubeflowUserId, mocks.KubeflowUserIDHeaderValue) + app.HealthcheckHandler(rr, req, nil) rs := rr.Result() @@ -46,6 +48,7 @@ func TestHealthCheckHandler(t *testing.T) { SystemInfo: models.SystemInfo{ Version: Version, }, + UserID: mocks.KubeflowUserIDHeaderValue, } assert.Equal(t, expected, healthCheckRes) diff --git a/clients/ui/bff/internal/api/healthcheck_handler.go b/clients/ui/bff/internal/api/healthcheck_handler.go index 57c6b9813..df6d4702e 100644 --- a/clients/ui/bff/internal/api/healthcheck_handler.go +++ b/clients/ui/bff/internal/api/healthcheck_handler.go @@ -7,7 +7,9 @@ import ( func (app *App) HealthcheckHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - healthCheck, err := app.repositories.HealthCheck.HealthCheck(Version) + userID := r.Header.Get(kubeflowUserId) + + healthCheck, err := app.repositories.HealthCheck.HealthCheck(Version, userID) if err != nil { app.serverErrorResponse(w, r, err) return diff --git a/clients/ui/bff/internal/api/middleware.go b/clients/ui/bff/internal/api/middleware.go index 64c5e5958..6a16fcf83 100644 --- a/clients/ui/bff/internal/api/middleware.go +++ b/clients/ui/bff/internal/api/middleware.go @@ -13,7 +13,7 @@ import ( type contextKey string const httpClientKey contextKey = "httpClientKey" -const userAccessToken = "x-forwarded-access-token" +const kubeflowUserId = "kubeflow-userid" func (app *App) RecoverPanic(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -48,14 +48,8 @@ func (app *App) AttachRESTClient(handler func(http.ResponseWriter, *http.Request app.serverErrorResponse(w, r, fmt.Errorf("failed to resolve model registry base URL): %v", err)) return } - var bearerToken string - bearerToken, err = resolveBearerToken(app.kubernetesClient, r.Header) - if err != nil { - app.serverErrorResponse(w, r, fmt.Errorf("failed to resolve BearerToken): %v", err)) - return - } - client, err := integrations.NewHTTPClient(modelRegistryBaseURL, bearerToken) + client, err := integrations.NewHTTPClient(modelRegistryBaseURL) if err != nil { app.serverErrorResponse(w, r, fmt.Errorf("failed to create Kubernetes client: %v", err)) return @@ -65,26 +59,6 @@ func (app *App) AttachRESTClient(handler func(http.ResponseWriter, *http.Request } } -func resolveBearerToken(k8s integrations.KubernetesClientInterface, header http.Header) (string, error) { - var bearerToken string - //check if I'm inside cluster - if k8s.IsInCluster() { - //in cluster - bearerToken = header.Get(userAccessToken) - if bearerToken == "" { - return "", fmt.Errorf("failed to create Rest client (not able to get bearerToken on cluster)") - } - } else { - //off cluster (development) - var err error - bearerToken, err = k8s.BearerToken() - if err != nil { - return "", fmt.Errorf("failed to fetch BearerToken in development mode: %v", err) - } - } - return bearerToken, nil -} - func resolveModelRegistryURL(id string, client integrations.KubernetesClientInterface, config config.EnvConfig) (string, error) { serviceDetails, err := client.GetServiceDetailsByName(id) if err != nil { @@ -99,3 +73,32 @@ func resolveModelRegistryURL(id string, client integrations.KubernetesClientInte url := fmt.Sprintf("http://%s:%d/api/model_registry/v1alpha3", serviceDetails.ClusterIP, serviceDetails.HTTPPort) return url, nil } + +func (app *App) RequireAccessControl(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + // Skip SAR for health check + if r.URL.Path == HealthCheckPath { + next.ServeHTTP(w, r) + return + } + + user := r.Header.Get(kubeflowUserId) + if user == "" { + app.forbiddenResponse(w, r, "missing kubeflow-userid header") + return + } + + allowed, err := app.kubernetesClient.PerformSAR(user) + if err != nil { + app.forbiddenResponse(w, r, "failed to perform SAR: %v") + return + } + if !allowed { + app.forbiddenResponse(w, r, "access denied") + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/clients/ui/bff/internal/api/model_versions_handler_test.go b/clients/ui/bff/internal/api/model_versions_handler_test.go index 729aa7158..1a9ef0409 100644 --- a/clients/ui/bff/internal/api/model_versions_handler_test.go +++ b/clients/ui/bff/internal/api/model_versions_handler_test.go @@ -15,7 +15,7 @@ var _ = Describe("TestGetModelVersionHandler", func() { By("fetching a model version") data := mocks.GetModelVersionMocks()[0] expected := ModelVersionEnvelope{Data: &data} - actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1", nil, k8sClient) + actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1", nil, k8sClient, mocks.KubeflowUserIDHeaderValue) Expect(err).NotTo(HaveOccurred()) By("should match the expected model version") Expect(rs.StatusCode).To(Equal(http.StatusOK)) @@ -27,7 +27,7 @@ var _ = Describe("TestGetModelVersionHandler", func() { data := mocks.GetModelVersionMocks()[0] expected := ModelVersionEnvelope{Data: &data} body := ModelVersionEnvelope{Data: openapi.NewModelVersion("Model One", "1")} - actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions", body, k8sClient) + actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions", body, k8sClient, mocks.KubeflowUserIDHeaderValue) Expect(err).NotTo(HaveOccurred()) By("should match the expected model version created") @@ -46,7 +46,7 @@ var _ = Describe("TestGetModelVersionHandler", func() { } body := ModelVersionUpdateEnvelope{Data: &reqData} - actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/model_versions/1", body, k8sClient) + actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/model_versions/1", body, k8sClient, mocks.KubeflowUserIDHeaderValue) Expect(err).NotTo(HaveOccurred()) By("should match the expected model version updated") @@ -58,7 +58,7 @@ var _ = Describe("TestGetModelVersionHandler", func() { By("getting a model artifacts by model version") data := mocks.GetModelArtifactListMock() expected := ModelArtifactListEnvelope{Data: &data} - actual, rs, err := setupApiTest[ModelArtifactListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", nil, k8sClient) + actual, rs, err := setupApiTest[ModelArtifactListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", nil, k8sClient, mocks.KubeflowUserIDHeaderValue) Expect(err).NotTo(HaveOccurred()) By("should get all expected model version artifacts") @@ -79,7 +79,7 @@ var _ = Describe("TestGetModelVersionHandler", func() { ArtifactType: "ARTIFACT_TYPE_ONE", } body := ModelArtifactEnvelope{Data: &artifact} - actual, rs, err := setupApiTest[ModelArtifactEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", body, k8sClient) + actual, rs, err := setupApiTest[ModelArtifactEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", body, k8sClient, mocks.KubeflowUserIDHeaderValue) Expect(err).NotTo(HaveOccurred()) By("should get all expected model artifacts") @@ -88,5 +88,54 @@ var _ = Describe("TestGetModelVersionHandler", func() { Expect(rs.Header.Get("Location")).To(Equal("/api/v1/model_registry/model-registry/model_artifacts/1")) }) + + It("should return 403 when not using the wrong KubeflowUserIDHeaderValue", func() { + By("making a request with an incorrect username") + wrongUserIDHeader := "bella@dora.com" // Incorrect username header value + + // Test: GET /model_versions/1 + _, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1", nil, k8sClient, wrongUserIDHeader) + + Expect(err).NotTo(HaveOccurred()) + By("should return a 403 Forbidden response") + Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) + + // Test: POST /model_versions/1/artifacts + artifact := openapi.ModelArtifact{ + Name: openapi.PtrString("Artifact One"), + ArtifactType: "ARTIFACT_TYPE_ONE", + } + body := ModelArtifactEnvelope{Data: &artifact} + _, rs, err = setupApiTest[ModelArtifactEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", body, k8sClient, wrongUserIDHeader) + + Expect(err).NotTo(HaveOccurred()) + By("should return a 403 Forbidden response") + Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) + + // Test: GET /model_versions/1/artifacts + _, rs, err = setupApiTest[ModelArtifactListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", nil, k8sClient, wrongUserIDHeader) + + Expect(err).NotTo(HaveOccurred()) + By("should return a 403 Forbidden response") + Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) + + // Test: PATCH /model_versions/1 + reqData := openapi.ModelVersionUpdate{ + Description: openapi.PtrString("New description"), + } + body1 := ModelVersionUpdateEnvelope{Data: &reqData} + _, rs, err = setupApiTest[ModelVersionEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/model_versions/1", body1, k8sClient, wrongUserIDHeader) + + Expect(err).NotTo(HaveOccurred()) + By("should return a 403 Forbidden response") + Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) + + // Test: POST /model_versions + body2 := ModelVersionEnvelope{Data: openapi.NewModelVersion("Model One", "1")} + _, rs, err = setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions", body2, k8sClient, wrongUserIDHeader) + Expect(err).NotTo(HaveOccurred()) + By("should return a 403 Forbidden response") + Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) + }) }) }) diff --git a/clients/ui/bff/internal/api/registered_models_handler_test.go b/clients/ui/bff/internal/api/registered_models_handler_test.go index 34bdbc1a9..93aa5cfc5 100644 --- a/clients/ui/bff/internal/api/registered_models_handler_test.go +++ b/clients/ui/bff/internal/api/registered_models_handler_test.go @@ -15,7 +15,7 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { By("fetching all model registries") data := mocks.GetRegisteredModelMocks()[0] expected := RegisteredModelEnvelope{Data: &data} - actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1", nil, k8sClient) + actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1", nil, k8sClient, mocks.KubeflowUserIDHeaderValue) Expect(err).NotTo(HaveOccurred()) By("should match the expected model registry") //TODO assert the full structure, I couldn't get unmarshalling to work for the full customProperties values @@ -28,7 +28,7 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { By("fetching all registered models") data := mocks.GetRegisteredModelListMock() expected := RegisteredModelListEnvelope{Data: &data} - actual, rs, err := setupApiTest[RegisteredModelListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models", nil, k8sClient) + actual, rs, err := setupApiTest[RegisteredModelListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models", nil, k8sClient, mocks.KubeflowUserIDHeaderValue) Expect(err).NotTo(HaveOccurred()) By("should match the expected model registry") Expect(rs.StatusCode).To(Equal(http.StatusOK)) @@ -43,7 +43,7 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { data := mocks.GetRegisteredModelMocks()[0] expected := RegisteredModelEnvelope{Data: &data} body := RegisteredModelEnvelope{Data: openapi.NewRegisteredModel("Model One")} - actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models", body, k8sClient) + actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models", body, k8sClient, mocks.KubeflowUserIDHeaderValue) Expect(err).NotTo(HaveOccurred()) By("should do a successful post") @@ -60,7 +60,7 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { Description: openapi.PtrString("This is a new description"), } body := RegisteredModelUpdateEnvelope{Data: &reqData} - actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/registered_models/1", body, k8sClient) + actual, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/registered_models/1", body, k8sClient, mocks.KubeflowUserIDHeaderValue) Expect(err).NotTo(HaveOccurred()) By("should do a successful patch") @@ -73,7 +73,7 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { data := mocks.GetModelVersionListMock() expected := ModelVersionListEnvelope{Data: &data} - actual, rs, err := setupApiTest[ModelVersionListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1/versions", nil, k8sClient) + actual, rs, err := setupApiTest[ModelVersionListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1/versions", nil, k8sClient, mocks.KubeflowUserIDHeaderValue) Expect(err).NotTo(HaveOccurred()) By("should get all items") @@ -90,7 +90,7 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { expected := ModelVersionEnvelope{Data: &data} body := ModelVersionEnvelope{Data: openapi.NewModelVersion("Version Fifty", "")} - actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models/1/versions", body, k8sClient) + actual, rs, err := setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models/1/versions", body, k8sClient, mocks.KubeflowUserIDHeaderValue) Expect(err).NotTo(HaveOccurred()) By("should successfully create it") @@ -99,5 +99,52 @@ var _ = Describe("TestGetRegisteredModelHandler", func() { Expect(rs.Header.Get("Location")).To(Equal("/api/v1/model_registry/model-registry/model_versions/1")) }) + + It("should return 403 when not using the correct KubeflowUserIDHeaderValue", func() { + By("making a request with an incorrect username") + wrongUserIDHeader := "bella@dora.com" // Incorrect username header value + + // Test: GET /registered_models/1 + _, rs, err := setupApiTest[RegisteredModelEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1", nil, k8sClient, wrongUserIDHeader) + Expect(err).NotTo(HaveOccurred()) + By("should return a 403 Forbidden response for GET registered model by ID") + Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) + + // Test: GET /registered_models + _, rs, err = setupApiTest[RegisteredModelListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models", nil, k8sClient, wrongUserIDHeader) + Expect(err).NotTo(HaveOccurred()) + By("should return a 403 Forbidden response for GET all registered models") + Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) + + // Test: POST /registered_models + body := RegisteredModelEnvelope{Data: openapi.NewRegisteredModel("Model One")} + _, rs, err = setupApiTest[RegisteredModelEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models", body, k8sClient, wrongUserIDHeader) + Expect(err).NotTo(HaveOccurred()) + By("should return a 403 Forbidden response for POST create registered model") + Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) + + // Test: PATCH /registered_models/1 + reqData := openapi.RegisteredModelUpdate{ + Description: openapi.PtrString("This is a new description"), + } + body2 := RegisteredModelUpdateEnvelope{Data: &reqData} + _, rs, err = setupApiTest[RegisteredModelEnvelope](http.MethodPatch, "/api/v1/model_registry/model-registry/registered_models/1", body2, k8sClient, wrongUserIDHeader) + Expect(err).NotTo(HaveOccurred()) + By("should return a 403 Forbidden response for PATCH update registered model") + Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) + + // Test: GET /registered_models/1/versions + _, rs, err = setupApiTest[ModelVersionListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/registered_models/1/versions", nil, k8sClient, wrongUserIDHeader) + Expect(err).NotTo(HaveOccurred()) + By("should return a 403 Forbidden response for GET model versions of registered model") + Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) + + // Test: POST /registered_models/1/versions + body3 := ModelVersionEnvelope{Data: openapi.NewModelVersion("Version Fifty", "")} + _, rs, err = setupApiTest[ModelVersionEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/registered_models/1/versions", body3, k8sClient, wrongUserIDHeader) + Expect(err).NotTo(HaveOccurred()) + By("should return a 403 Forbidden response for POST create model version for registered model") + Expect(rs.StatusCode).To(Equal(http.StatusForbidden)) + }) }) }) diff --git a/clients/ui/bff/internal/api/test_utils.go b/clients/ui/bff/internal/api/test_utils.go index 94bc598f3..3a2ec65a6 100644 --- a/clients/ui/bff/internal/api/test_utils.go +++ b/clients/ui/bff/internal/api/test_utils.go @@ -12,7 +12,7 @@ import ( "net/http/httptest" ) -func setupApiTest[T any](method string, url string, body interface{}, k8sClient k8s.KubernetesClientInterface) (T, *http.Response, error) { +func setupApiTest[T any](method string, url string, body interface{}, k8sClient k8s.KubernetesClientInterface, kubeflowUserIDHeader string) (T, *http.Response, error) { mockMRClient, err := mocks.NewModelRegistryClient(nil) if err != nil { return *new(T), nil, err @@ -43,6 +43,9 @@ func setupApiTest[T any](method string, url string, body interface{}, k8sClient } } + // Set the kubeflow-userid header + req.Header.Set(kubeflowUserId, kubeflowUserIDHeader) + ctx := context.WithValue(req.Context(), httpClientKey, mockClient) req = req.WithContext(ctx) diff --git a/clients/ui/bff/internal/integrations/http.go b/clients/ui/bff/internal/integrations/http.go index 6d58c63aa..c20a859bd 100644 --- a/clients/ui/bff/internal/integrations/http.go +++ b/clients/ui/bff/internal/integrations/http.go @@ -16,9 +16,8 @@ type HTTPClientInterface interface { } type HTTPClient struct { - client *http.Client - baseURL string - bearerToken string + client *http.Client + baseURL string } type ErrorResponse struct { @@ -35,14 +34,13 @@ func (e *HTTPError) Error() string { return fmt.Sprintf("HTTP %d: %s - %s", e.StatusCode, e.Code, e.Message) } -func NewHTTPClient(baseURL string, bearerToken string) (HTTPClientInterface, error) { +func NewHTTPClient(baseURL string) (HTTPClientInterface, error) { return &HTTPClient{ client: &http.Client{Transport: &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, }}, - baseURL: baseURL, - bearerToken: bearerToken, + baseURL: baseURL, }, nil } @@ -53,7 +51,6 @@ func (c *HTTPClient) GET(url string) ([]byte, error) { return nil, err } - req.Header.Add("Authorization", "Bearer "+c.bearerToken) response, err := c.client.Do(req) if err != nil { return nil, err @@ -76,7 +73,6 @@ func (c *HTTPClient) POST(url string, body io.Reader) ([]byte, error) { } req.Header.Set("Content-Type", "application/json") - req.Header.Add("Authorization", "Bearer "+c.bearerToken) response, err := c.client.Do(req) if err != nil { @@ -118,7 +114,6 @@ func (c *HTTPClient) PATCH(url string, body io.Reader) ([]byte, error) { } req.Header.Set("Content-Type", "application/json") - req.Header.Add("Authorization", "Bearer "+c.bearerToken) response, err := c.client.Do(req) if err != nil { diff --git a/clients/ui/bff/internal/integrations/k8s.go b/clients/ui/bff/internal/integrations/k8s.go index 2dd2c4c2d..98a319843 100644 --- a/clients/ui/bff/internal/integrations/k8s.go +++ b/clients/ui/bff/internal/integrations/k8s.go @@ -3,16 +3,18 @@ package integrations import ( "context" "fmt" + helper "github.com/kubeflow/model-registry/ui/bff/internal/helpers" + authv1 "k8s.io/api/authorization/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "log/slog" "os" - "time" - - helper "github.com/kubeflow/model-registry/ui/bff/internal/helpers" - corev1 "k8s.io/api/core/v1" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" + "time" ) const ComponentName = "model-registry-server" @@ -24,6 +26,7 @@ type KubernetesClientInterface interface { BearerToken() (string, error) Shutdown(ctx context.Context, logger *slog.Logger) error IsInCluster() bool + PerformSAR(user string) (bool, error) } type ServiceDetails struct { @@ -35,12 +38,13 @@ type ServiceDetails struct { } type KubernetesClient struct { - Client client.Client - Mgr ctrl.Manager - Token string - Logger *slog.Logger - StopFn context.CancelFunc // Store a function to cancel the context for graceful shutdown - mgrStopped chan struct{} + ControllerRuntimeClient client.Client //Controller-runtime client: used for high-level operations with caching. + KubernetesNativeClient kubernetes.Interface //Native KubernetesNativeClient client: only for specific non-cached subresources like SAR. + Mgr ctrl.Manager + Token string + Logger *slog.Logger + StopFn context.CancelFunc // Store a function to cancel the context for graceful shutdown + mgrStopped chan struct{} } func NewKubernetesClient(logger *slog.Logger) (KubernetesClientInterface, error) { @@ -67,7 +71,6 @@ func NewKubernetesClient(logger *slog.Logger) (KubernetesClientInterface, error) }, HealthProbeBindAddress: "0", // disable health probe serving LeaderElection: false, - //Namespace: "namespace", //TODO (ederign) do we need to specify the namespace to operate in //There is also cache filters and Sync periods to assess later. }) @@ -95,15 +98,22 @@ func NewKubernetesClient(logger *slog.Logger) (KubernetesClientInterface, error) return nil, fmt.Errorf("failed to wait for cache to sync") } + //Native KubernetesNativeClient client: only for specific non-cached subresources like SAR. + k8sClient, err := kubernetes.NewForConfig(kubeconfig) + if err != nil { + logger.Error("failed to create native KubernetesNativeClient client", "error", err) + cancel() + return nil, fmt.Errorf("failed to create KubernetesNativeClient client: %w", err) + } + kc := &KubernetesClient{ - Client: mgr.GetClient(), - Mgr: mgr, - Token: kubeconfig.BearerToken, - Logger: logger, - StopFn: cancel, - mgrStopped: mgrStopped, // Store the stop channel - - //Namespace: namespace, //TODO (ederign) do we need to restrict service list by namespace? + ControllerRuntimeClient: mgr.GetClient(), + KubernetesNativeClient: k8sClient, + Mgr: mgr, + Token: kubeconfig.BearerToken, + Logger: logger, + StopFn: cancel, + mgrStopped: mgrStopped, } return kc, nil } @@ -138,10 +148,6 @@ func (kc *KubernetesClient) BearerToken() (string, error) { } func (kc *KubernetesClient) GetServiceNames() ([]string, error) { - //TODO (ederign) when we develop the front-end, implement subject access review here - // and check if the username has actually permissions to access that server - // currently on kf dashboard, the user name comes in kubeflow-userid - //TODO (ederign) we should consider and rethinking listing all services on cluster // what if we have thousand of those? // we should consider label filtering for instance @@ -151,7 +157,7 @@ func (kc *KubernetesClient) GetServiceNames() ([]string, error) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Second) defer cancel() - err := kc.Client.List(ctx, serviceList, &client.ListOptions{}) + err := kc.ControllerRuntimeClient.List(ctx, serviceList, &client.ListOptions{}) if err != nil { return nil, fmt.Errorf("failed to list services: %w", err) } @@ -172,17 +178,12 @@ func (kc *KubernetesClient) GetServiceNames() ([]string, error) { func (kc *KubernetesClient) GetServiceDetails() ([]ServiceDetails, error) { //TODO (ederign) review the context timeout - - //TODO (ederign) when we develop the front-end, implement subject access review here - // and check if the username has actually permissions to access that server - // currently on kf dashboard, the user name comes in kubeflow-userid - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Second) defer cancel() // Ensure the context is canceled to free up resources serviceList := &corev1.ServiceList{} - err := kc.Client.List(ctx, serviceList, &client.ListOptions{}) + err := kc.ControllerRuntimeClient.List(ctx, serviceList, &client.ListOptions{}) if err != nil { return nil, fmt.Errorf("failed to list services: %w", err) } @@ -242,10 +243,6 @@ func (kc *KubernetesClient) GetServiceDetails() ([]ServiceDetails, error) { } func (kc *KubernetesClient) GetServiceDetailsByName(serviceName string) (ServiceDetails, error) { - //TODO (ederign) when we develop the front-end, implement subject access review here - // and check if the username has actually permissions to access that server - // currently on kf dashboard, the user name comes in kubeflow-userid - services, err := kc.GetServiceDetails() if err != nil { return ServiceDetails{}, fmt.Errorf("failed to get service details: %w", err) @@ -259,3 +256,33 @@ func (kc *KubernetesClient) GetServiceDetailsByName(serviceName string) (Service return ServiceDetails{}, fmt.Errorf("service %s not found", serviceName) } + +func (kc *KubernetesClient) PerformSAR(user string) (bool, error) { + verbs := []string{"get", "list"} + resource := "services" + + for _, verb := range verbs { + sar := &authv1.SubjectAccessReview{ + Spec: authv1.SubjectAccessReviewSpec{ + User: user, + ResourceAttributes: &authv1.ResourceAttributes{ + Verb: verb, + Resource: resource, + }, + }, + } + + // Perform the SAR using the native KubernetesNativeClient client + response, err := kc.KubernetesNativeClient.AuthorizationV1().SubjectAccessReviews().Create(context.TODO(), sar, metav1.CreateOptions{}) + if err != nil { + return false, fmt.Errorf("failed to create SubjectAccessReview for verb %q on resource %q: %w", verb, resource, err) + } + + if !response.Status.Allowed { + kc.Logger.Warn("access denied", "user", user, "verb", verb, "resource", resource) + return false, nil + } + } + + return true, nil +} diff --git a/clients/ui/bff/internal/mocks/k8s_mock.go b/clients/ui/bff/internal/mocks/k8s_mock.go index 7ea5bf425..0a6f85931 100644 --- a/clients/ui/bff/internal/mocks/k8s_mock.go +++ b/clients/ui/bff/internal/mocks/k8s_mock.go @@ -3,19 +3,22 @@ package mocks import ( "context" "fmt" - "log/slog" - "os" - "path/filepath" - "runtime" - k8s "github.com/kubeflow/model-registry/ui/bff/internal/integrations" corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes/scheme" + "log/slog" + "os" + "path/filepath" + "runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/envtest" ) +const KubeflowUserIDHeaderValue = "user@example.com" + type KubernetesClientMock struct { *k8s.KubernetesClient testEnv *envtest.Environment @@ -63,6 +66,13 @@ func NewKubernetesClient(logger *slog.Logger, ctx context.Context, cancel contex os.Exit(1) } + nativeK8sClient, err := kubernetes.NewForConfig(cfg) + if err != nil { + logger.Error("failed to create native KubernetesNativeClient client", slog.String("error", err.Error())) + cancel() + os.Exit(1) + } + err = setupMock(mockK8sClient, ctx) if err != nil { logger.Error("failed on mock setup", slog.String("error", err.Error())) @@ -72,9 +82,10 @@ func NewKubernetesClient(logger *slog.Logger, ctx context.Context, cancel contex return &KubernetesClientMock{ KubernetesClient: &k8s.KubernetesClient{ - Client: mockK8sClient, - Logger: logger, - StopFn: cancel, + ControllerRuntimeClient: mockK8sClient, + KubernetesNativeClient: nativeK8sClient, + Logger: logger, + StopFn: cancel, }, testEnv: testEnv, }, nil @@ -115,6 +126,12 @@ func setupMock(mockK8sClient client.Client, ctx context.Context) error { if err != nil { return err } + + err = createRBAC(mockK8sClient, ctx, KubeflowUserIDHeaderValue) + if err != nil { + return fmt.Errorf("failed to create RBAC for KubeflowUserIDHeaderValue: %w", err) + } + return nil } @@ -207,6 +224,50 @@ func createService(k8sClient client.Client, ctx context.Context, name string, na return nil } +func createRBAC(k8sClient client.Client, ctx context.Context, username string) error { + clusterRole := &rbacv1.ClusterRole{ + ObjectMeta: metav1.ObjectMeta{ + Name: "service-access-role", + }, + Rules: []rbacv1.PolicyRule{ + { + APIGroups: []string{""}, // Core API group + Resources: []string{"services"}, + Verbs: []string{"get", "list"}, + }, + }, + } + + err := k8sClient.Create(ctx, clusterRole) + if err != nil { + return fmt.Errorf("failed to create ClusterRole: %w", err) + } + + clusterRoleBinding := &rbacv1.ClusterRoleBinding{ + ObjectMeta: metav1.ObjectMeta{ + Name: "service-access-binding", + }, + Subjects: []rbacv1.Subject{ + { + Kind: "User", + Name: username, + }, + }, + RoleRef: rbacv1.RoleRef{ + Kind: "ClusterRole", + Name: "service-access-role", + APIGroup: "rbac.authorization.k8s.io", + }, + } + + err = k8sClient.Create(ctx, clusterRoleBinding) + if err != nil { + return fmt.Errorf("failed to create ClusterRoleBinding: %w", err) + } + + return nil +} + func strPtr(s string) *string { return &s } diff --git a/clients/ui/bff/internal/mocks/k8s_mock_test.go b/clients/ui/bff/internal/mocks/k8s_mock_test.go index 9ef9f502c..5fce6bf75 100644 --- a/clients/ui/bff/internal/mocks/k8s_mock_test.go +++ b/clients/ui/bff/internal/mocks/k8s_mock_test.go @@ -5,7 +5,7 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Kubernetes Client Test", func() { +var _ = Describe("Kubernetes ControllerRuntimeClient Test", func() { Context("with existing services", Ordered, func() { It("should retrieve the get all service successfully", func() { @@ -60,3 +60,23 @@ var _ = Describe("Kubernetes Client Test", func() { }) }) + +var _ = Describe("KubernetesNativeClient SAR Test", func() { + Context("Subject Access Review", func() { + + It("should allow allowed user to access services", func() { + By("performing SAR for Kubeflow User ID") + allowed, err := k8sClient.PerformSAR(KubeflowUserIDHeaderValue) + Expect(err).NotTo(HaveOccurred(), "Failed to perform SAR for Kubeflow User ID\"") + Expect(allowed).To(BeTrue(), "Expected Kubeflow User ID to have access") + }) + + It("should deny access for another user", func() { + By("performing SAR for another user") + allowed, err := k8sClient.PerformSAR("unauthorized-dora@example.com") + Expect(err).NotTo(HaveOccurred(), "Failed to perform SAR for unauthorized-dora@example.com") + Expect(allowed).To(BeFalse(), "Expected unauthorized-dora@example.com to be denied access") + }) + + }) +}) diff --git a/clients/ui/bff/internal/models/health_check.go b/clients/ui/bff/internal/models/health_check.go index daf9e72d2..cfee33ac8 100644 --- a/clients/ui/bff/internal/models/health_check.go +++ b/clients/ui/bff/internal/models/health_check.go @@ -7,4 +7,5 @@ type SystemInfo struct { type HealthCheckModel struct { Status string `json:"status"` SystemInfo SystemInfo `json:"system_info"` + UserID string `json:"user-id"` } diff --git a/clients/ui/bff/internal/repositories/health_check.go b/clients/ui/bff/internal/repositories/health_check.go index 2dedfc725..a513966a7 100644 --- a/clients/ui/bff/internal/repositories/health_check.go +++ b/clients/ui/bff/internal/repositories/health_check.go @@ -8,13 +8,14 @@ func NewHealthCheckRepository() *HealthCheckRepository { return &HealthCheckRepository{} } -func (r *HealthCheckRepository) HealthCheck(version string) (models.HealthCheckModel, error) { +func (r *HealthCheckRepository) HealthCheck(version string, userID string) (models.HealthCheckModel, error) { var res = models.HealthCheckModel{ Status: "available", SystemInfo: models.SystemInfo{ Version: version, }, + UserID: userID, } return res, nil diff --git a/clients/ui/frontend/docs/dev-setup.md b/clients/ui/frontend/docs/dev-setup.md index 2fe6e9653..76981767c 100644 --- a/clients/ui/frontend/docs/dev-setup.md +++ b/clients/ui/frontend/docs/dev-setup.md @@ -32,6 +32,8 @@ npm run build This is the default context for running a local UI. Make sure you build the project using the instructions above prior to running the command below. +You will need to inject your requests with a kubeflow-userid header for authorization purposes. For example, you can use the [Header Editor](https://chromewebstore.google.com/detail/eningockdidmgiojffjmkdblpjocbhgh) extension in Chrome to set the kubeflow-userid header to user@example.com. + ```bash npm run start:dev ``` diff --git a/clients/ui/manifests/user-rbac/kubeflow-dashboard-rbac.yaml b/clients/ui/manifests/user-rbac/kubeflow-dashboard-rbac.yaml new file mode 100644 index 000000000..e81048d8e --- /dev/null +++ b/clients/ui/manifests/user-rbac/kubeflow-dashboard-rbac.yaml @@ -0,0 +1,21 @@ +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: service-access-cluster-role +rules: + - apiGroups: [""] + resources: ["services"] + verbs: ["get", "list"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: service-access-cluster-binding +subjects: + - kind: User + name: user@example.com + apiGroup: rbac.authorization.k8s.io +roleRef: + kind: ClusterRole + name: service-access-cluster-role + apiGroup: rbac.authorization.k8s.io diff --git a/clients/ui/manifests/user-rbac/kustomization.yaml b/clients/ui/manifests/user-rbac/kustomization.yaml index cb01d8d05..3e513a329 100644 --- a/clients/ui/manifests/user-rbac/kustomization.yaml +++ b/clients/ui/manifests/user-rbac/kustomization.yaml @@ -2,4 +2,5 @@ apiVersion: kustomize.config.k8s.io/v1beta1 kind: Kustomization resources: - - admin-rbac.yaml \ No newline at end of file + - admin-rbac.yaml + - kubeflow-dashboard-rbac.yaml \ No newline at end of file diff --git a/clients/ui/scripts/deploy_kind_cluster.sh b/clients/ui/scripts/deploy_kind_cluster.sh index 43f8afb71..57f365cf5 100755 --- a/clients/ui/scripts/deploy_kind_cluster.sh +++ b/clients/ui/scripts/deploy_kind_cluster.sh @@ -52,10 +52,10 @@ echo "Applying admin user service account and rolebinding..." kubectl apply -k . # Step 6: Generate token for admin user and display it -echo "Generating token for admin user, copy the following token in the local storage with key 'x-forwarded-access-token'..." -echo -e "\033[32m$(kubectl -n kube-system create token admin-user)\033[0m" +echo "In your browser, you will need to inject your requests with a kubeflow-userid header for authorization purposes." +echo "For example, you can use the Header Editor - https://chromewebstore.google.com/detail/eningockdidmgiojffjmkdblpjocbhgh extension in Chrome to set the kubeflow-userid header to user@example.com." # Step 5: Port-forward the service -echo "Port-fowarding Model Registry UI..." +echo "Port-forwarding Model Registry UI..." echo -e "\033[32mDashboard available in http://localhost:8080\033[0m" kubectl port-forward svc/model-registry-ui-service -n kubeflow 8080:8080