From e9ffeca74ee3648a40bb30a4a2dbc7936e82be45 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Thu, 11 Apr 2019 17:08:18 -0700 Subject: [PATCH] tmputil: introduce non-global encoding logic Right now it's not possible to use the tpm and tpm2 packages simultaneously. Refactor the encoding logic into separate structs, so using one doesn't impact the other, while maintaining the top level API. As a follow up, refactor the tpm and tpm2 packages to use the encoders. For example instead of: tpmutil.Pack(digest) Use: var encoding = tpmutil.Encoding1_2 encoding.Pack(digest) --- tpm/constants.go | 2 +- tpm2/constants.go | 2 +- tpmutil/encoding.go | 124 +++++++++++++++++++++++++++------------ tpmutil/encoding_test.go | 60 +++++++++---------- tpmutil/run.go | 19 ++++-- 5 files changed, 129 insertions(+), 78 deletions(-) diff --git a/tpm/constants.go b/tpm/constants.go index 8dc8e4cf..c568987f 100644 --- a/tpm/constants.go +++ b/tpm/constants.go @@ -18,7 +18,7 @@ import "github.com/google/go-tpm/tpmutil" func init() { // TPM 1.2 spec uses uint32 for length prefix of byte arrays. - tpmutil.UseTPM12LengthPrefixSize() + tpmutil.UseTPM12Encoding() } // Supported TPM commands. diff --git a/tpm2/constants.go b/tpm2/constants.go index ea954e7a..0b835463 100644 --- a/tpm2/constants.go +++ b/tpm2/constants.go @@ -26,7 +26,7 @@ import ( ) func init() { - tpmutil.UseTPM20LengthPrefixSize() + tpmutil.UseTPM20Encoding() } // MAX_DIGEST_BUFFER is the maximum size of []byte request or response fields. diff --git a/tpmutil/encoding.go b/tpmutil/encoding.go index 9457777c..e05dd360 100644 --- a/tpmutil/encoding.go +++ b/tpmutil/encoding.go @@ -23,32 +23,49 @@ import ( "reflect" ) -// lengthPrefixSize is the size in bytes of length prefix for byte slices. -// -// In TPM 1.2 this is 4 bytes. -// In TPM 2.0 this is 2 bytes. -var lengthPrefixSize int +// Encoding implements encoding logic for different versions of the TPM +// specification. +type Encoding struct { + // lengthPrefixSize is the size in bytes of length prefix for byte slices. + // + // In TPM 1.2 this is 4 bytes. + // In TPM 2.0 this is 2 bytes. + lengthPrefixSize int +} + +var ( + // Encoding1_2 implements TPM 1.2 encoding. + Encoding1_2 = &Encoding{ + lengthPrefixSize: tpm12PrefixSize, + } + // Encoding2_0 implements TPM 2.0 encoding. + Encoding2_0 = &Encoding{ + lengthPrefixSize: tpm20PrefixSize, + } + + defaultEncoding *Encoding +) const ( tpm12PrefixSize = 4 tpm20PrefixSize = 2 ) -// UseTPM12LengthPrefixSize makes Pack/Unpack use TPM 1.2 encoding for byte -// arrays. -func UseTPM12LengthPrefixSize() { - lengthPrefixSize = tpm12PrefixSize +// UseTPM12Encoding makes the package level Pack/Unpack functions use +// TPM 1.2 encoding for byte arrays. +func UseTPM12Encoding() { + defaultEncoding = Encoding1_2 } -// UseTPM20LengthPrefixSize makes Pack/Unpack use TPM 2.0 encoding for byte -// arrays. -func UseTPM20LengthPrefixSize() { - lengthPrefixSize = tpm20PrefixSize +// UseTPM20Encoding makes the package level Pack/Unpack functions use +// TPM 2.0 encoding for byte arrays. +func UseTPM20Encoding() { + defaultEncoding = Encoding2_0 } // packedSize computes the size of a sequence of types that can be passed to // binary.Read or binary.Write. -func packedSize(elts ...interface{}) (int, error) { +func (enc *Encoding) packedSize(elts ...interface{}) (int, error) { var size int for _, e := range elts { marshaler, ok := e.(SelfMarshaler) @@ -59,7 +76,7 @@ func packedSize(elts ...interface{}) (int, error) { v := reflect.ValueOf(e) switch v.Kind() { case reflect.Ptr: - s, err := packedSize(reflect.Indirect(v).Interface()) + s, err := enc.packedSize(reflect.Indirect(v).Interface()) if err != nil { return 0, err } @@ -67,7 +84,7 @@ func packedSize(elts ...interface{}) (int, error) { size += s case reflect.Struct: for i := 0; i < v.NumField(); i++ { - s, err := packedSize(v.Field(i).Interface()) + s, err := enc.packedSize(v.Field(i).Interface()) if err != nil { return 0, err } @@ -77,7 +94,7 @@ func packedSize(elts ...interface{}) (int, error) { case reflect.Slice: switch s := e.(type) { case []byte: - size += lengthPrefixSize + len(s) + size += enc.lengthPrefixSize + len(s) case RawBytes: size += len(s) default: @@ -100,16 +117,27 @@ func packedSize(elts ...interface{}) (int, error) { // fixed length or slices of fixed-length types and packs them into a single // byte array using binary.Write. It updates the CommandHeader to have the right // length. -func packWithHeader(ch commandHeader, cmd ...interface{}) ([]byte, error) { +func (enc *Encoding) packWithHeader(ch commandHeader, cmd ...interface{}) ([]byte, error) { hdrSize := binary.Size(ch) - bodySize, err := packedSize(cmd...) + bodySize, err := enc.packedSize(cmd...) if err != nil { return nil, fmt.Errorf("couldn't compute packed size for message body: %v", err) } ch.Size = uint32(hdrSize + bodySize) in := []interface{}{ch} in = append(in, cmd...) - return Pack(in...) + return enc.Pack(in...) +} + +// Pack encodes a set of elements using the package's default encoding. +// +// Callers must call UseTPM12Encoding() or UseTPM20Encoding() before calling +// this method. +func Pack(elts ...interface{}) ([]byte, error) { + if defaultEncoding == nil { + return nil, errors.New("default encoding not initialized") + } + return defaultEncoding.Pack(elts...) } // Pack encodes a set of elements into a single byte array, using @@ -119,13 +147,9 @@ func packWithHeader(ch commandHeader, cmd ...interface{}) ([]byte, error) { // It has one difference from encoding/binary: it encodes byte slices with a // prepended length, to match how the TPM encodes variable-length arrays. If // you wish to add a byte slice without length prefix, use RawBytes. -func Pack(elts ...interface{}) ([]byte, error) { - if lengthPrefixSize == 0 { - return nil, errors.New("lengthPrefixSize must be initialized") - } - +func (enc *Encoding) Pack(elts ...interface{}) ([]byte, error) { buf := new(bytes.Buffer) - if err := packType(buf, elts...); err != nil { + if err := enc.packType(buf, elts...); err != nil { return nil, err } @@ -137,7 +161,7 @@ func Pack(elts ...interface{}) ([]byte, error) { // lengthPrefixSize size followed by the bytes. The function unpackType // performs the inverse operation of unpacking slices stored in this manner and // using encoding/binary for everything else. -func packType(buf io.Writer, elts ...interface{}) error { +func (enc *Encoding) packType(buf io.Writer, elts ...interface{}) error { for _, e := range elts { marshaler, ok := e.(SelfMarshaler) if ok { @@ -149,20 +173,20 @@ func packType(buf io.Writer, elts ...interface{}) error { v := reflect.ValueOf(e) switch v.Kind() { case reflect.Ptr: - if err := packType(buf, reflect.Indirect(v).Interface()); err != nil { + if err := enc.packType(buf, reflect.Indirect(v).Interface()); err != nil { return err } case reflect.Struct: // TODO(awly): Currently packType cannot handle non-struct fields that implement SelfMarshaler for i := 0; i < v.NumField(); i++ { - if err := packType(buf, v.Field(i).Interface()); err != nil { + if err := enc.packType(buf, v.Field(i).Interface()); err != nil { return err } } case reflect.Slice: switch s := e.(type) { case []byte: - switch lengthPrefixSize { + switch enc.lengthPrefixSize { case tpm20PrefixSize: if err := binary.Write(buf, binary.BigEndian, uint16(len(s))); err != nil { return err @@ -172,7 +196,7 @@ func packType(buf io.Writer, elts ...interface{}) error { return err } default: - return fmt.Errorf("lengthPrefixSize is %d, must be either 2 or 4", lengthPrefixSize) + return fmt.Errorf("lengthPrefixSize is %d, must be either 2 or 4", enc.lengthPrefixSize) } if err := binary.Write(buf, binary.BigEndian, s); err != nil { return err @@ -195,21 +219,45 @@ func packType(buf io.Writer, elts ...interface{}) error { return nil } +// Unpack is a convenience wrapper around UnpackBuf using the package's default +// encoding. +// +// Callers must call UseTPM12Encoding() or UseTPM20Encoding() before calling +// this method. +func Unpack(b []byte, elts ...interface{}) (int, error) { + if defaultEncoding == nil { + return 0, errors.New("default encoding not initialized") + } + return defaultEncoding.Unpack(b, elts...) +} + // Unpack is a convenience wrapper around UnpackBuf. Unpack returns the number // of bytes read from b to fill elts and error, if any. -func Unpack(b []byte, elts ...interface{}) (int, error) { +func (enc *Encoding) Unpack(b []byte, elts ...interface{}) (int, error) { buf := bytes.NewBuffer(b) - err := UnpackBuf(buf, elts...) + err := enc.UnpackBuf(buf, elts...) read := len(b) - buf.Len() return read, err } +// UnpackBuf recursively unpacks types from a reader using the package's default +// encoding. +// +// Callers must call UseTPM12Encoding() or UseTPM20Encoding() before calling +// this method. +func UnpackBuf(buf io.Reader, elts ...interface{}) error { + if defaultEncoding == nil { + return errors.New("default encoding not initialized") + } + return defaultEncoding.UnpackBuf(buf, elts...) +} + // UnpackBuf recursively unpacks types from a reader just as encoding/binary // does under binary.BigEndian, but with one difference: it unpacks a byte // slice by first reading an integer with lengthPrefixSize bytes, then reading // that many bytes. It assumes that incoming values are pointers to values so // that, e.g., underlying slices can be resized as needed. -func UnpackBuf(buf io.Reader, elts ...interface{}) error { +func (enc *Encoding) UnpackBuf(buf io.Reader, elts ...interface{}) error { for _, e := range elts { v := reflect.ValueOf(e) k := v.Kind() @@ -233,7 +281,7 @@ func UnpackBuf(buf io.Reader, elts ...interface{}) error { case reflect.Struct: // Decompose the struct and copy over the values. for i := 0; i < iv.NumField(); i++ { - if err := UnpackBuf(buf, iv.Field(i).Addr().Interface()); err != nil { + if err := enc.UnpackBuf(buf, iv.Field(i).Addr().Interface()); err != nil { return err } } @@ -250,21 +298,21 @@ func UnpackBuf(buf io.Reader, elts ...interface{}) error { } size = int(tmpSize) // TPM 2.0 - case lengthPrefixSize == tpm20PrefixSize: + case enc.lengthPrefixSize == tpm20PrefixSize: var tmpSize uint16 if err := binary.Read(buf, binary.BigEndian, &tmpSize); err != nil { return err } size = int(tmpSize) // TPM 1.2 - case lengthPrefixSize == tpm12PrefixSize: + case enc.lengthPrefixSize == tpm12PrefixSize: var tmpSize uint32 if err := binary.Read(buf, binary.BigEndian, &tmpSize); err != nil { return err } size = int(tmpSize) default: - return fmt.Errorf("lengthPrefixSize is %d, must be either 2 or 4", lengthPrefixSize) + return fmt.Errorf("lengthPrefixSize is %d, must be either 2 or 4", enc.lengthPrefixSize) } // A zero size is used by the TPM to signal that certain elements diff --git a/tpmutil/encoding_test.go b/tpmutil/encoding_test.go index 9efe3155..106db04b 100644 --- a/tpmutil/encoding_test.go +++ b/tpmutil/encoding_test.go @@ -22,10 +22,6 @@ import ( "testing" ) -func init() { - UseTPM12LengthPrefixSize() -} - type invalidPacked struct { A []int B uint32 @@ -61,7 +57,7 @@ func testEncodingInvalidSlices(t *testing.T, f func(io.Writer, interface{}) erro func TestEncodingPackedSizeInvalid(t *testing.T) { f := func(w io.Writer, i interface{}) error { - _, err := packedSize(i) + _, err := Encoding1_2.packedSize(i) return err } @@ -70,7 +66,7 @@ func TestEncodingPackedSizeInvalid(t *testing.T) { func TestEncodingPackTypeInvalid(t *testing.T) { f := func(w io.Writer, i interface{}) error { - return packType(w, i) + return Encoding1_2.packType(w, i) } testEncodingInvalidSlices(t, f) @@ -106,7 +102,7 @@ func TestEncodingPackedSize(t *testing.T) { {[]byte(nil), 4}, } for _, tt := range tests { - if s, err := packedSize(tt.in); err != nil || s != tt.want { + if s, err := Encoding1_2.packedSize(tt.in); err != nil || s != tt.want { t.Errorf("packedSize(%#v): %d, want %d", tt.in, s, tt.want) } } @@ -125,7 +121,7 @@ func TestEncodingPackType(t *testing.T) { RawBytes(buf), } for _, i := range inputs { - if err := packType(ioutil.Discard, i); err != nil { + if err := Encoding1_2.packType(ioutil.Discard, i); err != nil { t.Errorf("packType(%#v): %v", i, err) } } @@ -140,7 +136,7 @@ func TestEncodingPackTypeWriteFail(t *testing.T) { {3, []byte(nil)}, } for _, tt := range tests { - if err := packType(&limitedDiscard{tt.limit}, tt.in); err == nil { + if err := Encoding1_2.packType(&limitedDiscard{tt.limit}, tt.in); err == nil { t.Errorf("packType(%#v) with write size limit %d returned nil, want error", tt.in, tt.limit) } } @@ -167,7 +163,7 @@ func (l *limitedDiscard) Write(p []byte) (n int, err error) { func TestEncodingCommandHeaderInvalidBody(t *testing.T) { var invalid []int ch := commandHeader{1, 0, 2} - _, err := packWithHeader(ch, invalid) + _, err := Encoding1_2.packWithHeader(ch, invalid) if err == nil { t.Fatal("packWithHeader incorrectly packed a body that with an invalid int slice member") } @@ -176,12 +172,12 @@ func TestEncodingCommandHeaderInvalidBody(t *testing.T) { func TestEncodingInvalidPack(t *testing.T) { var invalid []int ch := commandHeader{1, 0, 2} - _, err := packWithHeader(ch, invalid) + _, err := Encoding1_2.packWithHeader(ch, invalid) if err == nil { t.Fatal("packWithHeader incorrectly packed a body that with an invalid int slice member") } - _, err = Pack(invalid) + _, err = Encoding1_2.Pack(invalid) if err == nil { t.Fatal("pack incorrectly packed a slice of int") } @@ -192,14 +188,14 @@ func TestEncodingCommandHeaderEncoding(t *testing.T) { var c uint32 = 137 in := c - b, err := packWithHeader(ch, in) + b, err := Encoding1_2.packWithHeader(ch, in) if err != nil { t.Fatal("Couldn't pack the bytes:", err) } var hdr commandHeader var size uint32 - if _, err := Unpack(b, &hdr, &size); err != nil { + if _, err := Encoding1_2.Unpack(b, &hdr, &size); err != nil { t.Fatal("Couldn't unpack the packed bytes") } @@ -214,19 +210,19 @@ func TestEncodingInvalidUnpack(t *testing.T) { // The value ui is a serialization of uint32(0). ui := []byte{0, 0, 0, 0} uiBuf := bytes.NewBuffer(ui) - if err := UnpackBuf(uiBuf, i); err == nil { + if err := Encoding1_2.UnpackBuf(uiBuf, i); err == nil { t.Fatal("UnpackBuf incorrectly deserialized into a nil pointer") } var ii uint32 - if err := UnpackBuf(uiBuf, ii); err == nil { + if err := Encoding1_2.UnpackBuf(uiBuf, ii); err == nil { t.Fatal("UnpackBuf incorrectly deserialized into a non pointer") } var b []byte var empty []byte emptyBuf := bytes.NewBuffer(empty) - if err := UnpackBuf(emptyBuf, &b); err == nil { + if err := Encoding1_2.UnpackBuf(emptyBuf, &b); err == nil { t.Fatal("UnpackBuf incorrectly deserialized an empty byte array into a byte slice") } @@ -234,14 +230,14 @@ func TestEncodingInvalidUnpack(t *testing.T) { // The slice ui represents uint32(1), which is the length of an empty byte array. ui2 := []byte{0, 0, 0, 1} uiBuf2 := bytes.NewBuffer(ui2) - if err := UnpackBuf(uiBuf2, &b); err == nil { + if err := Encoding1_2.UnpackBuf(uiBuf2, &b); err == nil { t.Fatal("UnpackBuf incorrectly deserialized a byte array that didn't have enough bytes available") } var iii []int ui3 := []byte{0, 0, 0, 1} uiBuf3 := bytes.NewBuffer(ui3) - if err := UnpackBuf(uiBuf3, &iii); err == nil { + if err := Encoding1_2.UnpackBuf(uiBuf3, &iii); err == nil { t.Fatal("UnpackBuf incorrectly deserialized into a slice of ints (only byte slices are supported)") } @@ -253,14 +249,14 @@ func TestEncodingUnpack(t *testing.T) { // The slice ui represents uint32(0), which is the length of an empty byte array. ui := []byte{0, 0, 0, 0} uiBuf := bytes.NewBuffer(ui) - if err := UnpackBuf(uiBuf, &b); err != nil { + if err := Encoding1_2.UnpackBuf(uiBuf, &b); err != nil { t.Fatal("UnpackBuf failed to unpack the empty byte array") } // A byte slice of length 1 with a single entry: b[0] == 137 ui2 := []byte{0, 0, 0, 1, 137} uiBuf2 := bytes.NewBuffer(ui2) - if err := UnpackBuf(uiBuf2, &b); err != nil { + if err := Encoding1_2.UnpackBuf(uiBuf2, &b); err != nil { t.Fatal("UnpackBuf failed to unpack a byte array with a single value in it") } @@ -269,12 +265,12 @@ func TestEncodingUnpack(t *testing.T) { } sp := simplePacked{137, 138} - bsp, err := Pack(sp) + bsp, err := Encoding1_2.Pack(sp) if err != nil { t.Fatal("Couldn't pack a simple struct:", err) } var sp2 simplePacked - if _, err := Unpack(bsp, &sp2); err != nil { + if _, err := Encoding1_2.Unpack(bsp, &sp2); err != nil { t.Fatal("Couldn't unpack a simple struct:", err) } @@ -283,17 +279,17 @@ func TestEncodingUnpack(t *testing.T) { } // Try unpacking a version that's missing a byte at the end. - if _, err := Unpack(bsp[:len(bsp)-1], &sp2); err == nil { + if _, err := Encoding1_2.Unpack(bsp[:len(bsp)-1], &sp2); err == nil { t.Fatal("unpack incorrectly unpacked from a byte array that didn't have enough values") } np := nestedPacked{sp, 139} - bnp, err := Pack(np) + bnp, err := Encoding1_2.Pack(np) if err != nil { t.Fatal("Couldn't pack a nested struct") } var np2 nestedPacked - if _, err := Unpack(bnp, &np2); err != nil { + if _, err := Encoding1_2.Unpack(bnp, &np2); err != nil { t.Fatal("Couldn't unpack a nested struct:", err) } if np.SP.A != np2.SP.A || np.SP.B != np2.SP.B || np.C != np2.C { @@ -301,12 +297,12 @@ func TestEncodingUnpack(t *testing.T) { } ns := nestedSlice{137, b} - bns, err := Pack(ns) + bns, err := Encoding1_2.Pack(ns) if err != nil { t.Fatal("Couldn't pack a struct with a nested byte slice:", err) } var ns2 nestedSlice - if _, err := Unpack(bns, &ns2); err != nil { + if _, err := Encoding1_2.Unpack(bns, &ns2); err != nil { t.Fatal("Couldn't unpacked a struct with a nested slice:", err) } if ns.A != ns2.A || !bytes.Equal(ns.S, ns2.S) { @@ -314,7 +310,7 @@ func TestEncodingUnpack(t *testing.T) { } var hs []Handle - if _, err := Unpack([]byte{0, 3, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, &hs); err != nil { + if _, err := Encoding1_2.Unpack([]byte{0, 3, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, &hs); err != nil { t.Fatal("Couldn't unpack a list of Handles:", err) } if want := []Handle{0x01020304, 0x05060708, 0x090a0b0c}; !reflect.DeepEqual(want, hs) { @@ -324,20 +320,20 @@ func TestEncodingUnpack(t *testing.T) { func TestPartialUnpack(t *testing.T) { u1, u2 := uint32(1), uint32(2) - buf, err := Pack(u1, u2) + buf, err := Encoding1_2.Pack(u1, u2) if err != nil { t.Fatalf("packing uint32 value: %v", err) } var gu1, gu2 uint32 - read1, err := Unpack(buf, &gu1) + read1, err := Encoding1_2.Unpack(buf, &gu1) if err != nil { t.Fatalf("unpacking first uint32 value: %v", err) } if gu1 != u1 { t.Errorf("first unpacked value: got %d, want %d", gu1, u1) } - read2, err := Unpack(buf[read1:], &gu2) + read2, err := Encoding1_2.Unpack(buf[read1:], &gu2) if err != nil { t.Fatalf("unpacking second uint32 value: %v", err) } diff --git a/tpmutil/run.go b/tpmutil/run.go index 641febef..4eeb8c31 100644 --- a/tpmutil/run.go +++ b/tpmutil/run.go @@ -13,10 +13,6 @@ // limitations under the License. // Package tpmutil provides common utility functions for both TPM 1.2 and TPM 2.0 devices. -// -// Users should call either UseTPM12LengthPrefixSize or -// UseTPM20LengthPrefixSize before using this package, depending on their type -// of TPM device. package tpmutil import ( @@ -31,16 +27,27 @@ import ( // returning a header and a body in separate responses. const maxTPMResponse = 4096 +// RunCommand executes cmd with the package's default encoding. +// +// Callers must call UseTPM12Encoding() or UseTPM20Encoding() before calling +// this method. +func RunCommand(rw io.ReadWriter, tag Tag, cmd Command, in ...interface{}) ([]byte, ResponseCode, error) { + if defaultEncoding == nil { + return nil, 0, errors.New("default encoding not initialized") + } + return defaultEncoding.RunCommand(rw, tag, cmd, in...) +} + // RunCommand executes cmd with given tag and arguments. Returns TPM response // body (without response header) and response code from the header. Returned // error may be nil if response code is not RCSuccess; caller should check // both. -func RunCommand(rw io.ReadWriter, tag Tag, cmd Command, in ...interface{}) ([]byte, ResponseCode, error) { +func (enc *Encoding) RunCommand(rw io.ReadWriter, tag Tag, cmd Command, in ...interface{}) ([]byte, ResponseCode, error) { if rw == nil { return nil, 0, errors.New("nil TPM handle") } ch := commandHeader{tag, 0, cmd} - inb, err := packWithHeader(ch, in...) + inb, err := enc.packWithHeader(ch, in...) if err != nil { return nil, 0, err }