Skip to content

Commit

Permalink
Adds ability to change log verbosity through cli / env var and also a…
Browse files Browse the repository at this point in the history
…dds debug logging statements for upstream requests

Signed-off-by: Alex Creasy <[email protected]>
  • Loading branch information
alexcreasy committed Jan 8, 2025
1 parent 2a39432 commit 68cd131
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 9 deletions.
3 changes: 2 additions & 1 deletion clients/ui/bff/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ DEV_MODE_PORT ?= 8080
STANDALONE_MODE ?= true
# ENVTEST_K8S_VERSION refers to the version of kubebuilder assets to be downloaded by envtest binary.
ENVTEST_K8S_VERSION = 1.29.0
LOG_LEVEL ?= info

.PHONY: all
all: build
Expand Down Expand Up @@ -48,7 +49,7 @@ build: fmt vet test ## Builds the project to produce a binary executable.
.PHONY: run
run: fmt vet envtest ## Runs the project.
ENVTEST_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" \
go run ./cmd/main.go --port=$(PORT) --mock-k8s-client=$(MOCK_K8S_CLIENT) --mock-mr-client=$(MOCK_MR_CLIENT) --dev-mode=$(DEV_MODE) --dev-mode-port=$(DEV_MODE_PORT) --standalone-mode=$(STANDALONE_MODE)
go run ./cmd/main.go --port=$(PORT) --mock-k8s-client=$(MOCK_K8S_CLIENT) --mock-mr-client=$(MOCK_MR_CLIENT) --dev-mode=$(DEV_MODE) --dev-mode-port=$(DEV_MODE_PORT) --standalone-mode=$(STANDALONE_MODE) --log-level=$(LOG_LEVEL)

.PHONY: docker-build
docker-build: ## Builds a container for the project.
Expand Down
30 changes: 29 additions & 1 deletion clients/ui/bff/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"flag"
"fmt"
"os/signal"
"strings"
"syscall"

"github.com/kubeflow/model-registry/ui/bff/internal/api"
Expand All @@ -25,9 +26,12 @@ func main() {
flag.BoolVar(&cfg.DevMode, "dev-mode", false, "Use development mode for access to local K8s cluster")
flag.IntVar(&cfg.DevModePort, "dev-mode-port", getEnvAsInt("DEV_MODE_PORT", 8080), "Use port when in development mode")
flag.BoolVar(&cfg.StandaloneMode, "standalone-mode", false, "Use standalone mode for enabling endpoints in standalone mode")
flag.StringVar(&cfg.LogLevel, "log-level", getEnvAsString("LOG_LEVEL", "info"), "Sets server log level, possible values: debug, info, warn, error, fatal")
flag.Parse()

logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: getLogLevelFromString(cfg.LogLevel),
}))

app, err := api.NewApp(cfg, logger)
if err != nil {
Expand Down Expand Up @@ -87,3 +91,27 @@ func getEnvAsInt(name string, defaultVal int) int {
}
return defaultVal
}

func getEnvAsString(name string, defaultVal string) string {
if value, exists := os.LookupEnv(name); exists {
return value
}
return defaultVal
}

func getLogLevelFromString(level string) slog.Level {
switch strings.ToLower(level) {
case "debug":
return slog.LevelDebug
case "info":
return slog.LevelInfo
case "warn":
return slog.LevelWarn
case "error":
return slog.LevelError
case "fatal":
return slog.LevelError

}
return slog.LevelInfo
}
2 changes: 1 addition & 1 deletion clients/ui/bff/internal/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,5 @@ func (app *App) Routes() http.Handler {
router.GET(NamespaceListPath, app.GetNamespacesHandler)
}

return app.RecoverPanic(app.enableCORS(app.InjectUserHeaders(router)))
return app.RecoverPanic(app.EnableTelemetry(app.enableCORS(app.InjectUserHeaders(router))))
}
50 changes: 46 additions & 4 deletions clients/ui/bff/internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@ import (
"context"
"errors"
"fmt"
"net/http"
"strings"

"github.com/google/uuid"
"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/ui/bff/internal/config"
"github.com/kubeflow/model-registry/ui/bff/internal/integrations"
"log/slog"
"net/http"
"strings"
)

type contextKey string

const traceIdKey contextKey = "traceIdKey"
const traceLoggerKey contextKey = "traceLoggerKey"

const (
ModelRegistryHttpClientKey contextKey = "ModelRegistryHttpClientKey"
NamespaceHeaderParameterKey contextKey = "namespace"
Expand Down Expand Up @@ -81,6 +85,32 @@ func (app *App) enableCORS(next http.Handler) http.Handler {
})
}

func (app *App) EnableTelemetry(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Adds a unique id to the context to allow tracing of requests
traceId := uuid.NewString()
ctx := context.WithValue(r.Context(), traceIdKey, traceId)

// logger will only be nil in tests.
if app.logger != nil {
traceLogger := app.logger.With(slog.String("trace_id", traceId))
ctx = context.WithValue(ctx, traceLoggerKey, traceLogger)

if traceLogger.Enabled(ctx, slog.LevelDebug) {
body, err := r.GetBody()
if err != nil {
traceLogger.Debug("Error reading request body for debug logging", "error", err)
}
//TODO (Alex) Log headers, BUT we must ensure we don't log confidential data like tokens etc.
traceLogger.Debug("Incoming HTTP request", "method", r.Method, "url", r.URL.String(), "body", integrations.StreamToString(body))
//logger.Debug("Making upstream HTTP request", "request_id", reqId, "method", req.Method, "url", req.URL.String(), "body", streamToString(body))
}
}

next.ServeHTTP(w, r.WithContext(ctx))
})
}

func (app *App) AttachRESTClient(next func(http.ResponseWriter, *http.Request, httprouter.Params)) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {

Expand All @@ -97,7 +127,19 @@ func (app *App) AttachRESTClient(next func(http.ResponseWriter, *http.Request, h
return
}

client, err := integrations.NewHTTPClient(modelRegistryID, modelRegistryBaseURL)
// Set up a child logger for the rest client that automatically adds the request id to all statements for
// tracing.
restClientLogger := app.logger
traceId, ok := r.Context().Value(traceIdKey).(string)
if app.logger != nil {
if ok {
restClientLogger = app.logger.With(slog.String("trace_id", traceId))
} else {
app.logger.Warn("Failed to set trace_id for tracing")
}
}

client, err := integrations.NewHTTPClient(restClientLogger, modelRegistryID, modelRegistryBaseURL)
if err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("failed to create Kubernetes client: %v", err))
return
Expand Down
1 change: 1 addition & 0 deletions clients/ui/bff/internal/config/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ type EnvConfig struct {
DevMode bool
StandaloneMode bool
DevModePort int
LogLevel string
}
52 changes: 50 additions & 2 deletions clients/ui/bff/internal/integrations/http.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package integrations

import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"github.com/google/uuid"
"io"
"log/slog"
"net/http"
"strconv"
)

type HTTPClientInterface interface {
GetModelRegistryID() (modelRegistryService string)
GET(url string) ([]byte, error)
POST(url string, body io.Reader) ([]byte, error)
PATCH(url string, body io.Reader) ([]byte, error)
Expand All @@ -20,6 +23,7 @@ type HTTPClient struct {
client *http.Client
baseURL string
ModelRegistryID string
logger *slog.Logger
}

type ErrorResponse struct {
Expand All @@ -36,14 +40,15 @@ func (e *HTTPError) Error() string {
return fmt.Sprintf("HTTP %d: %s - %s", e.StatusCode, e.Code, e.Message)
}

func NewHTTPClient(modelRegistryID string, baseURL string) (HTTPClientInterface, error) {
func NewHTTPClient(logger *slog.Logger, modelRegistryID string, baseURL string) (HTTPClientInterface, error) {

return &HTTPClient{
client: &http.Client{Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}},
baseURL: baseURL,
ModelRegistryID: modelRegistryID,
logger: logger,
}, nil
}

Expand All @@ -52,26 +57,33 @@ func (c *HTTPClient) GetModelRegistryID() string {
}

func (c *HTTPClient) GET(url string) ([]byte, error) {
requestId := uuid.NewString()

fullURL := c.baseURL + url
req, err := http.NewRequest("GET", fullURL, nil)
if err != nil {
return nil, err
}

logUpstreamReq(c.logger, requestId, req)

response, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer response.Body.Close()

body, err := io.ReadAll(response.Body)
logUpstreamResp(c.logger, requestId, response, body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
return body, nil
}

func (c *HTTPClient) POST(url string, body io.Reader) ([]byte, error) {
requestId := uuid.NewString()

fullURL := c.baseURL + url
fmt.Println(fullURL)
req, err := http.NewRequest("POST", fullURL, body)
Expand All @@ -81,13 +93,16 @@ func (c *HTTPClient) POST(url string, body io.Reader) ([]byte, error) {

req.Header.Set("Content-Type", "application/json")

logUpstreamReq(c.logger, requestId, req)

response, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer response.Body.Close()

responseBody, err := io.ReadAll(response.Body)
logUpstreamResp(c.logger, requestId, response, responseBody)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
Expand Down Expand Up @@ -120,15 +135,20 @@ func (c *HTTPClient) PATCH(url string, body io.Reader) ([]byte, error) {
return nil, err
}

requestId := uuid.NewString()

req.Header.Set("Content-Type", "application/json")

logUpstreamReq(c.logger, requestId, req)

response, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer response.Body.Close()

responseBody, err := io.ReadAll(response.Body)
logUpstreamResp(c.logger, requestId, response, responseBody)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
Expand All @@ -152,3 +172,31 @@ func (c *HTTPClient) PATCH(url string, body io.Reader) ([]byte, error) {
}
return responseBody, nil
}

func logUpstreamReq(logger *slog.Logger, reqId string, req *http.Request) {
if logger.Enabled(context.TODO(), slog.LevelDebug) {
body, err := req.GetBody()
if err != nil {
logger.Debug("Error reading request body for debug logging", "requestId", reqId, "error", err)
}
logger.Debug("Making upstream HTTP request", "request_id", reqId, "method", req.Method, "url", req.URL.String(), "body", StreamToString(body))
}
}

func logUpstreamResp(logger *slog.Logger, reqId string, resp *http.Response, body []byte) {
if logger.Enabled(context.TODO(), slog.LevelDebug) {
logger.Debug("Received upstream HTTP response", "request_id", reqId, "status_code", resp.StatusCode, "body", body)
}
}

func StreamToString(stream io.Reader) string {
if stream == nil {
return ""
}
buf := new(bytes.Buffer)
_, err := buf.ReadFrom(stream)
if err != nil {
return ""
}
return buf.String()
}

0 comments on commit 68cd131

Please sign in to comment.