Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tdx: support extracting host data from quote & fill in extensions #830

Merged
merged 5 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions coordinator/internal/authority/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,20 @@ package authority

import (
"context"
"encoding/asn1"
"errors"
"fmt"
"log/slog"
"net"
"time"

"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation"
"github.com/edgelesssys/contrast/internal/attestation/certcache"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/attestation/tdx"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/memstore"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -90,7 +89,7 @@ func (c *Credentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.A
}

for _, opt := range opts {
validator := snp.NewValidatorWithCallbacks(opt.VerifyOpts, opt.ValidateOpts, allowedHostDataEntries,
validator := snp.NewValidatorWithReportSetter(opt.VerifyOpts, opt.ValidateOpts, allowedHostDataEntries,
logger.NewWithAttrs(logger.NewNamed(c.logger, "validator"), map[string]string{"tee-type": "snp"}),
&authInfo)
validators = append(validators, validator)
Expand All @@ -101,8 +100,8 @@ func (c *Credentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.A
return nil, nil, fmt.Errorf("generating TDX validation options: %w", err)
}
for _, opt := range tdxOpts {
validators = append(validators, tdx.NewValidator(&tdx.StaticValidateOptsGenerator{Opts: opt},
logger.NewWithAttrs(logger.NewNamed(c.logger, "validator"), map[string]string{"tee-type": "tdx"})))
validators = append(validators, tdx.NewValidatorWithReportSetter(&tdx.StaticValidateOptsGenerator{Opts: opt},
logger.NewWithAttrs(logger.NewNamed(c.logger, "validator"), map[string]string{"tee-type": "tdx"}), &authInfo))
}

serverCfg, err := atls.CreateAttestationServerTLSConfig(c.issuer, validators, c.attestationFailuresCounter)
Expand Down Expand Up @@ -153,11 +152,10 @@ type AuthInfo struct {
// State is the coordinator state at the time of the TLS handshake.
State *State
// Report is the attestation report sent by the peer.
Report *sevsnp.Report
Report attestation.Report
}

// ValidateCallback takes the validated report and attaches it to the [AuthInfo].
func (a *AuthInfo) ValidateCallback(_ context.Context, report *sevsnp.Report, _ asn1.ObjectIdentifier, _, _, _ []byte) error {
// SetReport takes the validated report and attaches it to the [AuthInfo].
func (a *AuthInfo) SetReport(report attestation.Report) {
a.Report = report
return nil
}
5 changes: 2 additions & 3 deletions coordinator/meshapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/edgelesssys/contrast/coordinator/internal/authority"
"github.com/edgelesssys/contrast/coordinator/internal/seedengine"
"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/meshapi"
grpcprometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus"
Expand Down Expand Up @@ -106,7 +105,7 @@ func (i *meshAPIServer) NewMeshCert(ctx context.Context, _ *meshapi.NewMeshCertR
return nil, fmt.Errorf("could not marshal public key: %w", err)
}

hostData := manifest.NewHexString(report.HostData)
hostData := manifest.NewHexString(report.HostData())
entry, ok := state.Manifest.Policies[hostData]
if !ok {
return nil, fmt.Errorf("report data %s not found in manifest", hostData)
Expand All @@ -118,7 +117,7 @@ func (i *meshAPIServer) NewMeshCert(ctx context.Context, _ *meshapi.NewMeshCertR
return nil, fmt.Errorf("failed to parse peer public key: %w", err)
}

extensions, err := snp.ClaimsToCertExtension(report)
extensions, err := report.ClaimsToCertExtension()
if err != nil {
return nil, fmt.Errorf("failed to construct extensions: %w", err)
}
Expand Down
11 changes: 3 additions & 8 deletions internal/atls/atls.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ import (
"github.com/prometheus/client_golang/prometheus"
)

const attestationTimeout = 30 * time.Second

var (
// NoValidators skips validation of the server's attestation document.
NoValidators = []Validator{}
Expand Down Expand Up @@ -94,7 +92,7 @@ type Issuer interface {
// Validator is able to validate an attestation document.
type Validator interface {
Getter
Validate(ctx context.Context, attDoc []byte, nonce []byte, peerPublicKey []byte) error
Validate(attDoc []byte, nonce []byte, peerPublicKey []byte) error
}

// getATLSConfigForClientFunc returns a config setup function that is called once for every client connecting to the server.
Expand Down Expand Up @@ -242,10 +240,7 @@ func verifyEmbeddedReport(validators []Validator, cert *x509.Certificate, peerPu
// We've found a matching validator. Let's validate the document.
foundMatchingValidator = true

ctx, cancel := context.WithTimeout(context.Background(), attestationTimeout)
defer cancel()

validationErr := validator.Validate(ctx, ex.Value, nonce, peerPublicKey)
validationErr := validator.Validate(ex.Value, nonce, peerPublicKey)
if validationErr == nil {
// The validator has successfully verified the document. We can exit.
return nil
Expand Down Expand Up @@ -429,7 +424,7 @@ func NewFakeValidators(oid Getter) []Validator {
}

// Validate unmarshals the attestation document and verifies the nonce.
func (v FakeValidator) Validate(_ context.Context, attDoc []byte, nonce []byte, _ []byte) error {
func (v FakeValidator) Validate(attDoc []byte, nonce []byte, _ []byte) error {
var doc FakeAttestationDoc
if err := json.Unmarshal(attDoc, &doc); err != nil {
return err
Expand Down
19 changes: 19 additions & 0 deletions internal/attestation/callback.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright 2024 Edgeless Systems GmbH
// SPDX-License-Identifier: AGPL-3.0-only

package attestation

import (
"crypto/x509/pkix"
)

// Report is a verified and validates TEE attestation report.
type Report interface {
HostData() []byte
ClaimsToCertExtension() ([]pkix.Extension, error)
}

// ReportSetter is called by a validator after it verified and validated an attestation report.
type ReportSetter interface {
SetReport(report Report)
}
87 changes: 87 additions & 0 deletions internal/attestation/extension/extension.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright 2024 Edgeless Systems GmbH
// SPDX-License-Identifier: AGPL-3.0-only

package extension

import (
"crypto/x509/pkix"
"encoding/asn1"
"fmt"
"math/big"

"golang.org/x/exp/constraints"
)

// NewBigIntExtension returns a new extension containing an unsigned integer value.
func NewBigIntExtension[T constraints.Unsigned](oid asn1.ObjectIdentifier, value T) Extension {
bigInt := &big.Int{}
bigInt.SetUint64(uint64(value))
return bigIntExtension{OID: oid, Val: bigInt}
}

type bigIntExtension struct {
OID asn1.ObjectIdentifier
Val *big.Int
}

func (b bigIntExtension) toExtension() (pkix.Extension, error) {
bytes, err := asn1.Marshal(b.Val)
if err != nil {
return pkix.Extension{}, fmt.Errorf("marshaling big int: %w", err)
}
return pkix.Extension{Id: b.OID, Value: bytes}, nil
}

// NewBytesExtension returns a new extension containing bytes.
func NewBytesExtension(oid asn1.ObjectIdentifier, val []byte) Extension {
return bytesExtension{OID: oid, Val: val}
}

type bytesExtension struct {
OID asn1.ObjectIdentifier
Val []byte
}

func (b bytesExtension) toExtension() (pkix.Extension, error) {
bytes, err := asn1.Marshal(b.Val)
if err != nil {
return pkix.Extension{}, fmt.Errorf("marshaling bytes: %w", err)
}
return pkix.Extension{Id: b.OID, Value: bytes}, nil
}

// NewBoolExtension returns a new extension containing a boolean value.
func NewBoolExtension(oid asn1.ObjectIdentifier, val bool) Extension {
return boolExtension{OID: oid, Val: val}
}

type boolExtension struct {
OID asn1.ObjectIdentifier
Val bool
}

func (b boolExtension) toExtension() (pkix.Extension, error) {
bytes, err := asn1.Marshal(b.Val)
if err != nil {
return pkix.Extension{}, fmt.Errorf("marshaling bool: %w", err)
}
return pkix.Extension{Id: b.OID, Value: bytes}, nil
}

// Extension is a yet-to-be-marshalled pkix extension.
type Extension interface {
toExtension() (pkix.Extension, error)
}

// ConvertExtensions converts the extensions into pkix extensions.
func ConvertExtensions(extensions []Extension) ([]pkix.Extension, error) {
var exts []pkix.Extension
for _, extension := range extensions {
ext, err := extension.toExtension()
if err != nil {
return nil, fmt.Errorf("converting extension to pkix: %w", err)
}
exts = append(exts, ext)
}
return exts, nil
}
Loading