From 64c8f0be13eb6b08237d6dfc3f9eef2feac12714 Mon Sep 17 00:00:00 2001 From: Alvin Lin Date: Fri, 26 Jan 2024 11:24:59 -0800 Subject: [PATCH] Fix issue 185 - request cannot rewind during retry (#186) * Ensure the proxied request is type that is rewindable * Add comment on why we copy request body * Fix test by dealing with nil cases --- handler/proxy_client.go | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/handler/proxy_client.go b/handler/proxy_client.go index 740d5df..492edf2 100644 --- a/handler/proxy_client.go +++ b/handler/proxy_client.go @@ -18,6 +18,7 @@ package handler import ( "bytes" "fmt" + "io" "io/ioutil" "net/http" "net/http/httputil" @@ -119,6 +120,14 @@ func chunked(transferEncoding []string) bool { return false } +func readDownStreamRequestBody(req *http.Request) ([]byte, error) { + if req.Body == nil { + return []byte{}, nil + } + defer req.Body.Close() + return io.ReadAll(req.Body) +} + func (p *ProxyClient) Do(req *http.Request) (*http.Response, error) { proxyURL := *req.URL if p.HostOverride != "" { @@ -140,7 +149,16 @@ func (p *ProxyClient) Do(req *http.Request) (*http.Response, error) { log.WithField("request", string(initialReqDump)).Debug("Initial request dump:") } - proxyReq, err := http.NewRequest(req.Method, proxyURL.String(), req.Body) + // Save the request body into memory so that it's rewindable during retry. + // See https://github.com/awslabs/aws-sigv4-proxy/issues/185 + // This may increase memory demand, but the demand should be ok for most cases. If there + // are cases proven to be very problematic, we can consider adding a flag to disable this. + proxyReqBody, err := readDownStreamRequestBody(req) + if err != nil { + return nil, err + } + + proxyReq, err := http.NewRequest(req.Method, proxyURL.String(), bytes.NewReader(proxyReqBody)) if err != nil { return nil, err } @@ -222,7 +240,7 @@ func (p *ProxyClient) Do(req *http.Request) (*http.Response, error) { } if (p.LogFailedRequest || log.GetLevel() == log.DebugLevel) && resp.StatusCode >= 400 { - b, _ := ioutil.ReadAll(resp.Body) + b, _ := io.ReadAll(resp.Body) log.WithField("request", fmt.Sprintf("%s %s", proxyReq.Method, proxyReq.URL)). WithField("status_code", resp.StatusCode). WithField("message", string(b)). @@ -230,7 +248,7 @@ func (p *ProxyClient) Do(req *http.Request) (*http.Response, error) { // Need to "reset" the response body because we consumed the stream above, otherwise caller will // get empty body. - resp.Body = ioutil.NopCloser(bytes.NewBuffer(b)) + resp.Body = io.NopCloser(bytes.NewBuffer(b)) } return resp, nil