diff --git a/CHANGELOG.md b/CHANGELOG.md index ec1cc74..c42c762 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +## [1.3.0] - 2024-01-22 + +### Added + +- Added support to override default middleware with function `GetDefaultMiddlewaresWithOptions`. + ## [1.2.1] - 2023-01-22 ### Changed diff --git a/headers_inspection_handler_test.go b/headers_inspection_handler_test.go index 34f5acd..1f4bb5f 100644 --- a/headers_inspection_handler_test.go +++ b/headers_inspection_handler_test.go @@ -19,7 +19,7 @@ func TestHeadersInspectionOptionsImplementTheOptionInterface(t *testing.T) { options := NewHeadersInspectionOptions() assert.NotNil(t, options) _, ok := any(options).(abs.RequestOption) - assert.True(t, ok, "options does not implement RequestOption") + assert.True(t, ok, "options does not implement optionsType") } func TestItGetsRequestHeaders(t *testing.T) { diff --git a/kiota_client_factory.go b/kiota_client_factory.go index 9ef2f68..94126c2 100644 --- a/kiota_client_factory.go +++ b/kiota_client_factory.go @@ -3,6 +3,8 @@ package nethttplibrary import ( + "errors" + abs "github.com/microsoft/kiota-abstractions-go" nethttp "net/http" "net/url" "time" @@ -76,12 +78,76 @@ func getDefaultClientWithoutMiddleware() *nethttp.Client { // GetDefaultMiddlewares creates a new default set of middlewares for the Kiota request adapter func GetDefaultMiddlewares() []Middleware { - return []Middleware{ - NewRetryHandler(), - NewRedirectHandler(), - NewCompressionHandler(), - NewParametersNameDecodingHandler(), - NewUserAgentHandler(), - NewHeadersInspectionHandler(), + return getDefaultMiddleWare(make(map[abs.RequestOptionKey]Middleware)) +} + +// GetDefaultMiddlewaresWithOptions creates a new default set of middlewares for the Kiota request adapter with options +func GetDefaultMiddlewaresWithOptions(requestOptions ...abs.RequestOption) ([]Middleware, error) { + if len(requestOptions) == 0 { + return GetDefaultMiddlewares(), nil + } + + // map of middleware options + middlewareMap := make(map[abs.RequestOptionKey]Middleware) + + for _, element := range requestOptions { + switch v := element.(type) { + case *RetryHandlerOptions: + middlewareMap[retryKeyValue] = NewRetryHandlerWithOptions(*v) + case *RedirectHandlerOptions: + middlewareMap[redirectKeyValue] = NewRedirectHandlerWithOptions(*v) + case *CompressionOptions: + middlewareMap[compressKey] = NewCompressionHandlerWithOptions(*v) + case *ParametersNameDecodingOptions: + middlewareMap[parametersNameDecodingKeyValue] = NewParametersNameDecodingHandlerWithOptions(*v) + case *UserAgentHandlerOptions: + middlewareMap[userAgentKeyValue] = NewUserAgentHandlerWithOptions(v) + case *HeadersInspectionOptions: + middlewareMap[headersInspectionKeyValue] = NewHeadersInspectionHandlerWithOptions(*v) + default: + // none of the above types + return nil, errors.New("unsupported option type") + } + } + + middleware := getDefaultMiddleWare(middlewareMap) + return middleware, nil +} + +// getDefaultMiddleWare creates a new default set of middlewares for the Kiota request adapter +func getDefaultMiddleWare(middlewareMap map[abs.RequestOptionKey]Middleware) []Middleware { + middlewareSource := map[abs.RequestOptionKey]func() Middleware{ + retryKeyValue: func() Middleware { + return NewRetryHandler() + }, + redirectKeyValue: func() Middleware { + return NewRedirectHandler() + }, + compressKey: func() Middleware { + return NewCompressionHandler() + }, + parametersNameDecodingKeyValue: func() Middleware { + return NewParametersNameDecodingHandler() + }, + userAgentKeyValue: func() Middleware { + return NewUserAgentHandler() + }, + headersInspectionKeyValue: func() Middleware { + return NewHeadersInspectionHandler() + }, } + + // loop over middlewareSource and add any middleware that wasn't provided in the requestOptions + for key, value := range middlewareSource { + if _, ok := middlewareMap[key]; !ok { + middlewareMap[key] = value() + } + } + + var middleware []Middleware + for _, value := range middlewareMap { + middleware = append(middleware, value) + } + + return middleware } diff --git a/kiota_client_factory_test.go b/kiota_client_factory_test.go new file mode 100644 index 0000000..df5c741 --- /dev/null +++ b/kiota_client_factory_test.go @@ -0,0 +1,99 @@ +package nethttplibrary + +import ( + abstractions "github.com/microsoft/kiota-abstractions-go" + "github.com/stretchr/testify/assert" + nethttp "net/http" + "testing" + "time" +) + +func TestGetDefaultMiddleWareWithMultipleOptions(t *testing.T) { + retryOptions := RetryHandlerOptions{ + ShouldRetry: func(delay time.Duration, executionCount int, request *nethttp.Request, response *nethttp.Response) bool { + return false + }, + } + redirectHandlerOptions := RedirectHandlerOptions{ + MaxRedirects: defaultMaxRedirects, + ShouldRedirect: func(req *nethttp.Request, res *nethttp.Response) bool { + return true + }, + } + compressionOptions := NewCompressionOptions(false) + parametersNameDecodingOptions := ParametersNameDecodingOptions{ + Enable: true, + ParametersToDecode: []byte{'-', '.', '~', '$'}, + } + userAgentHandlerOptions := UserAgentHandlerOptions{ + Enabled: true, + ProductName: "kiota-go", + ProductVersion: "1.1.0", + } + headersInspectionOptions := HeadersInspectionOptions{ + RequestHeaders: abstractions.NewRequestHeaders(), + ResponseHeaders: abstractions.NewResponseHeaders(), + } + options, err := GetDefaultMiddlewaresWithOptions(&retryOptions, + &redirectHandlerOptions, + &compressionOptions, + ¶metersNameDecodingOptions, + &userAgentHandlerOptions, + &headersInspectionOptions, + ) + if err != nil { + t.Errorf(err.Error()) + } + if len(options) != 6 { + t.Errorf("expected 6 middleware, got %v", len(options)) + } + + for _, element := range options { + switch v := element.(type) { + case *CompressionHandler: + assert.Equal(t, v.options.ShouldCompress(), compressionOptions.ShouldCompress()) + } + } +} + +func TestGetDefaultMiddleWareWithInvalidOption(t *testing.T) { + chaosOptions := ChaosHandlerOptions{ + ChaosPercentage: 101, + ChaosStrategy: Random, + } + _, err := GetDefaultMiddlewaresWithOptions(&chaosOptions) + + assert.Equal(t, err.Error(), "unsupported option type") +} + +func TestGetDefaultMiddleWareWithOptions(t *testing.T) { + compression := NewCompressionOptions(false) + options, err := GetDefaultMiddlewaresWithOptions(&compression) + if err != nil { + t.Errorf(err.Error()) + } + if len(options) != 6 { + t.Errorf("expected 6 middleware, got %v", len(options)) + } + + for _, element := range options { + switch v := element.(type) { + case *CompressionHandler: + assert.Equal(t, v.options.ShouldCompress(), compression.ShouldCompress()) + } + } +} + +func TestGetDefaultMiddlewares(t *testing.T) { + options := GetDefaultMiddlewares() + if len(options) != 6 { + t.Errorf("expected 6 middleware, got %v", len(options)) + } + + for _, element := range options { + switch v := element.(type) { + case *CompressionHandler: + assert.True(t, v.options.ShouldCompress()) + } + } +}