diff --git a/attestation/github/github.go b/attestation/github/github.go index c768e3b7..46bce445 100644 --- a/attestation/github/github.go +++ b/attestation/github/github.go @@ -25,7 +25,6 @@ import ( "os" "strings" - "github.com/davecgh/go-spew/spew" "github.com/in-toto/go-witness/attestation" "github.com/in-toto/go-witness/attestation/jwt" "github.com/in-toto/go-witness/cryptoutil" @@ -51,18 +50,22 @@ var ( _ attestation.BackReffer = &Attestor{} ) +// init registers the github attestor. func init() { attestation.RegisterAttestation(Name, Type, RunType, func() attestation.Attestor { return New() }) } +// ErrNotGitlab is an error type that indicates the environment is not a github ci job. type ErrNotGitlab struct{} +// Error returns the error message for ErrNotGitlab. func (e ErrNotGitlab) Error() string { return "not in a github ci job" } +// Attestor is a struct that holds the necessary information for github attestation. type Attestor struct { JWT *jwt.Attestor `json:"jwt,omitempty"` CIConfigPath string `json:"ciconfigpath"` @@ -81,6 +84,7 @@ type Attestor struct { aud string } +// New creates and returns a new github attestor. func New() *Attestor { return &Attestor{ aud: tokenAudience, @@ -89,18 +93,22 @@ func New() *Attestor { } } +// Name returns the name of the attestor. func (a *Attestor) Name() string { return Name } +// Type returns the type of the attestor. func (a *Attestor) Type() string { return Type } +// RunType returns the run type of the attestor. func (a *Attestor) RunType() attestation.RunType { return RunType } +// Attest performs the attestation for the github environment. func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { if os.Getenv("GITHUB_ACTIONS") != "true" { return ErrNotGitlab{} @@ -108,16 +116,16 @@ func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { jwtString, err := fetchToken(a.tokenURL, os.Getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN"), "witness") if err != nil { - return err + return fmt.Errorf("error on fetching token %w", err) } - spew.Dump(jwtString) + if jwtString == "" { + return fmt.Errorf("empty JWT string") + } - if jwtString != "" { - a.JWT = jwt.New(jwt.WithToken(jwtString), jwt.WithJWKSUrl(a.jwksURL)) - if err := a.JWT.Attest(ctx); err != nil { - return err - } + a.JWT = jwt.New(jwt.WithToken(jwtString), jwt.WithJWKSUrl(a.jwksURL)) + if err := a.JWT.Attest(ctx); err != nil { + return fmt.Errorf("failed to attest github jwt: %w", err) } a.CIServerUrl = os.Getenv("GITHUB_SERVER_URL") @@ -134,6 +142,7 @@ func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { return nil } +// Subjects returns a map of subjects and their corresponding digest sets. func (a *Attestor) Subjects() map[string]cryptoutil.DigestSet { subjects := make(map[string]cryptoutil.DigestSet) hashes := []crypto.Hash{crypto.SHA256} @@ -152,6 +161,7 @@ func (a *Attestor) Subjects() map[string]cryptoutil.DigestSet { return subjects } +// BackRefs returns a map of back references and their corresponding digest sets. func (a *Attestor) BackRefs() map[string]cryptoutil.DigestSet { backRefs := make(map[string]cryptoutil.DigestSet) for subj, ds := range a.Subjects() { @@ -164,13 +174,14 @@ func (a *Attestor) BackRefs() map[string]cryptoutil.DigestSet { return backRefs } +// fetchToken fetches the token from the given URL. func fetchToken(tokenURL string, bearer string, audience string) (string, error) { client := &http.Client{} - //add audient "&audience=witness" to the end of the tokenURL, parse it, and then add it to the query + // add audience "&audience=witness" to the end of the tokenURL, parse it, and then add it to the query u, err := url.Parse(tokenURL) if err != nil { - return "", err + return "", fmt.Errorf("error on parsing token url %w", err) } q := u.Query() @@ -181,33 +192,35 @@ func fetchToken(tokenURL string, bearer string, audience string) (string, error) req, err := http.NewRequest("GET", reqURL, nil) if err != nil { - return "", err + return "", fmt.Errorf("error on creating request %w", err) } req.Header.Add("Authorization", "bearer "+bearer) resp, err := client.Do(req) if err != nil { - return "", err + return "", fmt.Errorf("error on request %w", err) } defer resp.Body.Close() body, err := readResponseBody(resp.Body) if err != nil { - return "", err + return "", fmt.Errorf("error on reading response body %w", err) } var tokenResponse GithubTokenResponse err = json.Unmarshal(body, &tokenResponse) if err != nil { - return "", err + return "", fmt.Errorf("error on unmarshaling token response %w", err) } return tokenResponse.Value, nil } +// GithubTokenResponse is a struct that holds the response from the github token request. type GithubTokenResponse struct { Count int `json:"count"` Value string `json:"value"` } +// readResponseBody reads the response body and returns it as a byte slice. func readResponseBody(body io.Reader) ([]byte, error) { var buf bytes.Buffer _, err := buf.ReadFrom(body) diff --git a/attestation/github/github_test.go b/attestation/github/github_test.go new file mode 100644 index 00000000..d2ac0e8f --- /dev/null +++ b/attestation/github/github_test.go @@ -0,0 +1,122 @@ +// Copyright 2021 The Witness Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package github + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func createMockServer() *httptest.Server { + type Response struct { + Count int `json:"count"` + Value string `json:"value"` + } + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/valid" && r.Header.Get("Authorization") == "bearer validBearer" { + resp, _ := json.Marshal(Response{Count: 1, Value: "validJWTToken"}) + _, _ = w.Write(resp) + } else { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + } + })) +} + +func TestFetchToken(t *testing.T) { + testCases := []struct { + name string + tokenURL string + bearer string + audience string + wantToken string + wantErr bool + }{ + { + name: "valid token", + tokenURL: "/valid", + bearer: "validBearer", + audience: "validAudience", + wantToken: "validJWTToken", + wantErr: false, + }, + { + name: "invalid token url", + tokenURL: "/invalid", + bearer: "validBearer", + audience: "validAudience", + wantToken: "", + wantErr: true, + }, + { + name: "invalid bearer", + tokenURL: "/valid", + bearer: "invalidBearer", + audience: "validAudience", + wantToken: "", + wantErr: true, + }, + { + name: "invalid url", + tokenURL: "invalidURL", + bearer: "validBearer", + audience: "validAudience", + wantToken: "", + wantErr: true, + }, + } + + server := createMockServer() + defer server.Close() + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + gotToken, err := fetchToken(server.URL+testCase.tokenURL, testCase.bearer, testCase.audience) + if testCase.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.wantToken, gotToken) + } + }) + } +} + +func TestSubjects(t *testing.T) { + tokenServer := createMockServer() + defer tokenServer.Close() + attestor := &Attestor{ + aud: "projecturl", + jwksURL: tokenServer.URL, + tokenURL: os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL"), + } + + subjects := attestor.Subjects() + assert.NotNil(t, subjects) + assert.Equal(t, 2, len(subjects)) + + expectedSubjects := []string{"pipelineurl:" + attestor.PipelineUrl, "projecturl:" + attestor.ProjectUrl} + for _, expectedSubject := range expectedSubjects { + _, ok := subjects[expectedSubject] + assert.True(t, ok, "Expected subject not found: %s", expectedSubject) + } + m := attestor.BackRefs() + assert.NotNil(t, m) + assert.Equal(t, 1, len(m)) +} diff --git a/attestation/jwt/jwt.go b/attestation/jwt/jwt.go index 1eb2ca12..02986308 100644 --- a/attestation/jwt/jwt.go +++ b/attestation/jwt/jwt.go @@ -93,23 +93,23 @@ func (a *Attestor) Attest(ctx *attestation.AttestationContext) error { parsed, err := jwt.ParseSigned(a.token) if err != nil { - return err + return fmt.Errorf("error parsing token: %w", err) } resp, err := http.Get(a.jwksUrl) if err != nil { - return err + return fmt.Errorf("error fetching jwks: %w", err) } defer resp.Body.Close() jwks := jose.JSONWebKeySet{} decoder := json.NewDecoder(resp.Body) if err := decoder.Decode(&jwks); err != nil { - return err + return fmt.Errorf("error decoding jwks: %w", err) } if err := parsed.Claims(jwks, &a.Claims); err != nil { - return err + return fmt.Errorf("error parsing claims: %w", err) } keyID := ""