Skip to content

Commit

Permalink
[minor_change] Fix skipLoggingPayload usage in Authenticate to avoid …
Browse files Browse the repository at this point in the history
…data race condition issues.
  • Loading branch information
samiib authored and lhercot committed Jun 13, 2024
1 parent 8672b70 commit 964d0e9
Showing 1 changed file with 30 additions and 21 deletions.
51 changes: 30 additions & 21 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,10 @@ func (c *Client) useInsecureHTTPClient(insecure bool) *http.Transport {
// - Passwords with special chars have issues when using container
// - For encoding/decoding
func (c *Client) MakeRestRequestRaw(method string, rpath string, payload []byte, authenticated bool) (*http.Request, error) {
return c.makeRestRequestRaw(method, rpath, payload, authenticated, c.skipLoggingPayload)
}

func (c *Client) makeRestRequestRaw(method string, rpath string, payload []byte, authenticated bool, skipLoggingPayload bool) (*http.Request, error) {

pathURL, err := url.Parse(rpath)
if err != nil {
Expand Down Expand Up @@ -355,7 +359,7 @@ func (c *Client) MakeRestRequestRaw(method string, rpath string, payload []byte,
return nil, err
}

if c.skipLoggingPayload {
if skipLoggingPayload {
log.Printf("HTTP request %s %s", method, rpath)
} else {
log.Printf("HTTP request %s %s %v", method, rpath, req)
Expand All @@ -367,14 +371,18 @@ func (c *Client) MakeRestRequestRaw(method string, rpath string, payload []byte,
}
}

if !c.skipLoggingPayload {
if !skipLoggingPayload {
log.Printf("HTTP request after injection %s %s %v", method, rpath, req)
}

return req, nil
}

func (c *Client) MakeRestRequest(method string, rpath string, body *container.Container, authenticated bool) (*http.Request, error) {
return c.makeRestRequest(method, rpath, body, authenticated, c.skipLoggingPayload)
}

func (c *Client) makeRestRequest(method string, rpath string, body *container.Container, authenticated bool, skipLoggingPayload bool) (*http.Request, error) {

pathURL, err := url.Parse(rpath)
if err != nil {
Expand Down Expand Up @@ -409,7 +417,7 @@ func (c *Client) MakeRestRequest(method string, rpath string, body *container.Co
return nil, err
}

if c.skipLoggingPayload {
if skipLoggingPayload {
log.Printf("HTTP request %s %s", method, rpath)
} else {
log.Printf("HTTP request %s %s %v", method, rpath, req)
Expand All @@ -421,7 +429,7 @@ func (c *Client) MakeRestRequest(method string, rpath string, body *container.Co
}
}

if !c.skipLoggingPayload {
if !skipLoggingPayload {
log.Printf("HTTP request after injection %s %s %v", method, rpath, req)
}

Expand All @@ -430,9 +438,6 @@ func (c *Client) MakeRestRequest(method string, rpath string, body *container.Co

// Authenticate is used to
func (c *Client) Authenticate() error {
// Setting skipLoggingPayloadState to preserve state during call of the method
skipLoggingPayloadState := c.skipLoggingPayload

log.Printf("[DEBUG] Begining Authentication method")

method := "POST"
Expand All @@ -451,17 +456,12 @@ func (c *Client) Authenticate() error {
authenticated = true
}

// Setting skipLoggingPayload true so authentication details are not shown in logs
c.skipLoggingPayload = true

req, err := c.MakeRestRequestRaw(method, path, body, authenticated)
req, err := c.makeRestRequestRaw(method, path, body, authenticated, true)
if err != nil {
return err
}

obj, _, err := c.Do(req)

c.skipLoggingPayload = skipLoggingPayloadState
obj, _, err := c.do(req, true)

if err != nil {
log.Printf("[DEBUG] Authentication ERROR: %s", err)
Expand All @@ -470,7 +470,7 @@ func (c *Client) Authenticate() error {
if obj == nil {
return errors.New("Empty response")
}
err = CheckForErrors(obj, method, c.skipLoggingPayload)
err = CheckForErrors(obj, method, true)
if err != nil {
return err
}
Expand Down Expand Up @@ -502,11 +502,16 @@ func (c *Client) Authenticate() error {

return nil
}

func StrtoInt(s string, startIndex int, bitSize int) (int64, error) {
return strconv.ParseInt(s, startIndex, bitSize)

}

func (c *Client) Do(req *http.Request) (*container.Container, *http.Response, error) {
return c.do(req, c.skipLoggingPayload)
}

func (c *Client) do(req *http.Request, skipLoggingPayload bool) (*container.Container, *http.Response, error) {
log.Printf("[DEBUG] Begining Do method %s", req.URL.String())

// retain the request body across multiple attempts
Expand All @@ -520,7 +525,7 @@ func (c *Client) Do(req *http.Request) (*container.Container, *http.Response, er
if c.maxRetries != 0 {
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
}
if !c.skipLoggingPayload {
if !skipLoggingPayload {
log.Printf("[TRACE] HTTP Request Body: %v", req.Body)
}

Expand All @@ -536,7 +541,7 @@ func (c *Client) Do(req *http.Request) (*container.Container, *http.Response, er
}
}

if !c.skipLoggingPayload {
if !skipLoggingPayload {
log.Printf("[TRACE] HTTP Response: %d %s %v", resp.StatusCode, resp.Status, resp)
} else {
log.Printf("[TRACE] HTTP Response: %d %s", resp.StatusCode, resp.Status)
Expand All @@ -545,7 +550,7 @@ func (c *Client) Do(req *http.Request) (*container.Container, *http.Response, er
bodyBytes, err := ioutil.ReadAll(resp.Body)
bodyStr := string(bodyBytes)
resp.Body.Close()
if !c.skipLoggingPayload {
if !skipLoggingPayload {
log.Printf("[DEBUG] HTTP response unique string %s %s %s", req.Method, req.URL.String(), bodyStr)
}

Expand Down Expand Up @@ -589,6 +594,10 @@ func (c *Client) Do(req *http.Request) (*container.Container, *http.Response, er
}

func (c *Client) DoRaw(req *http.Request) (*http.Response, error) {
return c.doRaw(req, c.skipLoggingPayload)
}

func (c *Client) doRaw(req *http.Request, skipLoggingPayload bool) (*http.Response, error) {
log.Printf("[DEBUG] Begining DoRaw method %s", req.URL.String())

// retain the request body across multiple attempts
Expand All @@ -602,7 +611,7 @@ func (c *Client) DoRaw(req *http.Request) (*http.Response, error) {
if c.maxRetries != 0 {
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
}
if !c.skipLoggingPayload {
if !skipLoggingPayload {
log.Printf("[TRACE] HTTP Request Body: %v", req.Body)
}

Expand All @@ -618,7 +627,7 @@ func (c *Client) DoRaw(req *http.Request) (*http.Response, error) {
}
}

if !c.skipLoggingPayload {
if !skipLoggingPayload {
log.Printf("[TRACE] HTTP Response: %d %s %v", resp.StatusCode, resp.Status, resp)
} else {
log.Printf("[TRACE] HTTP Response: %d %s", resp.StatusCode, resp.Status)
Expand Down

0 comments on commit 964d0e9

Please sign in to comment.