Skip to content

Commit

Permalink
oidc/oidctest: add new package
Browse files Browse the repository at this point in the history
  • Loading branch information
ericchiang committed Nov 29, 2023
1 parent 6d6be43 commit 38beddf
Show file tree
Hide file tree
Showing 2 changed files with 291 additions and 0 deletions.
213 changes: 213 additions & 0 deletions oidc/oidctest/oidctest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
// Package oidctest implements a test OpenID Connect server.
//
// For convinence, methods in this package panic rather than returning errors.
// This package is NOT suitable for use outside of tests.
//
// This package is primarily intended to be used with the standard library's
// [net/http/httpttest] package. Users should configure a key pair and setup
// a server:
//
// priv, err := rsa.GenerateKey(rand.Reader, 2048)
// if err != nil {
// // ...
// }
// s := &oidctest.Server{
// PublicKeys: []oidctest.PublicKey{
// {
// PublicKey: priv.Public(),
// KeyID: "my-key-id",
// Algorithm: oidc.ES256,
// },
// },
// }
// srv := httptest.NewServer(s)
// defer srv.Close()
// s.SetIssuer(srv.URL)
//
// Then sign a token:
//
// rawClaims := `{
// "iss": "` + srv.URL + `",
// "aud": "my-client-id",
// "sub": "foo",
// "email": "[email protected]",
// "email_verified": true
// }`
// token := oidctest.SignIDToken(priv, "my-key-id", oidc.RS256, rawClaims)
//
// And finaly, verify through the oidc package:
//
// ctx := context.Background()
// p, err := oidc.NewProvider(ctx, srv.URL)
// if err != nil {
// // ...
// }
// config := &oidc.Config{
// ClientID: "my-client-id",
// SkipExpiryCheck: true,
// }
// v := p.VerifierContext(ctx, config)
//
// idToken, err := v.Verify(ctx, token)
// if err != nil {
// // ...
// }
package oidctest

import (
"crypto"
"encoding/json"
"fmt"
"net/http"
"net/url"

jose "github.com/go-jose/go-jose/v3"
)

// SignIDToken uses a private key to sign provided claims.
//
// A minimal set of claims may look like:
//
// rawClaims := `{
// "iss": "` + srv.URL + `",
// "aud": "my-client-id",
// "sub": "foo",
// "exp": ` + strconv.FormatInt(time.Now().Add(time.Hour).Unix(), 10) + `,
// "email": "[email protected]",
// "email_verified": true
// }`
// token := oidctest.SignIDToken(priv, "my-key-id", oidc.RS256, rawClaims)
func SignIDToken(priv crypto.PrivateKey, keyID, alg, claims string) string {
token, err := newToken(priv, keyID, alg, claims)
if err != nil {
panic("oidctest: generating token: " + err.Error())
}
return token
}

func newToken(priv crypto.PrivateKey, keyID, alg, claims string) (string, error) {
key := jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(alg),
Key: priv,
}
opts := &jose.SignerOptions{
ExtraHeaders: map[jose.HeaderKey]interface{}{
jose.HeaderKey("kid"): keyID,
},
}

signer, err := jose.NewSigner(key, opts)
if err != nil {
return "", fmt.Errorf("creating signer: %v", err)
}
sig, err := signer.Sign([]byte(claims))
if err != nil {
return "", fmt.Errorf("signing payload: %v", err)
}
jwt, err := sig.CompactSerialize()
if err != nil {
return "", fmt.Errorf("serializing jwt: %v", err)
}
return jwt, nil
}

// PublicKey holds a public key as well as additional metadata about that
// key.
type PublicKey struct {
// Either *rsa.PublicKey or *ecdsa.PublicKey.
PublicKey crypto.PublicKey
// The ID of the key. Should match the value passed to [SignIDToken].
KeyID string
// Signature algorithm used by the public key, such as "RS256" or "RS256".
Algorithm string
}

// Server holds configuration for the OpenID Connect test server.
type Server struct {
// Public keys advertised by the server that can be used to sign tokens.
PublicKeys []PublicKey
// The set of signing algorithms used by the server. Defaults to "RS256".
Algorithms []string

issuerURL *url.URL
}

// SetIssuer must be called before serving traffic. This is usually the
// [httptest.Server.URL].
func (s *Server) SetIssuer(issuerURL string) {
u, err := url.Parse(issuerURL)
if err != nil {
panic("oidctest: invalid issuer URL: " + err.Error())
}
s.issuerURL = u
}

// ServeHTTP is the server's implementation of [http.Handler].
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
if r.Method != http.MethodGet {
http.Error(w, "Expected GET request for discovery endpoint, got: "+r.Method,
http.StatusMethodNotAllowed)
return
}
s.serveDiscovery(w, r)
case "/keys":
if r.Method != http.MethodGet {
http.Error(w, "Expected GET request for keys endpoint, got: "+r.Method,
http.StatusMethodNotAllowed)
return
}
s.serveKeys(w, r)
default:
http.Error(w, "Unknown path: "+r.URL.Path, http.StatusNotFound)
}
}

func (s *Server) serveDiscovery(w http.ResponseWriter, r *http.Request) {
if s.issuerURL == nil {
http.Error(w, "oidctest: server called without SetIssuer()", http.StatusInternalServerError)
return
}

algs := s.Algorithms
if len(algs) == 0 {
algs = []string{"RS256"}
}
disc := struct {
Issuer string `json:"issuer"`
Auth string `json:"authorization_endpoint"`
Token string `json:"token_endpoint"`
JWKs string `json:"jwks_uri"`
ResponseTypes []string `json:"response_types_supported"`
SubjectTypes []string `json:"subject_types_supported"`
Algs []string `json:"id_token_signing_alg_values_supported"`
}{
Issuer: s.issuerURL.String(),
Auth: s.issuerURL.JoinPath("/auth").String(),
Token: s.issuerURL.JoinPath("/token").String(),
JWKs: s.issuerURL.JoinPath("/keys").String(),
ResponseTypes: []string{"code", "id_token", "token id_token"},
SubjectTypes: []string{"public"},
Algs: algs,
}

w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(disc)
}

func (s *Server) serveKeys(w http.ResponseWriter, r *http.Request) {
set := &jose.JSONWebKeySet{}
for _, pub := range s.PublicKeys {
k := jose.JSONWebKey{
Key: pub.PublicKey,
KeyID: pub.KeyID,
Algorithm: pub.Algorithm,
Use: "sig",
}
set.Keys = append(set.Keys, k)
}

w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(set)
}
78 changes: 78 additions & 0 deletions oidc/oidctest/oidctest_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package oidctest_test

import (
"context"
"crypto/rand"
"crypto/rsa"
"net/http/httptest"
"strconv"
"testing"
"time"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/coreos/go-oidc/v3/oidc/oidctest"
)

func TestServer(t *testing.T) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("creating server: %v", err)
}

s := &oidctest.Server{
PublicKeys: []oidctest.PublicKey{
{
PublicKey: priv.Public(),
KeyID: "my-key-id",
Algorithm: oidc.RS256,
},
},
}
srv := httptest.NewServer(s)
defer srv.Close()
s.SetIssuer(srv.URL)

now := time.Now()
rawClaims := `{
"iss": "` + srv.URL + `",
"aud": "my-client-id",
"sub": "foo",
"exp": ` + strconv.FormatInt(now.Add(time.Hour).Unix(), 10) + `,
"email": "[email protected]",
"email_verified": true
}`
token := oidctest.SignIDToken(priv, "my-key-id", oidc.RS256, rawClaims)

ctx := context.Background()
p, err := oidc.NewProvider(ctx, srv.URL)
if err != nil {
t.Fatalf("new provider: %v", err)
}
config := &oidc.Config{
ClientID: "my-client-id",
Now: func() time.Time { return now },
}
v := p.VerifierContext(ctx, config)

idToken, err := v.Verify(ctx, token)
if err != nil {
t.Fatalf("verifying token: %v", err)
}
if want := "foo"; idToken.Subject != want {
t.Errorf("ID token returned unexpected subject, got=%v, want=%v", idToken.Subject, want)
}

var claims struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}
if err := idToken.Claims(&claims); err != nil {
t.Fatalf("parsing id token claims: %v", err)
}
if want := "[email protected]"; claims.Email != want {
t.Errorf("ID token returned unexpected email, got=%v, want=%v", claims.Email, want)
}
if want := true; claims.EmailVerified != want {
t.Errorf("ID token returned unexpected email_verified, got=%v, want=%v", claims.EmailVerified, want)
}
}

0 comments on commit 38beddf

Please sign in to comment.