Skip to content

Commit

Permalink
Merge pull request #890 from hashicorp/f/binary-unmarshaling
Browse files Browse the repository at this point in the history
Improve binary unmarshalling and support per-request RetryFuncs
  • Loading branch information
manicminer authored Feb 23, 2024
2 parents a429f10 + 98498b4 commit d241b54
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 57 deletions.
60 changes: 35 additions & 25 deletions sdk/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,39 +91,43 @@ type Request struct {
func (r *Request) Marshal(payload interface{}) error {
contentType := strings.ToLower(r.Header.Get("Content-Type"))

if strings.Contains(contentType, "application/json") {
switch {
case strings.Contains(contentType, "application/json"):
body, err := json.Marshal(payload)
if err == nil {
r.ContentLength = int64(len(body))
r.Body = io.NopCloser(bytes.NewReader(body))
}

return nil
}

if strings.Contains(contentType, "application/xml") || strings.Contains(contentType, "text/xml") {
case strings.Contains(contentType, "application/xml") || strings.Contains(contentType, "text/xml"):
body, err := xml.Marshal(payload)
if err == nil {
r.ContentLength = int64(len(body))
r.Body = io.NopCloser(bytes.NewReader(body))
}

return nil
}

if strings.Contains(contentType, "application/octet-stream") || strings.Contains(contentType, "text/powershell") {
switch v := payload.(type) {
case *[]byte:
switch v := payload.(type) {
case *[]byte:
if v == nil {
r.ContentLength = int64(len([]byte{}))
r.Body = io.NopCloser(bytes.NewReader([]byte{}))
} else {
r.ContentLength = int64(len(*v))
r.Body = io.NopCloser(bytes.NewReader(*v))
case []byte:
r.ContentLength = int64(len(v))
r.Body = io.NopCloser(bytes.NewReader(v))
default:
return fmt.Errorf("internal-error: `payload` must be []byte or *[]byte but got type %T", payload)
}
return nil
case []byte:
r.ContentLength = int64(len(v))
r.Body = io.NopCloser(bytes.NewReader(v))
default:
return fmt.Errorf("internal-error: `payload` must be []byte or *[]byte but got type %T", payload)
}

return fmt.Errorf("internal-error: unimplemented marshal function for content type %q", contentType)
return nil
}

// Execute invokes the Execute method for the Request's Client
Expand Down Expand Up @@ -173,12 +177,14 @@ func (r *Response) Unmarshal(model interface{}) error {
contentType = strings.ToLower(r.Request.Header.Get("Content-Type"))
}
}
// the maintenance API returns a 200 for a delete with no content-type and no content length so we should skip
// trying to unmarshal this

// Some APIs (e.g. Maintenance) return 200 without a body, don't unmarshal these
if r.ContentLength == 0 && (r.Body == nil || r.Body == http.NoBody) {
return nil
}
if strings.Contains(contentType, "application/json") {

switch {
case strings.Contains(contentType, "application/json"):
// Read the response body and close it
respBody, err := io.ReadAll(r.Body)
if err != nil {
Expand All @@ -201,10 +207,10 @@ func (r *Response) Unmarshal(model interface{}) error {

// Reassign the response body as downstream code may expect it
r.Body = io.NopCloser(bytes.NewBuffer(respBody))

return nil
}

if strings.Contains(contentType, "application/xml") || strings.Contains(contentType, "text/xml") {
case strings.Contains(contentType, "application/xml") || strings.Contains(contentType, "text/xml"):
// Read the response body and close it
respBody, err := io.ReadAll(r.Body)
if err != nil {
Expand All @@ -227,13 +233,13 @@ func (r *Response) Unmarshal(model interface{}) error {

// Reassign the response body as downstream code may expect it
r.Body = io.NopCloser(bytes.NewBuffer(respBody))

return nil
}

if strings.Contains(contentType, "application/octet-stream") || strings.Contains(contentType, "text/powershell") {
ptr, ok := model.(**[]byte)
case strings.Contains(contentType, "application/octet-stream") || strings.Contains(contentType, "text/powershell"):
ptr, ok := model.(*[]byte)
if !ok || ptr == nil {
return fmt.Errorf("internal-error: `model` must be a non-nil `**[]byte` but got %+v", model)
return fmt.Errorf("internal-error: `model` must be a non-nil `*[]byte` but got %[1]T: %+[1]v", model)
}

// Read the response body and close it
Expand All @@ -243,14 +249,17 @@ func (r *Response) Unmarshal(model interface{}) error {
}
r.Body.Close()

// Trim away a BOM if present
respBody = bytes.TrimPrefix(respBody, []byte("\xef\xbb\xbf"))
if strings.HasPrefix(contentType, "text/") {
// Trim away a BOM if present
respBody = bytes.TrimPrefix(respBody, []byte("\xef\xbb\xbf"))
}

// copy the byte stream across
*ptr = &respBody
*ptr = respBody

// Reassign the response body as downstream code may expect it
r.Body = io.NopCloser(bytes.NewBuffer(respBody))

return nil
}

Expand Down Expand Up @@ -374,6 +383,7 @@ func (c *Client) NewRequest(ctx context.Context, input RequestOptions) (*Request
Client: c,
Request: req,
Pager: input.Pager,
RetryFunc: input.RetryFunc,
ValidStatusCodes: input.ExpectedStatusCodes,
}

Expand Down
12 changes: 7 additions & 5 deletions sdk/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,12 @@ func TestUnmarshalByteStreamAndPowerShell(t *testing.T) {
Body: io.NopCloser(bytes.NewReader(expected)),
},
}
var unmarshaled = pointer.To(make([]byte, 0))
var unmarshaled = make([]byte, 0)
if err := r.Unmarshal(&unmarshaled); err != nil {
t.Fatalf("unmarshaling: %+v", err)
}
if string(*unmarshaled) != "you serve butter" {
t.Fatalf("unexpected difference in decoded objects. Expected %q\n\nGot: %q", string(expected), string(*unmarshaled))
if string(unmarshaled) != "you serve butter" {
t.Fatalf("unexpected difference in decoded objects. Expected %q\n\nGot: %q", string(expected), string(unmarshaled))
}
}
}
Expand All @@ -373,7 +373,9 @@ func TestUnmarshalByteStreamAndPowerShellWithModel(t *testing.T) {
var respModel = struct {
HttpResponse *http.Response
Model *[]byte
}{}
}{
Model: pointer.To(make([]byte, 0)),
}
expected := []byte("you serve butter")
for _, contentType := range contentTypes {
r := &Response{
Expand All @@ -384,7 +386,7 @@ func TestUnmarshalByteStreamAndPowerShellWithModel(t *testing.T) {
Body: io.NopCloser(bytes.NewReader(expected)),
},
}
if err := r.Unmarshal(&respModel.Model); err != nil {
if err := r.Unmarshal(respModel.Model); err != nil {
t.Fatalf("unmarshaling: %+v", err)
}
if string(*respModel.Model) != "you serve butter" {
Expand Down
10 changes: 0 additions & 10 deletions sdk/client/dataplane/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
package dataplane

import (
"context"

"github.com/hashicorp/go-azure-sdk/sdk/client"
)

Expand All @@ -27,11 +25,3 @@ func NewDataPlaneClient(baseUri string, serviceName, apiVersion string) *Client
}
return client
}

func (c *Client) Execute(ctx context.Context, req *client.Request) (*client.Response, error) {
return c.Client.Execute(ctx, req)
}

func (c *Client) ExecutePaged(ctx context.Context, req *client.Request) (*client.Response, error) {
return c.Client.ExecutePaged(ctx, req)
}
24 changes: 8 additions & 16 deletions sdk/client/dataplane/storage/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,31 @@ import (
"github.com/hashicorp/go-azure-sdk/sdk/client/dataplane"
)

var _ client.BaseClient = &BaseClient{}
var _ client.BaseClient = &Client{}

var storageDefaultRetryFunctions = []client.RequestRetryFunc{
// TODO: stuff n tings
}

type BaseClient struct {
type Client struct {
*dataplane.Client
}

func NewBaseClient(baseUri string, componentName, apiVersion string) (*BaseClient, error) {
func NewStorageClient(baseUri string, componentName, apiVersion string) (*Client, error) {
// NOTE: both the domain name _and_ the domain format can change entirely depending on the type of storage account being used
// when provisioned in an edge zone, and when AzureDNSZone is used, as such we require the baseUri is provided here
return &BaseClient{
return &Client{
Client: dataplane.NewDataPlaneClient(baseUri, fmt.Sprintf("storage/%s", componentName), apiVersion),
}, nil
}

func (c *BaseClient) NewRequest(ctx context.Context, input client.RequestOptions) (*client.Request, error) {
func (c *Client) NewRequest(ctx context.Context, input client.RequestOptions) (*client.Request, error) {
// TODO move these validations to base client method
if _, ok := ctx.Deadline(); !ok {
return nil, fmt.Errorf("the context used must have a deadline attached for polling purposes, but got no deadline")
return nil, fmt.Errorf("internal-error: the context used must have a deadline attached for polling purposes, but got no deadline")
}
if err := input.Validate(); err != nil {
return nil, fmt.Errorf("pre-validating request payload: %+v", err)
return nil, fmt.Errorf("internal-error: pre-validating request payload: %+v", err)
}

req, err := c.Client.Client.NewRequest(ctx, input)
Expand Down Expand Up @@ -66,16 +66,8 @@ func (c *BaseClient) NewRequest(ctx context.Context, input client.RequestOptions
}

req.URL.RawQuery = query.Encode()
req.RetryFunc = client.RequestRetryAny(storageDefaultRetryFunctions...)
req.RetryFunc = client.RequestRetryAny(append(storageDefaultRetryFunctions, input.RetryFunc)...)
req.ValidStatusCodes = input.ExpectedStatusCodes

return req, nil
}

func (c *BaseClient) Execute(ctx context.Context, req *client.Request) (*client.Response, error) {
return c.Client.Execute(ctx, req)
}

func (c *BaseClient) ExecutePaged(ctx context.Context, req *client.Request) (*client.Response, error) {
return c.Client.ExecutePaged(ctx, req)
}
3 changes: 3 additions & 0 deletions sdk/client/request_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ type RequestOptions struct {

// Path is the absolute URI for this request, with a leading slash.
Path string

// RetryFunc is an optional function to determine whether a request should be automatically retried
RetryFunc RequestRetryFunc
}

func (ro RequestOptions) Validate() error {
Expand Down
2 changes: 1 addition & 1 deletion sdk/client/resourcemanager/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (c *Client) NewRequest(ctx context.Context, input client.RequestOptions) (*

req.URL.RawQuery = query.Encode()
req.Pager = input.Pager
req.RetryFunc = client.RequestRetryAny(defaultRetryFunctions...)
req.RetryFunc = client.RequestRetryAny(append(defaultRetryFunctions, input.RetryFunc)...)
req.ValidStatusCodes = input.ExpectedStatusCodes

return req, nil
Expand Down

0 comments on commit d241b54

Please sign in to comment.