From 964d0e96055e0bb6ac10440db4012f8b5801bd91 Mon Sep 17 00:00:00 2001 From: samitab Date: Fri, 31 May 2024 14:35:59 +1000 Subject: [PATCH] [minor_change] Fix skipLoggingPayload usage in Authenticate to avoid data race condition issues. --- client/client.go | 51 ++++++++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/client/client.go b/client/client.go index fdc4fe1..0b4f0b1 100644 --- a/client/client.go +++ b/client/client.go @@ -321,6 +321,10 @@ func (c *Client) useInsecureHTTPClient(insecure bool) *http.Transport { // - Passwords with special chars have issues when using container // - For encoding/decoding func (c *Client) MakeRestRequestRaw(method string, rpath string, payload []byte, authenticated bool) (*http.Request, error) { + return c.makeRestRequestRaw(method, rpath, payload, authenticated, c.skipLoggingPayload) +} + +func (c *Client) makeRestRequestRaw(method string, rpath string, payload []byte, authenticated bool, skipLoggingPayload bool) (*http.Request, error) { pathURL, err := url.Parse(rpath) if err != nil { @@ -355,7 +359,7 @@ func (c *Client) MakeRestRequestRaw(method string, rpath string, payload []byte, return nil, err } - if c.skipLoggingPayload { + if skipLoggingPayload { log.Printf("HTTP request %s %s", method, rpath) } else { log.Printf("HTTP request %s %s %v", method, rpath, req) @@ -367,7 +371,7 @@ func (c *Client) MakeRestRequestRaw(method string, rpath string, payload []byte, } } - if !c.skipLoggingPayload { + if !skipLoggingPayload { log.Printf("HTTP request after injection %s %s %v", method, rpath, req) } @@ -375,6 +379,10 @@ func (c *Client) MakeRestRequestRaw(method string, rpath string, payload []byte, } func (c *Client) MakeRestRequest(method string, rpath string, body *container.Container, authenticated bool) (*http.Request, error) { + return c.makeRestRequest(method, rpath, body, authenticated, c.skipLoggingPayload) +} + +func (c *Client) makeRestRequest(method string, rpath string, body *container.Container, authenticated bool, skipLoggingPayload bool) (*http.Request, error) { pathURL, err := url.Parse(rpath) if err != nil { @@ -409,7 +417,7 @@ func (c *Client) MakeRestRequest(method string, rpath string, body *container.Co return nil, err } - if c.skipLoggingPayload { + if skipLoggingPayload { log.Printf("HTTP request %s %s", method, rpath) } else { log.Printf("HTTP request %s %s %v", method, rpath, req) @@ -421,7 +429,7 @@ func (c *Client) MakeRestRequest(method string, rpath string, body *container.Co } } - if !c.skipLoggingPayload { + if !skipLoggingPayload { log.Printf("HTTP request after injection %s %s %v", method, rpath, req) } @@ -430,9 +438,6 @@ func (c *Client) MakeRestRequest(method string, rpath string, body *container.Co // Authenticate is used to func (c *Client) Authenticate() error { - // Setting skipLoggingPayloadState to preserve state during call of the method - skipLoggingPayloadState := c.skipLoggingPayload - log.Printf("[DEBUG] Begining Authentication method") method := "POST" @@ -451,17 +456,12 @@ func (c *Client) Authenticate() error { authenticated = true } - // Setting skipLoggingPayload true so authentication details are not shown in logs - c.skipLoggingPayload = true - - req, err := c.MakeRestRequestRaw(method, path, body, authenticated) + req, err := c.makeRestRequestRaw(method, path, body, authenticated, true) if err != nil { return err } - obj, _, err := c.Do(req) - - c.skipLoggingPayload = skipLoggingPayloadState + obj, _, err := c.do(req, true) if err != nil { log.Printf("[DEBUG] Authentication ERROR: %s", err) @@ -470,7 +470,7 @@ func (c *Client) Authenticate() error { if obj == nil { return errors.New("Empty response") } - err = CheckForErrors(obj, method, c.skipLoggingPayload) + err = CheckForErrors(obj, method, true) if err != nil { return err } @@ -502,11 +502,16 @@ func (c *Client) Authenticate() error { return nil } + func StrtoInt(s string, startIndex int, bitSize int) (int64, error) { return strconv.ParseInt(s, startIndex, bitSize) - } + func (c *Client) Do(req *http.Request) (*container.Container, *http.Response, error) { + return c.do(req, c.skipLoggingPayload) +} + +func (c *Client) do(req *http.Request, skipLoggingPayload bool) (*container.Container, *http.Response, error) { log.Printf("[DEBUG] Begining Do method %s", req.URL.String()) // retain the request body across multiple attempts @@ -520,7 +525,7 @@ func (c *Client) Do(req *http.Request) (*container.Container, *http.Response, er if c.maxRetries != 0 { req.Body = ioutil.NopCloser(bytes.NewBuffer(body)) } - if !c.skipLoggingPayload { + if !skipLoggingPayload { log.Printf("[TRACE] HTTP Request Body: %v", req.Body) } @@ -536,7 +541,7 @@ func (c *Client) Do(req *http.Request) (*container.Container, *http.Response, er } } - if !c.skipLoggingPayload { + if !skipLoggingPayload { log.Printf("[TRACE] HTTP Response: %d %s %v", resp.StatusCode, resp.Status, resp) } else { log.Printf("[TRACE] HTTP Response: %d %s", resp.StatusCode, resp.Status) @@ -545,7 +550,7 @@ func (c *Client) Do(req *http.Request) (*container.Container, *http.Response, er bodyBytes, err := ioutil.ReadAll(resp.Body) bodyStr := string(bodyBytes) resp.Body.Close() - if !c.skipLoggingPayload { + if !skipLoggingPayload { log.Printf("[DEBUG] HTTP response unique string %s %s %s", req.Method, req.URL.String(), bodyStr) } @@ -589,6 +594,10 @@ func (c *Client) Do(req *http.Request) (*container.Container, *http.Response, er } func (c *Client) DoRaw(req *http.Request) (*http.Response, error) { + return c.doRaw(req, c.skipLoggingPayload) +} + +func (c *Client) doRaw(req *http.Request, skipLoggingPayload bool) (*http.Response, error) { log.Printf("[DEBUG] Begining DoRaw method %s", req.URL.String()) // retain the request body across multiple attempts @@ -602,7 +611,7 @@ func (c *Client) DoRaw(req *http.Request) (*http.Response, error) { if c.maxRetries != 0 { req.Body = ioutil.NopCloser(bytes.NewBuffer(body)) } - if !c.skipLoggingPayload { + if !skipLoggingPayload { log.Printf("[TRACE] HTTP Request Body: %v", req.Body) } @@ -618,7 +627,7 @@ func (c *Client) DoRaw(req *http.Request) (*http.Response, error) { } } - if !c.skipLoggingPayload { + if !skipLoggingPayload { log.Printf("[TRACE] HTTP Response: %d %s %v", resp.StatusCode, resp.Status, resp) } else { log.Printf("[TRACE] HTTP Response: %d %s", resp.StatusCode, resp.Status)