diff --git a/bitbucket/webhook.go b/bitbucket/webhook.go index 85021ec..6e11ef8 100644 --- a/bitbucket/webhook.go +++ b/bitbucket/webhook.go @@ -19,15 +19,29 @@ const ( const maxPayloadSize = 10 * 1024 * 1024 // 10 MiB -func ParsePayload(r *http.Request, key []byte) (interface{}, error) { - p, err := validateSignature(r, key) +func ParsePayload(r *http.Request, key []byte) (interface{}, []byte, error) { + event, payload, err := ParsePayloadWithoutSignature(r) if err != nil { - return nil, err + return nil, nil, err + } + + err = ValidateSignature(r, payload, key) + if err != nil { + return nil, nil, err + } + + return event, payload, nil +} + +func ParsePayloadWithoutSignature(r *http.Request) (interface{}, []byte, error) { + payload, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize)) + if err != nil { + return nil, nil, fmt.Errorf("unable to parse payload: %w", err) } evk := r.Header.Get(EventKeyHeader) if evk == "" { - return nil, fmt.Errorf("unable find event key in request") + return nil, nil, fmt.Errorf("unable find event key in request") } k := EventKey(evk) var event interface{} @@ -37,48 +51,43 @@ func ParsePayload(r *http.Request, key []byte) (interface{}, error) { case EventKeyPullRequestOpened, EventKeyPullRequestFrom, EventkeyPullRequestModified, EventKeyPullRequestDeclined, EventKeyPullRequestDeleted, EventKeyPullRequestMerged: event = &PullRequestEvent{} default: - return nil, fmt.Errorf("event type not supported: %s", k) + return nil, nil, fmt.Errorf("event type not supported: %s", k) } - err = json.Unmarshal(p, event) + err = json.Unmarshal(payload, event) if err != nil { - return nil, fmt.Errorf("unable to parse event payload: %w", err) + return nil, nil, fmt.Errorf("unable to parse event payload: %w", err) } - return event, nil + return event, payload, nil } -func validateSignature(r *http.Request, key []byte) ([]byte, error) { +func ValidateSignature(r *http.Request, payload []byte, key []byte) error { sig := r.Header.Get(EventSignatureHeader) if sig == "" { - return nil, fmt.Errorf("no signature found") - } - - payload, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize)) - if err != nil { - return nil, fmt.Errorf("unable to parse payload: %w", err) + return fmt.Errorf("no signature found") } sp := strings.Split(sig, "=") if len(sp) != 2 { - return nil, fmt.Errorf("signatur format invalid") + return fmt.Errorf("signatur format invalid") } if sp[0] != "sha256" { - return nil, fmt.Errorf("unsupported hash algorithm: %s", sp[0]) + return fmt.Errorf("unsupported hash algorithm: %s", sp[0]) } sd, err := hex.DecodeString(sp[1]) if err != nil { - return nil, fmt.Errorf("unable to parse signature data: %w", err) + return fmt.Errorf("unable to parse signature data: %w", err) } h := hmac.New(sha256.New, key) h.Write([]byte(payload)) if !hmac.Equal(h.Sum(nil), sd) { - return nil, fmt.Errorf("signature does not match") + return fmt.Errorf("signature does not match") } - return payload, nil + return nil } diff --git a/bitbucket/webhook_test.go b/bitbucket/webhook_test.go index cd47523..9e1bbc5 100644 --- a/bitbucket/webhook_test.go +++ b/bitbucket/webhook_test.go @@ -17,11 +17,28 @@ func TestWebhook(t *testing.T) { req.Header.Add(EventKeyHeader, "repo:refs_changed") assert.NoError(t, err) - ev, err := ParsePayload(req, secretKey) + ev, payload, err := ParsePayload(req, secretKey) assert.NoError(t, err) - assert.NotNil(t, ev) + assert.NotNil(t, payload) - repoEv, ok := ev.(*RepositoryPushEvent) - assert.True(t, ok) - assert.Equal(t, "rep_1", repoEv.Repository.Slug) + if assert.NotNil(t, ev) { + repoEv, ok := ev.(*RepositoryPushEvent) + assert.True(t, ok) + assert.Equal(t, "rep_1", repoEv.Repository.Slug) + } +} + +func TestWebhookFailingSignature(t *testing.T) { + secretKey := []byte("abcdef0123456789") + + buf := bytes.NewBufferString(repoPushEvent01) + req, err := http.NewRequest("POST", "http://server.io/webhook", buf) + req.Header.Add(EventSignatureHeader, "sha256=d82c0422a140fc24335536d9450538aeaa978dbc741262a161ee12b99a6bf05d") + req.Header.Add(EventKeyHeader, "repo:refs_changed") + assert.NoError(t, err) + + ev, payload, err := ParsePayload(req, secretKey) + assert.Error(t, err) + assert.Nil(t, ev) + assert.Nil(t, payload) }