From d40829db4e235f7dfe5d7f852fc894b9c8fd92f6 Mon Sep 17 00:00:00 2001 From: naveensrinivasan <172697+naveensrinivasan@users.noreply.github.com> Date: Thu, 26 Oct 2023 13:10:26 -0500 Subject: [PATCH] Included tests for GitHub attestations - Included tests for GitHub attestations and some simple clean up. Signed-off-by: naveensrinivasan <172697+naveensrinivasan@users.noreply.github.com> --- attestation/github/github.go | 30 ++++-- attestation/github/github_test.go | 154 ++++++++++++++++++++++++++++++ attestation/jwt/jwt.go | 8 +- 3 files changed, 180 insertions(+), 12 deletions(-) create mode 100644 attestation/github/github_test.go diff --git a/attestation/github/github.go b/attestation/github/github.go index ef394210..cf0298f6 100644 --- a/attestation/github/github.go +++ b/attestation/github/github.go @@ -51,18 +51,22 @@ var ( _ attestation.BackReffer = &Attestor{} ) +// init registers the github attestor. func init() { attestation.RegisterAttestation(Name, Type, RunType, func() attestation.Attestor { return New() }) } +// ErrNotGitlab is an error type that indicates the environment is not a github ci job. type ErrNotGitlab struct{} +// Error returns the error message for ErrNotGitlab. func (e ErrNotGitlab) Error() string { return "not in a github ci job" } +// Attestor is a struct that holds the necessary information for github attestation. type Attestor struct { JWT *jwt.Attestor `json:"jwt,omitempty"` CIConfigPath string `json:"ciconfigpath"` @@ -81,6 +85,7 @@ type Attestor struct { aud string } +// New creates and returns a new github attestor. func New() *Attestor { return &Attestor{ aud: tokenAudience, @@ -89,18 +94,22 @@ func New() *Attestor { } } +// Name returns the name of the attestor. func (a *Attestor) Name() string { return Name } +// Type returns the type of the attestor. func (a *Attestor) Type() string { return Type } +// RunType returns the run type of the attestor. func (a *Attestor) RunType() attestation.RunType { return RunType } +// Attest performs the attestation for the github environment. func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { if os.Getenv("GITHUB_ACTIONS") != "true" { return ErrNotGitlab{} @@ -108,7 +117,7 @@ func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { jwtString, err := fetchToken(a.tokenURL, os.Getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN"), "witness") if err != nil { - return err + return fmt.Errorf("error on fething token %w", err) } spew.Dump(jwtString) @@ -116,7 +125,7 @@ func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { if jwtString != "" { a.JWT = jwt.New(jwt.WithToken(jwtString), jwt.WithJWKSUrl(a.jwksURL)) if err := a.JWT.Attest(ctx); err != nil { - return err + return fmt.Errorf("error on attesting jwt %w", err) } } @@ -134,6 +143,7 @@ func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { return nil } +// Subjects returns a map of subjects and their corresponding digest sets. func (a *Attestor) Subjects() map[string]cryptoutil.DigestSet { subjects := make(map[string]cryptoutil.DigestSet) hashes := []crypto.Hash{crypto.SHA256} @@ -152,6 +162,7 @@ func (a *Attestor) Subjects() map[string]cryptoutil.DigestSet { return subjects } +// BackRefs returns a map of back references and their corresponding digest sets. func (a *Attestor) BackRefs() map[string]cryptoutil.DigestSet { backRefs := make(map[string]cryptoutil.DigestSet) for subj, ds := range a.Subjects() { @@ -164,13 +175,14 @@ func (a *Attestor) BackRefs() map[string]cryptoutil.DigestSet { return backRefs } +// fetchToken fetches the token from the given URL. func fetchToken(tokenURL string, bearer string, audience string) (string, error) { client := &http.Client{} - //add audient "&audience=witness" to the end of the tokenURL, parse it, and then add it to the query + // add audience "&audience=witness" to the end of the tokenURL, parse it, and then add it to the query u, err := url.Parse(tokenURL) if err != nil { - return "", err + return "", fmt.Errorf("error on parsing token url %w", err) } q := u.Query() @@ -181,33 +193,35 @@ func fetchToken(tokenURL string, bearer string, audience string) (string, error) req, err := http.NewRequest("GET", reqURL, nil) if err != nil { - return "", err + return "", fmt.Errorf("error on creating request %w", err) } req.Header.Add("Authorization", "bearer "+bearer) resp, err := client.Do(req) if err != nil { - return "", err + return "", fmt.Errorf("error on request %w", err) } defer resp.Body.Close() body, err := readResponseBody(resp.Body) if err != nil { - return "", err + return "", fmt.Errorf("error on reading response body %w", err) } var tokenResponse GithubTokenResponse err = json.Unmarshal(body, &tokenResponse) if err != nil { - return "", err + return "", fmt.Errorf("error on unmarshaling token response %w", err) } return tokenResponse.Value, nil } +// GithubTokenResponse is a struct that holds the response from the github token request. type GithubTokenResponse struct { Count int `json:"count"` Value string `json:"value"` } +// readResponseBody reads the response body and returns it as a byte slice. func readResponseBody(body io.Reader) ([]byte, error) { var buf bytes.Buffer _, err := buf.ReadFrom(body) diff --git a/attestation/github/github_test.go b/attestation/github/github_test.go new file mode 100644 index 00000000..a09ee3c5 --- /dev/null +++ b/attestation/github/github_test.go @@ -0,0 +1,154 @@ +// Copyright 2021 The Witness Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package github + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/testifysec/go-witness/attestation" +) + +func createMockServer() *httptest.Server { + type Response struct { + Count int `json:"count"` + Value string `json:"value"` + } + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/valid" && r.Header.Get("Authorization") == "bearer validBearer" { + resp, _ := json.Marshal(Response{Count: 1, Value: "validJWTToken"}) + _, _ = w.Write(resp) + } else { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + } + })) +} + +func createTokenServer() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/valid" && r.Header.Get("Authorization") == "bearer validBearer" { + w.Write([]byte(`{"protected": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + "payload": "eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ", + "signature": "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"}`)) + } else { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + } + })) +} + +func TestFetchToken(t *testing.T) { + testCases := []struct { + name string + tokenURL string + bearer string + audience string + wantToken string + wantErr bool + }{ + { + name: "valid token", + tokenURL: "/valid", + bearer: "validBearer", + audience: "validAudience", + wantToken: "validJWTToken", + wantErr: false, + }, + { + name: "invalid token url", + tokenURL: "/invalid", + bearer: "validBearer", + audience: "validAudience", + wantToken: "", + wantErr: true, + }, + { + name: "invalid bearer", + tokenURL: "/valid", + bearer: "invalidBearer", + audience: "validAudience", + wantToken: "", + wantErr: true, + }, + { + name: "invalid url", + tokenURL: "invalidURL", + bearer: "validBearer", + audience: "validAudience", + wantToken: "", + wantErr: true, + }, + } + + server := createMockServer() + defer server.Close() + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + gotToken, err := fetchToken(server.URL+testCase.tokenURL, testCase.bearer, testCase.audience) + if testCase.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.wantToken, gotToken) + } + }) + } +} + +func TestAttestorAttest(t *testing.T) { + tokenServer := createTokenServer() + defer tokenServer.Close() + t.Setenv("GITHUB_ACTIONS", "true") + t.Setenv("ACTIONS_ID_TOKEN_REQUEST_URL", tokenServer.URL+"/valid") + t.Setenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN", "validBearer") + + attestor := &Attestor{ + aud: tokenAudience, + jwksURL: tokenServer.URL, + tokenURL: os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL"), + } + + ctx := &attestation.AttestationContext{} + + err := attestor.Attest(ctx) + assert.NoError(t, err) +} + +func TestSubjects(t *testing.T) { + tokenServer := createTokenServer() + defer tokenServer.Close() + attestor := &Attestor{ + aud: "projecturl", + jwksURL: tokenServer.URL, + tokenURL: os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL"), + } + + subjects := attestor.Subjects() + assert.NotNil(t, subjects) + assert.Equal(t, 2, len(subjects)) + + expectedSubjects := []string{"pipelineurl:" + attestor.PipelineUrl, "projecturl:" + attestor.ProjectUrl} + for _, expectedSubject := range expectedSubjects { + _, ok := subjects[expectedSubject] + assert.True(t, ok, "Expected subject not found: %s", expectedSubject) + } + m := attestor.BackRefs() + assert.NotNil(t, m) + assert.Equal(t, 1, len(m)) +} diff --git a/attestation/jwt/jwt.go b/attestation/jwt/jwt.go index 56952543..5229eda7 100644 --- a/attestation/jwt/jwt.go +++ b/attestation/jwt/jwt.go @@ -93,23 +93,23 @@ func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { parsed, err := jwt.ParseSigned(a.token) if err != nil { - return err + return fmt.Errorf("error parsing token: %w", err) } resp, err := http.Get(a.jwksUrl) if err != nil { - return err + return fmt.Errorf("error fetching jwks: %w", err) } defer resp.Body.Close() jwks := jose.JSONWebKeySet{} decoder := json.NewDecoder(resp.Body) if err := decoder.Decode(&jwks); err != nil { - return err + return fmt.Errorf("error decoding jwks: %w", err) } if err := parsed.Claims(jwks, &a.Claims); err != nil { - return err + return fmt.Errorf("error parsing claims: %w", err) } keyID := ""