From febd18b1a2fe1c6007b56c74bf945c784f95f305 Mon Sep 17 00:00:00 2001 From: raul Date: Thu, 11 Jan 2024 15:15:14 +0100 Subject: [PATCH] Add azure basic auth verification --- azuredevops/azuredevops.go | 46 ++++++++++++++++++++++++++-- azuredevops/azuredevops_test.go | 53 +++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 3 deletions(-) diff --git a/azuredevops/azuredevops.go b/azuredevops/azuredevops.go index 8bb0ddf..460cae4 100644 --- a/azuredevops/azuredevops.go +++ b/azuredevops/azuredevops.go @@ -13,8 +13,9 @@ import ( // parse errors var ( - ErrInvalidHTTPMethod = errors.New("invalid HTTP Method") - ErrParsingPayload = errors.New("error parsing payload") + ErrInvalidHTTPMethod = errors.New("invalid HTTP Method") + ErrParsingPayload = errors.New("error parsing payload") + ErrBasicAuthVerificationFailed = errors.New("basic auth verification failed") ) // Event defines an Azure DevOps server hook event type @@ -29,13 +30,38 @@ const ( GitPushEventType Event = "git.push" ) +// Option is a configuration option for the webhook +type Option func(*Webhook) error + +// Options is a namespace var for configuration options +var Options = WebhookOptions{} + +// WebhookOptions is a namespace for configuration option methods +type WebhookOptions struct{} + +// BasicAuth verifies payload using basic auth +func (WebhookOptions) BasicAuth(username, password string) Option { + return func(hook *Webhook) error { + hook.username = username + hook.password = password + return nil + } +} + // Webhook instance contains all methods needed to process events type Webhook struct { + username string + password string } // New creates and returns a WebHook instance -func New() (*Webhook, error) { +func New(options ...Option) (*Webhook, error) { hook := new(Webhook) + for _, opt := range options { + if err := opt(hook); err != nil { + return nil, errors.New("Error applying Option") + } + } return hook, nil } @@ -46,6 +72,10 @@ func (hook Webhook) Parse(r *http.Request, events ...Event) (interface{}, error) _ = r.Body.Close() }() + if !hook.verifyBasicAuth(r) { + return nil, ErrBasicAuthVerificationFailed + } + if r.Method != http.MethodPost { return nil, ErrInvalidHTTPMethod } @@ -78,3 +108,13 @@ func (hook Webhook) Parse(r *http.Request, events ...Event) (interface{}, error) return nil, fmt.Errorf("unknown event %s", pl.EventType) } } + +func (hook Webhook) verifyBasicAuth(r *http.Request) bool { + // skip validation if username or password was not provided + if hook.username == "" && hook.password == "" { + return true + } + username, password, ok := r.BasicAuth() + + return ok && username == hook.username && password == hook.password +} diff --git a/azuredevops/azuredevops_test.go b/azuredevops/azuredevops_test.go index 1ca7591..1a6a9c3 100644 --- a/azuredevops/azuredevops_test.go +++ b/azuredevops/azuredevops_test.go @@ -1,6 +1,9 @@ package azuredevops import ( + "bytes" + "fmt" + "github.com/stretchr/testify/assert" "log" "net/http" "net/http/httptest" @@ -117,3 +120,53 @@ func TestWebhooks(t *testing.T) { }) } } + +func TestParseBasicAuth(t *testing.T) { + const validUser = "validUser" + const validPass = "pass123" + tests := []struct { + name string + webhookUser string + webhookPass string + reqUser string + reqPass string + expectedErr error + }{ + { + name: "valid basic auth", + webhookUser: validUser, + webhookPass: validPass, + reqUser: validUser, + reqPass: validPass, + expectedErr: fmt.Errorf("unknown event "), // no event passed, so this is expected + }, + { + name: "no basic auth provided", + expectedErr: fmt.Errorf("unknown event "), // no event passed, so this is expected + }, + { + name: "invalid basic auth", + webhookUser: validUser, + webhookPass: validPass, + reqUser: "fakeUser", + reqPass: "fakePass", + expectedErr: ErrBasicAuthVerificationFailed, + }, + } + + for _, tt := range tests { + h := Webhook{ + username: tt.webhookUser, + password: tt.webhookPass, + } + body := []byte(`{}`) + r, err := http.NewRequest(http.MethodPost, "", bytes.NewBuffer(body)) + assert.NoError(t, err) + r.SetBasicAuth(tt.reqUser, tt.reqPass) + + p, err := h.Parse(r) + + assert.Equal(t, err, tt.expectedErr) + assert.Nil(t, p) + } +}