diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 784d704e..9279e76f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,6 +8,7 @@ jobs: name: run tests runs-on: ubuntu-latest strategy: + fail-fast: false matrix: go: - "1.21" diff --git a/lambda/handler.go b/lambda/handler.go index cd55f7af..c455656d 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -23,6 +23,7 @@ type Handler interface { type handlerOptions struct { handlerFunc baseContext context.Context + contextValues map[interface{}]interface{} jsonRequestUseNumber bool jsonRequestDisallowUnknownFields bool jsonResponseEscapeHTML bool @@ -50,6 +51,23 @@ func WithContext(ctx context.Context) Option { }) } +// WithContextValue adds a value to the handler context. +// If a base context was set using WithContext, that base is used as the parent. +// +// Usage: +// +// lambda.StartWithOptions( +// func (ctx context.Context) (string, error) { +// return ctx.Value("foo"), nil +// }, +// lambda.WithContextValue("foo", "bar") +// ) +func WithContextValue(key interface{}, value interface{}) Option { + return Option(func(h *handlerOptions) { + h.contextValues[key] = value + }) +} + // WithSetEscapeHTML sets the SetEscapeHTML argument on the underlying json encoder // // Usage: @@ -211,6 +229,7 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions { } h := &handlerOptions{ baseContext: context.Background(), + contextValues: map[interface{}]interface{}{}, jsonResponseEscapeHTML: false, jsonResponseIndentPrefix: "", jsonResponseIndentValue: "", @@ -218,6 +237,9 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions { for _, option := range options { option(h) } + for k, v := range h.contextValues { + h.baseContext = context.WithValue(h.baseContext, k, v) + } if h.enableSIGTERM { enableSIGTERM(h.sigtermCallbacks) } diff --git a/lambda/sigterm_test.go b/lambda/sigterm_test.go index 886a6d72..278d5926 100644 --- a/lambda/sigterm_test.go +++ b/lambda/sigterm_test.go @@ -9,6 +9,7 @@ import ( "os" "os/exec" "path" + "strconv" "strings" "testing" "time" @@ -17,10 +18,6 @@ import ( "github.com/stretchr/testify/require" ) -const ( - rieInvokeAPI = "http://localhost:8080/2015-03-31/functions/function/invocations" -) - func TestEnableSigterm(t *testing.T) { if _, err := exec.LookPath("aws-lambda-rie"); err != nil { t.Skipf("%v - install from https://github.com/aws/aws-lambda-runtime-interface-emulator/", err) @@ -34,6 +31,7 @@ func TestEnableSigterm(t *testing.T) { handlerBuild.Stdout = os.Stderr require.NoError(t, handlerBuild.Run()) + portI := 0 for name, opts := range map[string]struct { envVars []string assertLogs func(t *testing.T, logs string) @@ -53,8 +51,12 @@ func TestEnableSigterm(t *testing.T) { }, } { t.Run(name, func(t *testing.T) { + portI += 1 + addr1 := "localhost:" + strconv.Itoa(8000+portI) + addr2 := "localhost:" + strconv.Itoa(9000+portI) + rieInvokeAPI := "http://" + addr1 + "/2015-03-31/functions/function/invocations" // run the runtime interface emulator, capture the logs for assertion - cmd := exec.Command("aws-lambda-rie", "sigterm.handler") + cmd := exec.Command("aws-lambda-rie", "--runtime-interface-emulator-address", addr1, "--runtime-api-address", addr2, "sigterm.handler") cmd.Env = append([]string{ "PATH=" + testDir, "AWS_LAMBDA_FUNCTION_TIMEOUT=2", diff --git a/lambdaurl/http_handler.go b/lambdaurl/http_handler.go index 70e73d0a..79a63e18 100644 --- a/lambdaurl/http_handler.go +++ b/lambdaurl/http_handler.go @@ -18,24 +18,76 @@ import ( "github.com/aws/aws-lambda-go/lambda" ) +type detectContentTypeContextKey struct{} + +// WithDetectContentType sets the behavior of content type detection when the Content-Type header is not already provided. +// When true, the first Write call will pass the intial bytes to http.DetectContentType. +// When false, and if no Content-Type is provided, no Content-Type will be sent back to Lambda, +// and the Lambda Function URL will fallback to it's default. +// +// Note: The http.ResponseWriter passed to the handler is unbuffered. +// This may result in different Content-Type headers in the Function URL response when compared to http.ListenAndServe. +// +// Usage: +// +// lambdaurl.Start( +// http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { +// w.Write("") +// }), +// lambdaurl.WithDetectContentType(true) +// ) +func WithDetectContentType(detectContentType bool) lambda.Option { + return lambda.WithContextValue(detectContentTypeContextKey{}, detectContentType) +} + type httpResponseWriter struct { + detectContentType bool + header http.Header + writer io.Writer + once sync.Once + ready chan<- header +} + +type header struct { + code int header http.Header - writer io.Writer - once sync.Once - status chan<- int } func (w *httpResponseWriter) Header() http.Header { + if w.header == nil { + w.header = http.Header{} + } return w.header } func (w *httpResponseWriter) Write(p []byte) (int, error) { - w.once.Do(func() { w.status <- http.StatusOK }) + w.writeHeader(http.StatusOK, p) return w.writer.Write(p) } func (w *httpResponseWriter) WriteHeader(statusCode int) { - w.once.Do(func() { w.status <- statusCode }) + w.writeHeader(statusCode, nil) +} + +func (w *httpResponseWriter) writeHeader(statusCode int, initialPayload []byte) { + w.once.Do(func() { + if w.detectContentType { + if w.Header().Get("Content-Type") == "" { + w.Header().Set("Content-Type", detectContentType(initialPayload)) + } + } + w.ready <- header{code: statusCode, header: w.header} + }) +} + +func detectContentType(p []byte) string { + // http.DetectContentType returns "text/plain; charset=utf-8" for nil and zero-length byte slices. + // This is a weird behavior, since otherwise it defaults to "application/octet-stream"! So we'll do that. + // This differs from http.ListenAndServe, which set no Content-Type when the initial Flush body is empty. + if len(p) == 0 { + return "application/octet-stream" + } + return http.DetectContentType(p) } type requestContextKey struct{} @@ -46,11 +98,13 @@ func RequestFromContext(ctx context.Context) (*events.LambdaFunctionURLRequest, return req, ok } -// Wrap converts an http.Handler into a lambda request handler. +// Wrap converts an http.Handler into a Lambda request handler. +// // Only Lambda Function URLs configured with `InvokeMode: RESPONSE_STREAM` are supported with the returned handler. -// The response body of the handler will conform to the content-type `application/vnd.awslambda.http-integration-response` +// The response body of the handler will conform to the content-type `application/vnd.awslambda.http-integration-response`. func Wrap(handler http.Handler) func(context.Context, *events.LambdaFunctionURLRequest) (*events.LambdaFunctionURLStreamingResponse, error) { return func(ctx context.Context, request *events.LambdaFunctionURLRequest) (*events.LambdaFunctionURLStreamingResponse, error) { + var body io.Reader = strings.NewReader(request.Body) if request.IsBase64Encoded { body = base64.NewDecoder(base64.StdEncoding, body) @@ -67,21 +121,28 @@ func Wrap(handler http.Handler) func(context.Context, *events.LambdaFunctionURLR for k, v := range request.Headers { httpRequest.Header.Add(k, v) } - status := make(chan int) // Signals when it's OK to start returning the response body to Lambda - header := http.Header{} + + ready := make(chan header) // Signals when it's OK to start returning the response body to Lambda r, w := io.Pipe() + responseWriter := &httpResponseWriter{writer: w, ready: ready} + if detectContentType, ok := ctx.Value(detectContentTypeContextKey{}).(bool); ok { + responseWriter.detectContentType = detectContentType + } go func() { - defer close(status) + defer close(ready) defer w.Close() // TODO: recover and CloseWithError the any panic value once the runtime API client supports plumbing fatal errors through the reader - handler.ServeHTTP(&httpResponseWriter{writer: w, header: header, status: status}, httpRequest) + //nolint:errcheck + defer responseWriter.Write(nil) // force default status, headers, content type detection, if none occured during the execution of the handler + handler.ServeHTTP(responseWriter, httpRequest) }() + header := <-ready response := &events.LambdaFunctionURLStreamingResponse{ Body: r, - StatusCode: <-status, + StatusCode: header.code, } - if len(header) > 0 { - response.Headers = make(map[string]string, len(header)) - for k, v := range header { + if len(header.header) > 0 { + response.Headers = make(map[string]string, len(header.header)) + for k, v := range header.header { if k == "Set-Cookie" { response.Cookies = v } else { diff --git a/lambdaurl/http_handler_test.go b/lambdaurl/http_handler_test.go index a6e6aa8d..419cbd7c 100644 --- a/lambdaurl/http_handler_test.go +++ b/lambdaurl/http_handler_test.go @@ -13,6 +13,11 @@ import ( "io/ioutil" "log" "net/http" + "os" + "os/exec" + "path" + "strconv" + "strings" "testing" "time" @@ -35,12 +40,13 @@ var base64EncodedBodyRequest []byte func TestWrap(t *testing.T) { for name, params := range map[string]struct { - input []byte - handler http.HandlerFunc - expectStatus int - expectBody string - expectHeaders map[string]string - expectCookies []string + input []byte + handler http.HandlerFunc + detectContentType bool + expectStatus int + expectBody string + expectHeaders map[string]string + expectCookies []string }{ "hello": { input: helloRequest, @@ -58,10 +64,8 @@ func TestWrap(t *testing.T) { encoder := json.NewEncoder(w) _ = encoder.Encode(struct{ RequestQueryParams, Method any }{r.URL.Query(), r.Method}) }, - expectStatus: http.StatusTeapot, - expectHeaders: map[string]string{ - "Hello": "world1,world2", - }, + expectStatus: http.StatusTeapot, + expectHeaders: map[string]string{"Hello": "world1,world2"}, expectCookies: []string{ "yummy=cookie", "yummy=cake", @@ -110,6 +114,13 @@ func TestWrap(t *testing.T) { handler: func(w http.ResponseWriter, r *http.Request) {}, expectStatus: http.StatusOK, }, + "write status code only": { + input: helloRequest, + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusAccepted) + }, + expectStatus: http.StatusAccepted, + }, "base64request": { input: base64EncodedBodyRequest, handler: func(w http.ResponseWriter, r *http.Request) { @@ -118,12 +129,58 @@ func TestWrap(t *testing.T) { expectStatus: http.StatusOK, expectBody: "", }, + "detect content type: write status code only": { + input: helloRequest, + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusAccepted) + }, + detectContentType: true, + expectStatus: http.StatusAccepted, + expectHeaders: map[string]string{ + "Content-Type": "application/octet-stream", + }, + }, + "detect content type: empty handler": { + input: helloRequest, + handler: func(w http.ResponseWriter, r *http.Request) { + }, + detectContentType: true, + expectStatus: http.StatusOK, + expectHeaders: map[string]string{ + "Content-Type": "application/octet-stream", + }, + }, + "detect content type: writes html": { + input: helloRequest, + handler: func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("")) + }, + detectContentType: true, + expectBody: "", + expectStatus: http.StatusOK, + expectHeaders: map[string]string{ + "Content-Type": "text/html; charset=utf-8", + }, + }, + "detect content type: writes zeros": { + input: helloRequest, + handler: func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte{0, 0, 0, 0, 0}) + }, + detectContentType: true, + expectBody: "\x00\x00\x00\x00\x00", + expectStatus: http.StatusOK, + expectHeaders: map[string]string{ + "Content-Type": "application/octet-stream", + }, + }, } { t.Run(name, func(t *testing.T) { handler := Wrap(params.handler) var req events.LambdaFunctionURLRequest require.NoError(t, json.Unmarshal(params.input, &req)) - res, err := handler(context.Background(), &req) + ctx := context.WithValue(context.Background(), detectContentTypeContextKey{}, params.detectContentType) + res, err := handler(ctx, &req) require.NoError(t, err) resultBodyBytes, err := ioutil.ReadAll(res) require.NoError(t, err) @@ -155,3 +212,56 @@ func TestRequestContext(t *testing.T) { _, err := handler(context.Background(), req) require.NoError(t, err) } + +func TestStartViaEmulator(t *testing.T) { + addr1 := "localhost:" + strconv.Itoa(6001) + addr2 := "localhost:" + strconv.Itoa(7001) + rieInvokeAPI := "http://" + addr1 + "/2015-03-31/functions/function/invocations" + if _, err := exec.LookPath("aws-lambda-rie"); err != nil { + t.Skipf("%v - install from https://github.com/aws/aws-lambda-runtime-interface-emulator/", err) + } + + // compile our handler, it'll always run to timeout ensuring the SIGTERM is triggered by aws-lambda-rie + testDir := t.TempDir() + handlerBuild := exec.Command("go", "build", "-o", path.Join(testDir, "lambdaurl.handler"), "./testdata/lambdaurl.go") + handlerBuild.Stderr = os.Stderr + handlerBuild.Stdout = os.Stderr + require.NoError(t, handlerBuild.Run()) + + // run the runtime interface emulator, capture the logs for assertion + cmd := exec.Command("aws-lambda-rie", "--runtime-interface-emulator-address", addr1, "--runtime-api-address", addr2, "lambdaurl.handler") + cmd.Env = []string{ + "PATH=" + testDir, + "AWS_LAMBDA_FUNCTION_TIMEOUT=2", + } + cmd.Stderr = os.Stderr + stdout, err := cmd.StdoutPipe() + require.NoError(t, err) + var logs string + done := make(chan interface{}) // closed on completion of log flush + go func() { + logBytes, err := ioutil.ReadAll(stdout) + require.NoError(t, err) + logs = string(logBytes) + close(done) + }() + require.NoError(t, cmd.Start()) + t.Cleanup(func() { _ = cmd.Process.Kill() }) + + // give a moment for the port to bind + time.Sleep(500 * time.Millisecond) + + client := &http.Client{Timeout: 5 * time.Second} // http client timeout to prevent case from hanging on aws-lambda-rie + resp, err := client.Post(rieInvokeAPI, "application/json", strings.NewReader("{}")) + require.NoError(t, err) + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + assert.NoError(t, err) + + expected := "{\"statusCode\":200,\"headers\":{\"Content-Type\":\"text/html; charset=utf-8\"}}\x00\x00\x00\x00\x00\x00\x00\x00\n\n\nHello World!\n\n\n" + assert.Equal(t, expected, string(body)) + + require.NoError(t, cmd.Process.Kill()) // now ensure the logs are drained + <-done + t.Logf("stdout:\n%s", logs) +} diff --git a/lambdaurl/testdata/lambdaurl.go b/lambdaurl/testdata/lambdaurl.go new file mode 100644 index 00000000..ee12462f --- /dev/null +++ b/lambdaurl/testdata/lambdaurl.go @@ -0,0 +1,25 @@ +package main + +import ( + "io" + "net/http" + "strings" + + "github.com/aws/aws-lambda-go/lambdaurl" +) + +const content = ` + + +Hello World! + + +` + +func main() { + lambdaurl.Start(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.Copy(w, strings.NewReader(content)) + }), + lambdaurl.WithDetectContentType(true), + ) +}