From eeb1a5d96f9a3c619c6e43b6aa31c3810af153c1 Mon Sep 17 00:00:00 2001 From: Tom Bamford Date: Tue, 20 Feb 2024 11:46:50 +0000 Subject: [PATCH 1/2] Improve binary unmarshalling and support per-request RetryFuncs --- sdk/client/client.go | 60 +++++++++++++++----------- sdk/client/client_test.go | 14 +++--- sdk/client/dataplane/client.go | 10 ----- sdk/client/dataplane/storage/client.go | 24 ++++------- sdk/client/request_options.go | 3 ++ sdk/client/resourcemanager/client.go | 2 +- 6 files changed, 55 insertions(+), 58 deletions(-) diff --git a/sdk/client/client.go b/sdk/client/client.go index b6689bbe77c..ee26b40a0b2 100644 --- a/sdk/client/client.go +++ b/sdk/client/client.go @@ -91,39 +91,43 @@ type Request struct { func (r *Request) Marshal(payload interface{}) error { contentType := strings.ToLower(r.Header.Get("Content-Type")) - if strings.Contains(contentType, "application/json") { + switch { + case strings.Contains(contentType, "application/json"): body, err := json.Marshal(payload) if err == nil { r.ContentLength = int64(len(body)) r.Body = io.NopCloser(bytes.NewReader(body)) } + return nil - } - if strings.Contains(contentType, "application/xml") || strings.Contains(contentType, "text/xml") { + case strings.Contains(contentType, "application/xml") || strings.Contains(contentType, "text/xml"): body, err := xml.Marshal(payload) if err == nil { r.ContentLength = int64(len(body)) r.Body = io.NopCloser(bytes.NewReader(body)) } + return nil } - if strings.Contains(contentType, "application/octet-stream") || strings.Contains(contentType, "text/powershell") { - switch v := payload.(type) { - case *[]byte: + switch v := payload.(type) { + case *[]byte: + if v == nil { + r.ContentLength = int64(len([]byte{})) + r.Body = io.NopCloser(bytes.NewReader([]byte{})) + } else { r.ContentLength = int64(len(*v)) r.Body = io.NopCloser(bytes.NewReader(*v)) - case []byte: - r.ContentLength = int64(len(v)) - r.Body = io.NopCloser(bytes.NewReader(v)) - default: - return fmt.Errorf("internal-error: `payload` must be []byte or *[]byte but got type %T", payload) } - return nil + case []byte: + r.ContentLength = int64(len(v)) + r.Body = io.NopCloser(bytes.NewReader(v)) + default: + return fmt.Errorf("internal-error: `payload` must be []byte or *[]byte but got type %T", payload) } - return fmt.Errorf("internal-error: unimplemented marshal function for content type %q", contentType) + return nil } // Execute invokes the Execute method for the Request's Client @@ -173,12 +177,14 @@ func (r *Response) Unmarshal(model interface{}) error { contentType = strings.ToLower(r.Request.Header.Get("Content-Type")) } } - // the maintenance API returns a 200 for a delete with no content-type and no content length so we should skip - // trying to unmarshal this + + // Some APIs (e.g. Maintenance) return 200 without a body, don't unmarshal these if r.ContentLength == 0 && (r.Body == nil || r.Body == http.NoBody) { return nil } - if strings.Contains(contentType, "application/json") { + + switch { + case strings.Contains(contentType, "application/json"): // Read the response body and close it respBody, err := io.ReadAll(r.Body) if err != nil { @@ -201,10 +207,10 @@ func (r *Response) Unmarshal(model interface{}) error { // Reassign the response body as downstream code may expect it r.Body = io.NopCloser(bytes.NewBuffer(respBody)) + return nil - } - if strings.Contains(contentType, "application/xml") || strings.Contains(contentType, "text/xml") { + case strings.Contains(contentType, "application/xml") || strings.Contains(contentType, "text/xml"): // Read the response body and close it respBody, err := io.ReadAll(r.Body) if err != nil { @@ -227,13 +233,13 @@ func (r *Response) Unmarshal(model interface{}) error { // Reassign the response body as downstream code may expect it r.Body = io.NopCloser(bytes.NewBuffer(respBody)) + return nil - } - if strings.Contains(contentType, "application/octet-stream") || strings.Contains(contentType, "text/powershell") { - ptr, ok := model.(**[]byte) + case strings.Contains(contentType, "application/octet-stream") || strings.Contains(contentType, "text/powershell"): + ptr, ok := model.(*[]byte) if !ok || ptr == nil { - return fmt.Errorf("internal-error: `model` must be a non-nil `**[]byte` but got %+v", model) + return fmt.Errorf("internal-error: `model` must be a non-nil `*[]byte` but got %[1]T: %+[1]v", model) } // Read the response body and close it @@ -243,14 +249,17 @@ func (r *Response) Unmarshal(model interface{}) error { } r.Body.Close() - // Trim away a BOM if present - respBody = bytes.TrimPrefix(respBody, []byte("\xef\xbb\xbf")) + if strings.HasPrefix(contentType, "text/") { + // Trim away a BOM if present + respBody = bytes.TrimPrefix(respBody, []byte("\xef\xbb\xbf")) + } // copy the byte stream across - *ptr = &respBody + *ptr = respBody // Reassign the response body as downstream code may expect it r.Body = io.NopCloser(bytes.NewBuffer(respBody)) + return nil } @@ -374,6 +383,7 @@ func (c *Client) NewRequest(ctx context.Context, input RequestOptions) (*Request Client: c, Request: req, Pager: input.Pager, + RetryFunc: input.RetryFunc, ValidStatusCodes: input.ExpectedStatusCodes, } diff --git a/sdk/client/client_test.go b/sdk/client/client_test.go index 7c804c5e803..ef9ef9fb432 100644 --- a/sdk/client/client_test.go +++ b/sdk/client/client_test.go @@ -9,6 +9,7 @@ import ( "encoding/json" "encoding/xml" "fmt" + "github.com/hashicorp/go-azure-helpers/lang/pointer" "io" "log" "net/http" @@ -16,7 +17,6 @@ import ( "reflect" "testing" - "github.com/hashicorp/go-azure-helpers/lang/pointer" "github.com/hashicorp/go-azure-sdk/sdk/internal/test" "github.com/hashicorp/go-azure-sdk/sdk/odata" ) @@ -355,12 +355,12 @@ func TestUnmarshalByteStreamAndPowerShell(t *testing.T) { Body: io.NopCloser(bytes.NewReader(expected)), }, } - var unmarshaled = pointer.To(make([]byte, 0)) + var unmarshaled = make([]byte, 0) if err := r.Unmarshal(&unmarshaled); err != nil { t.Fatalf("unmarshaling: %+v", err) } - if string(*unmarshaled) != "you serve butter" { - t.Fatalf("unexpected difference in decoded objects. Expected %q\n\nGot: %q", string(expected), string(*unmarshaled)) + if string(unmarshaled) != "you serve butter" { + t.Fatalf("unexpected difference in decoded objects. Expected %q\n\nGot: %q", string(expected), string(unmarshaled)) } } } @@ -373,7 +373,9 @@ func TestUnmarshalByteStreamAndPowerShellWithModel(t *testing.T) { var respModel = struct { HttpResponse *http.Response Model *[]byte - }{} + }{ + Model: pointer.To(make([]byte, 0)), + } expected := []byte("you serve butter") for _, contentType := range contentTypes { r := &Response{ @@ -384,7 +386,7 @@ func TestUnmarshalByteStreamAndPowerShellWithModel(t *testing.T) { Body: io.NopCloser(bytes.NewReader(expected)), }, } - if err := r.Unmarshal(&respModel.Model); err != nil { + if err := r.Unmarshal(respModel.Model); err != nil { t.Fatalf("unmarshaling: %+v", err) } if string(*respModel.Model) != "you serve butter" { diff --git a/sdk/client/dataplane/client.go b/sdk/client/dataplane/client.go index 199a8210da4..5e0dd27d9f0 100644 --- a/sdk/client/dataplane/client.go +++ b/sdk/client/dataplane/client.go @@ -4,8 +4,6 @@ package dataplane import ( - "context" - "github.com/hashicorp/go-azure-sdk/sdk/client" ) @@ -27,11 +25,3 @@ func NewDataPlaneClient(baseUri string, serviceName, apiVersion string) *Client } return client } - -func (c *Client) Execute(ctx context.Context, req *client.Request) (*client.Response, error) { - return c.Client.Execute(ctx, req) -} - -func (c *Client) ExecutePaged(ctx context.Context, req *client.Request) (*client.Response, error) { - return c.Client.ExecutePaged(ctx, req) -} diff --git a/sdk/client/dataplane/storage/client.go b/sdk/client/dataplane/storage/client.go index 0b0c63e1efe..27dedf17c00 100644 --- a/sdk/client/dataplane/storage/client.go +++ b/sdk/client/dataplane/storage/client.go @@ -12,31 +12,31 @@ import ( "github.com/hashicorp/go-azure-sdk/sdk/client/dataplane" ) -var _ client.BaseClient = &BaseClient{} +var _ client.BaseClient = &Client{} var storageDefaultRetryFunctions = []client.RequestRetryFunc{ // TODO: stuff n tings } -type BaseClient struct { +type Client struct { *dataplane.Client } -func NewBaseClient(baseUri string, componentName, apiVersion string) (*BaseClient, error) { +func NewStorageClient(baseUri string, componentName, apiVersion string) (*Client, error) { // NOTE: both the domain name _and_ the domain format can change entirely depending on the type of storage account being used // when provisioned in an edge zone, and when AzureDNSZone is used, as such we require the baseUri is provided here - return &BaseClient{ + return &Client{ Client: dataplane.NewDataPlaneClient(baseUri, fmt.Sprintf("storage/%s", componentName), apiVersion), }, nil } -func (c *BaseClient) NewRequest(ctx context.Context, input client.RequestOptions) (*client.Request, error) { +func (c *Client) NewRequest(ctx context.Context, input client.RequestOptions) (*client.Request, error) { // TODO move these validations to base client method if _, ok := ctx.Deadline(); !ok { - return nil, fmt.Errorf("the context used must have a deadline attached for polling purposes, but got no deadline") + return nil, fmt.Errorf("internal-error: the context used must have a deadline attached for polling purposes, but got no deadline") } if err := input.Validate(); err != nil { - return nil, fmt.Errorf("pre-validating request payload: %+v", err) + return nil, fmt.Errorf("internal-error: pre-validating request payload: %+v", err) } req, err := c.Client.Client.NewRequest(ctx, input) @@ -66,16 +66,8 @@ func (c *BaseClient) NewRequest(ctx context.Context, input client.RequestOptions } req.URL.RawQuery = query.Encode() - req.RetryFunc = client.RequestRetryAny(storageDefaultRetryFunctions...) + req.RetryFunc = client.RequestRetryAny(append(storageDefaultRetryFunctions, input.RetryFunc)...) req.ValidStatusCodes = input.ExpectedStatusCodes return req, nil } - -func (c *BaseClient) Execute(ctx context.Context, req *client.Request) (*client.Response, error) { - return c.Client.Execute(ctx, req) -} - -func (c *BaseClient) ExecutePaged(ctx context.Context, req *client.Request) (*client.Response, error) { - return c.Client.ExecutePaged(ctx, req) -} diff --git a/sdk/client/request_options.go b/sdk/client/request_options.go index 402b497c0b0..d76a028b4f2 100644 --- a/sdk/client/request_options.go +++ b/sdk/client/request_options.go @@ -28,6 +28,9 @@ type RequestOptions struct { // Path is the absolute URI for this request, with a leading slash. Path string + + // RetryFunc is an optional function to determine whether a request should be automatically retried + RetryFunc RequestRetryFunc } func (ro RequestOptions) Validate() error { diff --git a/sdk/client/resourcemanager/client.go b/sdk/client/resourcemanager/client.go index 104b4c002f2..8c9b67a9c4f 100644 --- a/sdk/client/resourcemanager/client.go +++ b/sdk/client/resourcemanager/client.go @@ -86,7 +86,7 @@ func (c *Client) NewRequest(ctx context.Context, input client.RequestOptions) (* req.URL.RawQuery = query.Encode() req.Pager = input.Pager - req.RetryFunc = client.RequestRetryAny(defaultRetryFunctions...) + req.RetryFunc = client.RequestRetryAny(append(defaultRetryFunctions, input.RetryFunc)...) req.ValidStatusCodes = input.ExpectedStatusCodes return req, nil From 98498b437706dc78de735064f4ef83887a4da4b4 Mon Sep 17 00:00:00 2001 From: Tom Bamford Date: Fri, 23 Feb 2024 11:38:37 +0000 Subject: [PATCH 2/2] goimports --- sdk/client/client_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/client/client_test.go b/sdk/client/client_test.go index ef9ef9fb432..668edf4c841 100644 --- a/sdk/client/client_test.go +++ b/sdk/client/client_test.go @@ -9,7 +9,6 @@ import ( "encoding/json" "encoding/xml" "fmt" - "github.com/hashicorp/go-azure-helpers/lang/pointer" "io" "log" "net/http" @@ -17,6 +16,7 @@ import ( "reflect" "testing" + "github.com/hashicorp/go-azure-helpers/lang/pointer" "github.com/hashicorp/go-azure-sdk/sdk/internal/test" "github.com/hashicorp/go-azure-sdk/sdk/odata" )