Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ComputeAuthTimeout expiry overflow reproducer #261

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tpm2/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,9 @@ const (
TagAttestCertify tpmutil.Tag = 0x8017
TagAttestQuote tpmutil.Tag = 0x8018
TagAttestCreation tpmutil.Tag = 0x801a
TagAuthSecret tpmutil.Tag = 0x8023
TagHashCheck tpmutil.Tag = 0x8024
TagAuthSigned tpmutil.Tag = 0x8025
)

// StartupType instructs the TPM on how to handle its state during Shutdown or
Expand Down Expand Up @@ -470,6 +472,7 @@ const (
CmdSequenceUpdate tpmutil.Command = 0x0000015C
CmdSign tpmutil.Command = 0x0000015D
CmdUnseal tpmutil.Command = 0x0000015E
CmdPolicySigned tpmutil.Command = 0x00000160
CmdContextLoad tpmutil.Command = 0x00000161
CmdContextSave tpmutil.Command = 0x00000162
CmdECDHKeyGen tpmutil.Command = 0x00000163
Expand Down
21 changes: 21 additions & 0 deletions tpm2/structures.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,27 @@ type Signature struct {
ECC *SignatureECC
}

// Encode serializes a Signature structure in TPM wire format.
func (s Signature) Encode() ([]byte, error) {
head, err := tpmutil.Pack(s.Alg)
if err != nil {
return nil, fmt.Errorf("encoding Alg: %v", err)
}
var signature []byte
switch s.Alg {
case AlgRSASSA, AlgRSAPSS:
if signature, err = tpmutil.Pack(s.RSA); err != nil {
return nil, fmt.Errorf("encoding RSA: %v", err)
}
case AlgECDSA:
signature, err = tpmutil.Pack(s.ECC.HashAlg, tpmutil.U16Bytes(s.ECC.R.Bytes()), tpmutil.U16Bytes(s.ECC.S.Bytes()))
if err != nil {
return nil, fmt.Errorf("encoding ECC: %v", err)
}
}
return concat(head, signature)
}

// DecodeSignature decodes a serialized TPMT_SIGNATURE structure.
func DecodeSignature(in *bytes.Buffer) (*Signature, error) {
var sig Signature
Expand Down
263 changes: 256 additions & 7 deletions tpm2/test/tpm2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ import (
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"encoding/binary"
"flag"
"fmt"
"hash"
"io"
"math"
"math/big"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -77,8 +80,32 @@ var (
ExponentRaw: 1<<16 + 1,
},
}
defaultPassword = "\x01\x02\x03\x04"
emptyPassword = ""
defaultPassword = "\x01\x02\x03\x04"
emptyPassword = ""
defaultRsaSignerParams = Public{
Type: AlgRSA,
NameAlg: AlgSHA256,
Attributes: FlagSign | FlagSensitiveDataOrigin | FlagUserWithAuth,
RSAParameters: &RSAParams{
Sign: &SigScheme{
Alg: AlgRSASSA,
Hash: AlgSHA256,
},
KeyBits: 2048,
},
}
defaultEccSignerParams = Public{
Type: AlgECC,
NameAlg: AlgSHA256,
Attributes: FlagSign | FlagSensitiveDataOrigin | FlagUserWithAuth,
ECCParameters: &ECCParams{
Sign: &SigScheme{
Alg: AlgECDSA,
Hash: AlgSHA256,
},
CurveID: CurveNISTP256,
},
}
)

func min(a, b int) int {
Expand Down Expand Up @@ -1182,6 +1209,69 @@ func TestEncodeDecodePublicDefaultRSAExponent(t *testing.T) {
}
}

func TestEncodeDecodeSignature(t *testing.T) {
randRSASig := func() []byte {
// Key size 2048 bits
var size uint16 = 256
sizeU16 := make([]byte, 2)
binary.BigEndian.PutUint16(sizeU16, size)
key := make([]byte, size)
rand.Read(key)
return append(sizeU16, key...)
}

run := func(t *testing.T, s Signature) {
e, err := s.Encode()
if err != nil {
t.Fatalf("Signature{%+v}.Encode() returned error: %v", s, err)
}
d, err := DecodeSignature(bytes.NewBuffer(e))
if err != nil {
t.Fatalf("DecodeSignature{%v} returned error: %v", e, err)
}
if !reflect.DeepEqual(s, *d) {
t.Errorf("got decoded value:\n%v\nwant:\n%v", d, s)
}
}
t.Run("RSASSA", func(t *testing.T) {
run(t, Signature{
Alg: AlgRSASSA,
RSA: &SignatureRSA{
HashAlg: AlgSHA256,
Signature: randRSASig(),
},
})
})
t.Run("RSAPSS", func(t *testing.T) {
run(t, Signature{
Alg: AlgRSAPSS,
RSA: &SignatureRSA{
HashAlg: AlgSHA256,
Signature: randRSASig(),
},
})
})
t.Run("ECDSA", func(t *testing.T) {
// Key size 256 bits
size := 32
randBytes := make([]byte, size)
rand.Read(randBytes)
r := big.NewInt(0).SetBytes(randBytes)

rand.Read(randBytes)
s := big.NewInt(0).SetBytes(randBytes)

run(t, Signature{
Alg: AlgECDSA,
ECC: &SignatureECC{
HashAlg: AlgSHA256,
R: r,
S: s,
},
})
})
}

func TestCreateKeyWithSensitive(t *testing.T) {
rw := openTPM(t)
defer rw.Close()
Expand Down Expand Up @@ -1429,15 +1519,174 @@ func TestPolicySecret(t *testing.T) {
rw := openTPM(t)
defer rw.Close()

sessHandle, _, err := StartAuthSession(rw, HandleNull, HandleNull, make([]byte, 16), nil, SessionPolicy, AlgNull, AlgSHA256)
if err != nil {
t.Fatalf("StartAuthSession() failed: %v", err)
var nullTicket = Ticket{Type: TagAuthSecret, Hierarchy: HandleNull}

expirations := []int32{math.MinInt32, math.MinInt32 + 1, -1, 0, 1, math.MaxInt32}
for _, expiration := range expirations {
t.Run(t.Name()+fmt.Sprint(expiration), func(t *testing.T) {
sessHandle, nonce, err := StartAuthSession(rw, HandleNull, HandleNull, make([]byte, 16), nil, SessionPolicy, AlgNull, AlgSHA256)
if err != nil {
t.Fatalf("StartAuthSession() failed: %v", err)
}
defer FlushContext(rw, sessHandle)

_, tkt := testPolicySecret(t, rw, sessHandle, nonce, expiration)
if expiration < 0 && len(tkt.Digest) == 0 {
t.Fatalf("Got empty ticket digest, expected ticket with auth expiry")
} else if expiration >= 0 && !reflect.DeepEqual(*tkt, nullTicket) {
t.Fatalf("Got ticket with non-empty digest, expected NULL ticket")
}
})
}
defer FlushContext(rw, sessHandle)
}

if _, err := PolicySecret(rw, HandleEndorsement, AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession}, sessHandle, nil, nil, nil, 0); err != nil {
func testPolicySecret(t *testing.T, rw io.ReadWriter, sessHandle tpmutil.Handle, nonce []byte, expiration int32) ([]byte, *Ticket) {

timeout, tkt, err := PolicySecret(rw, HandleEndorsement, AuthCommand{Session: HandlePasswordSession, Attributes: AttrContinueSession}, sessHandle, nonce, nil, nil, expiration)
if err != nil {
t.Fatalf("PolicySecret() failed: %v", err)
}
return timeout, tkt
}

func TestPolicySigned(t *testing.T) {
rw := openTPM(t)
defer rw.Close()

var nullTicket = Ticket{Type: TagAuthSigned, Hierarchy: HandleNull}
signers := map[string]Public{
"RSA": defaultRsaSignerParams,
"ECC": defaultEccSignerParams,
}

expirations := []int32{math.MinInt32, math.MinInt32 + 1, 0, math.MaxInt32}
for _, expiration := range expirations {
for signerType, params := range signers {
t.Run(t.Name()+signerType+fmt.Sprint(expiration), func(t *testing.T) {
sessHandle, nonce, err := StartAuthSession(rw, HandleNull, HandleNull, make([]byte, 16), nil, SessionPolicy, AlgNull, AlgSHA256)
if err != nil {
t.Fatalf("StartAuthSession() failed: %v", err)
}
defer FlushContext(rw, sessHandle)

_, tkt := testPolicySigned(t, rw, sessHandle, nonce, expiration, params)
if expiration < 0 && len(tkt.Digest) == 0 {
t.Fatalf("Got empty ticket digest, expected ticket with auth expiry")
} else if expiration >= 0 && !reflect.DeepEqual(*tkt, nullTicket) {
t.Fatalf("Got ticket with non-empty digest, expected NULL ticket")
}
})
}
}
}

func testPolicySigned(t *testing.T, rw io.ReadWriter, sessHandle tpmutil.Handle, nonce []byte, expiration int32, signerParams Public) ([]byte, *Ticket) {
handle, _, err := CreatePrimary(rw, HandleOwner, PCRSelection{}, emptyPassword, emptyPassword, signerParams)
if err != nil {
t.Fatalf("CreatePrimary() failed: %s", err)
}
defer FlushContext(rw, handle)

// Sign the hash of the command parameters, as described in the TPM 2.0 spec, Part 3, Section 23.3.
// We only use expiration here.
expBytes := make([]byte, 4)
binary.BigEndian.PutUint32(expBytes, uint32(expiration))

// TPM2.0 spec, Revision 1.38, Part 3 nonce must be present if expiration is non-zero.
// aHash ≔ HauthAlg(nonceTPM || expiration || cpHashA || policyRef)
toDigest := append(nonce, expBytes...)

digest := sha256.Sum256(toDigest)

sig, err := Sign(rw, handle, emptyPassword, digest[:], nil, nil)
if err != nil {
t.Fatalf("Sign failed: %s", err)
}

signature, err := sig.Encode()
if err != nil {
t.Fatalf("Encode() failed: %v", err)
}

timeout, tkt, err := PolicySigned(rw, handle, sessHandle, nonce, nil, nil, expiration, signature)
if err != nil {
t.Fatalf("PolicySigned() failed: %v", err)
}
return timeout, tkt
}

func timeoutToUint64(timeout []byte) uint64 {
// The MSFT TPM simulator uses the MSB to indicate
// whether to expire on reset. Strip this out if set.
timeoutUint64 := binary.BigEndian.Uint64(timeout)
expiryBit := uint64(1) << 63
if timeoutUint64&expiryBit != 0 {
timeoutUint64 -= expiryBit
}
return timeoutUint64
}

// authTimeoutWithinExpectedRange expects the policy to not use nonceTpm.
func authTimeoutWithinExpectedRange(expiration int32, timeout []byte) bool {
// https://github.com/microsoft/ms-tpm-20-ref/blob/b94f9f92c579b723a16be72a69efbbf9c35ce44e/TPMCmd/tpm/src/command/EA/Policy_spt.c#L195

absExp := uint64(math.Abs(float64(expiration)))
var authTimeout uint64 = timeoutToUint64(timeout)

absExpInMs := absExp * 1000
if authTimeout < absExpInMs || authTimeout >= absExpInMs+1000 {
return false
}
return true
}

func skipOnUnsupportedRevision(t *testing.T, rw io.ReadWriter, revision uint32) {
props, _, err := GetCapability(rw, CapabilityTPMProperties, 3, uint32(FamilyIndicator))
if err != nil {
t.Fatalf("GetCapability failed: %v", err)
}
if props[2].(TaggedProperty).Value <= 116 {
t.Skipf("Test %v does not support TPM2, Revision %v", t.Name(), revision)
}
}

func TestComputeAuthTimeoutAbsValue(t *testing.T) {
// This test tests the absolue value function in
// https://github.com/microsoft/ms-tpm-20-ref/blob/b94f9f92c579b723a16be72a69efbbf9c35ce44e/TPMCmd/tpm/src/command/EA/Policy_spt.c#L189.
// ComputeAuthTimeout casts expiration to UINT64. This is undefined
// behavior for a negative number, either sign or zero-extended.
// This only invokes UB when the expiration is int32 min, as the
// abs value function is `expiration = -expiration`.
// This tests against revisions > 1.16, as ComputeAuthTimeout shows up in revisions 1.38 and 1.59.

rw := openTPM(t)
defer rw.Close()
skipOnUnsupportedRevision(t, rw, 116)

expirations := []int32{math.MinInt32, math.MinInt32 + 1}
for _, expiration := range expirations {
sessHandle, _, err := StartAuthSession(rw, HandleNull, HandleNull, make([]byte, 16), nil, SessionPolicy, AlgNull, AlgSHA256)
if err != nil {
t.Fatalf("StartAuthSession() failed: %v", err)
}
defer FlushContext(rw, sessHandle)

timeout, _ := testPolicySecret(t, rw, sessHandle, nil, expiration)
if len(timeout) == 0 {
t.Fatal("Expected a non-empty timeout!")
}
if !authTimeoutWithinExpectedRange(expiration, timeout) {
t.Errorf("The timeout %v is not in the expected range for expiration %v!", timeoutToUint64(timeout), expiration)
}

timeout, _ = testPolicySigned(t, rw, sessHandle, nil, expiration, defaultEccSignerParams)
if len(timeout) == 0 {
t.Fatal("Expected a non-empty timeout!")
}
if !authTimeoutWithinExpectedRange(expiration, timeout) {
t.Errorf("The timeout %v is not in the expected range for expiration %v!", timeoutToUint64(timeout), expiration)
}
}
}

func TestQuote(t *testing.T) {
Expand Down
Loading