Skip to content

Commit

Permalink
tmputil: introduce non-global encoding logic
Browse files Browse the repository at this point in the history
Right now it's not possible to use the tpm and tpm2 packages
simultaneously. Refactor the encoding logic into separate structs, so
using one doesn't impact the other, while maintaining the top level API.

As a follow up, refactor the tpm and tpm2 packages to use the encoders.
For example instead of:

	tpmutil.Pack(digest)

Use:

	var encoding = tpmutil.Encoding1_2
	encoding.Pack(digest)
  • Loading branch information
ericchiang committed Apr 12, 2019
1 parent 60fe40e commit e9ffeca
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 78 deletions.
2 changes: 1 addition & 1 deletion tpm/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import "github.com/google/go-tpm/tpmutil"

func init() {
// TPM 1.2 spec uses uint32 for length prefix of byte arrays.
tpmutil.UseTPM12LengthPrefixSize()
tpmutil.UseTPM12Encoding()
}

// Supported TPM commands.
Expand Down
2 changes: 1 addition & 1 deletion tpm2/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

func init() {
tpmutil.UseTPM20LengthPrefixSize()
tpmutil.UseTPM20Encoding()
}

// MAX_DIGEST_BUFFER is the maximum size of []byte request or response fields.
Expand Down
124 changes: 86 additions & 38 deletions tpmutil/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,49 @@ import (
"reflect"
)

// lengthPrefixSize is the size in bytes of length prefix for byte slices.
//
// In TPM 1.2 this is 4 bytes.
// In TPM 2.0 this is 2 bytes.
var lengthPrefixSize int
// Encoding implements encoding logic for different versions of the TPM
// specification.
type Encoding struct {
// lengthPrefixSize is the size in bytes of length prefix for byte slices.
//
// In TPM 1.2 this is 4 bytes.
// In TPM 2.0 this is 2 bytes.
lengthPrefixSize int
}

var (
// Encoding1_2 implements TPM 1.2 encoding.
Encoding1_2 = &Encoding{
lengthPrefixSize: tpm12PrefixSize,
}
// Encoding2_0 implements TPM 2.0 encoding.
Encoding2_0 = &Encoding{
lengthPrefixSize: tpm20PrefixSize,
}

defaultEncoding *Encoding
)

const (
tpm12PrefixSize = 4
tpm20PrefixSize = 2
)

// UseTPM12LengthPrefixSize makes Pack/Unpack use TPM 1.2 encoding for byte
// arrays.
func UseTPM12LengthPrefixSize() {
lengthPrefixSize = tpm12PrefixSize
// UseTPM12Encoding makes the package level Pack/Unpack functions use
// TPM 1.2 encoding for byte arrays.
func UseTPM12Encoding() {
defaultEncoding = Encoding1_2
}

// UseTPM20LengthPrefixSize makes Pack/Unpack use TPM 2.0 encoding for byte
// arrays.
func UseTPM20LengthPrefixSize() {
lengthPrefixSize = tpm20PrefixSize
// UseTPM20Encoding makes the package level Pack/Unpack functions use
// TPM 2.0 encoding for byte arrays.
func UseTPM20Encoding() {
defaultEncoding = Encoding2_0
}

// packedSize computes the size of a sequence of types that can be passed to
// binary.Read or binary.Write.
func packedSize(elts ...interface{}) (int, error) {
func (enc *Encoding) packedSize(elts ...interface{}) (int, error) {
var size int
for _, e := range elts {
marshaler, ok := e.(SelfMarshaler)
Expand All @@ -59,15 +76,15 @@ func packedSize(elts ...interface{}) (int, error) {
v := reflect.ValueOf(e)
switch v.Kind() {
case reflect.Ptr:
s, err := packedSize(reflect.Indirect(v).Interface())
s, err := enc.packedSize(reflect.Indirect(v).Interface())
if err != nil {
return 0, err
}

size += s
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
s, err := packedSize(v.Field(i).Interface())
s, err := enc.packedSize(v.Field(i).Interface())
if err != nil {
return 0, err
}
Expand All @@ -77,7 +94,7 @@ func packedSize(elts ...interface{}) (int, error) {
case reflect.Slice:
switch s := e.(type) {
case []byte:
size += lengthPrefixSize + len(s)
size += enc.lengthPrefixSize + len(s)
case RawBytes:
size += len(s)
default:
Expand All @@ -100,16 +117,27 @@ func packedSize(elts ...interface{}) (int, error) {
// fixed length or slices of fixed-length types and packs them into a single
// byte array using binary.Write. It updates the CommandHeader to have the right
// length.
func packWithHeader(ch commandHeader, cmd ...interface{}) ([]byte, error) {
func (enc *Encoding) packWithHeader(ch commandHeader, cmd ...interface{}) ([]byte, error) {
hdrSize := binary.Size(ch)
bodySize, err := packedSize(cmd...)
bodySize, err := enc.packedSize(cmd...)
if err != nil {
return nil, fmt.Errorf("couldn't compute packed size for message body: %v", err)
}
ch.Size = uint32(hdrSize + bodySize)
in := []interface{}{ch}
in = append(in, cmd...)
return Pack(in...)
return enc.Pack(in...)
}

// Pack encodes a set of elements using the package's default encoding.
//
// Callers must call UseTPM12Encoding() or UseTPM20Encoding() before calling
// this method.
func Pack(elts ...interface{}) ([]byte, error) {
if defaultEncoding == nil {
return nil, errors.New("default encoding not initialized")
}
return defaultEncoding.Pack(elts...)
}

// Pack encodes a set of elements into a single byte array, using
Expand All @@ -119,13 +147,9 @@ func packWithHeader(ch commandHeader, cmd ...interface{}) ([]byte, error) {
// It has one difference from encoding/binary: it encodes byte slices with a
// prepended length, to match how the TPM encodes variable-length arrays. If
// you wish to add a byte slice without length prefix, use RawBytes.
func Pack(elts ...interface{}) ([]byte, error) {
if lengthPrefixSize == 0 {
return nil, errors.New("lengthPrefixSize must be initialized")
}

func (enc *Encoding) Pack(elts ...interface{}) ([]byte, error) {
buf := new(bytes.Buffer)
if err := packType(buf, elts...); err != nil {
if err := enc.packType(buf, elts...); err != nil {
return nil, err
}

Expand All @@ -137,7 +161,7 @@ func Pack(elts ...interface{}) ([]byte, error) {
// lengthPrefixSize size followed by the bytes. The function unpackType
// performs the inverse operation of unpacking slices stored in this manner and
// using encoding/binary for everything else.
func packType(buf io.Writer, elts ...interface{}) error {
func (enc *Encoding) packType(buf io.Writer, elts ...interface{}) error {
for _, e := range elts {
marshaler, ok := e.(SelfMarshaler)
if ok {
Expand All @@ -149,20 +173,20 @@ func packType(buf io.Writer, elts ...interface{}) error {
v := reflect.ValueOf(e)
switch v.Kind() {
case reflect.Ptr:
if err := packType(buf, reflect.Indirect(v).Interface()); err != nil {
if err := enc.packType(buf, reflect.Indirect(v).Interface()); err != nil {
return err
}
case reflect.Struct:
// TODO(awly): Currently packType cannot handle non-struct fields that implement SelfMarshaler
for i := 0; i < v.NumField(); i++ {
if err := packType(buf, v.Field(i).Interface()); err != nil {
if err := enc.packType(buf, v.Field(i).Interface()); err != nil {
return err
}
}
case reflect.Slice:
switch s := e.(type) {
case []byte:
switch lengthPrefixSize {
switch enc.lengthPrefixSize {
case tpm20PrefixSize:
if err := binary.Write(buf, binary.BigEndian, uint16(len(s))); err != nil {
return err
Expand All @@ -172,7 +196,7 @@ func packType(buf io.Writer, elts ...interface{}) error {
return err
}
default:
return fmt.Errorf("lengthPrefixSize is %d, must be either 2 or 4", lengthPrefixSize)
return fmt.Errorf("lengthPrefixSize is %d, must be either 2 or 4", enc.lengthPrefixSize)
}
if err := binary.Write(buf, binary.BigEndian, s); err != nil {
return err
Expand All @@ -195,21 +219,45 @@ func packType(buf io.Writer, elts ...interface{}) error {
return nil
}

// Unpack is a convenience wrapper around UnpackBuf using the package's default
// encoding.
//
// Callers must call UseTPM12Encoding() or UseTPM20Encoding() before calling
// this method.
func Unpack(b []byte, elts ...interface{}) (int, error) {
if defaultEncoding == nil {
return 0, errors.New("default encoding not initialized")
}
return defaultEncoding.Unpack(b, elts...)
}

// Unpack is a convenience wrapper around UnpackBuf. Unpack returns the number
// of bytes read from b to fill elts and error, if any.
func Unpack(b []byte, elts ...interface{}) (int, error) {
func (enc *Encoding) Unpack(b []byte, elts ...interface{}) (int, error) {
buf := bytes.NewBuffer(b)
err := UnpackBuf(buf, elts...)
err := enc.UnpackBuf(buf, elts...)
read := len(b) - buf.Len()
return read, err
}

// UnpackBuf recursively unpacks types from a reader using the package's default
// encoding.
//
// Callers must call UseTPM12Encoding() or UseTPM20Encoding() before calling
// this method.
func UnpackBuf(buf io.Reader, elts ...interface{}) error {
if defaultEncoding == nil {
return errors.New("default encoding not initialized")
}
return defaultEncoding.UnpackBuf(buf, elts...)
}

// UnpackBuf recursively unpacks types from a reader just as encoding/binary
// does under binary.BigEndian, but with one difference: it unpacks a byte
// slice by first reading an integer with lengthPrefixSize bytes, then reading
// that many bytes. It assumes that incoming values are pointers to values so
// that, e.g., underlying slices can be resized as needed.
func UnpackBuf(buf io.Reader, elts ...interface{}) error {
func (enc *Encoding) UnpackBuf(buf io.Reader, elts ...interface{}) error {
for _, e := range elts {
v := reflect.ValueOf(e)
k := v.Kind()
Expand All @@ -233,7 +281,7 @@ func UnpackBuf(buf io.Reader, elts ...interface{}) error {
case reflect.Struct:
// Decompose the struct and copy over the values.
for i := 0; i < iv.NumField(); i++ {
if err := UnpackBuf(buf, iv.Field(i).Addr().Interface()); err != nil {
if err := enc.UnpackBuf(buf, iv.Field(i).Addr().Interface()); err != nil {
return err
}
}
Expand All @@ -250,21 +298,21 @@ func UnpackBuf(buf io.Reader, elts ...interface{}) error {
}
size = int(tmpSize)
// TPM 2.0
case lengthPrefixSize == tpm20PrefixSize:
case enc.lengthPrefixSize == tpm20PrefixSize:
var tmpSize uint16
if err := binary.Read(buf, binary.BigEndian, &tmpSize); err != nil {
return err
}
size = int(tmpSize)
// TPM 1.2
case lengthPrefixSize == tpm12PrefixSize:
case enc.lengthPrefixSize == tpm12PrefixSize:
var tmpSize uint32
if err := binary.Read(buf, binary.BigEndian, &tmpSize); err != nil {
return err
}
size = int(tmpSize)
default:
return fmt.Errorf("lengthPrefixSize is %d, must be either 2 or 4", lengthPrefixSize)
return fmt.Errorf("lengthPrefixSize is %d, must be either 2 or 4", enc.lengthPrefixSize)
}

// A zero size is used by the TPM to signal that certain elements
Expand Down
Loading

0 comments on commit e9ffeca

Please sign in to comment.