Skip to content

Commit

Permalink
read/write encrypted filesystem (#24)
Browse files Browse the repository at this point in the history
* read/write encrypted filesystem
  • Loading branch information
hgarvison authored Sep 20, 2023
1 parent dc34755 commit ad3230d
Show file tree
Hide file tree
Showing 31 changed files with 1,166 additions and 452 deletions.
69 changes: 53 additions & 16 deletions cmd/attestation-container/attestation-container.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"context"
"encoding/base64"
"flag"
"log"
"net"
"os"
"path/filepath"
Expand All @@ -19,12 +18,16 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/sirupsen/logrus"
)

var (
socketAddress = flag.String("socket-address", "/tmp/attestation-container.sock", "The socket address of Unix domain socket (UDS)")
platformCertificateServer = flag.String("platform-certificate-server", "", "Server to fetch platform certificate. If set, certificates contained in security context directory are ignored. Value is either 'Azure' or 'AMD'")
insecureVirtual = flag.Bool("insecure-virtual", false, "If set, dummy attestation is returned (INSECURE: do not use in production)")
logLevel = flag.String("loglevel", "warning", "Logging Level: trace, debug, info, warning, error, fatal, panic.")
logFile = flag.String("logfile", "", "Logging Target: An optional file name/path. Omit for console output.")

platformCertificateValue *common.THIMCerts = nil
// UVM Endorsement (UVM reference info)
Expand All @@ -48,10 +51,11 @@ func (s *server) FetchAttestation(ctx context.Context, in *pb.FetchAttestationRe
}
copy(reportData[:], in.GetReportData())
if *insecureVirtual {
log.Println("Serving virtual attestation report")
logrus.Trace("Serving virtual attestation report")
return &pb.FetchAttestationReply{}, nil
}

logrus.Trace("Fetching attestation report...")
reportFetcher, err := attest.NewAttestationReportFetcher()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get attestation report fetcher: %s", err)
Expand All @@ -62,24 +66,30 @@ func (s *server) FetchAttestation(ctx context.Context, in *pb.FetchAttestationRe
return nil, status.Errorf(codes.Internal, "failed to fetch attestation report: %s", err)
}

logrus.Trace("Setting platform certificate...")
var platformCertificate []byte
if platformCertificateValue == nil {
logrus.Trace("Deserializing attestation report...")
var SNPReport attest.SNPAttestationReport
if err = SNPReport.DeserializeReport(reportBytes); err != nil {
return nil, status.Errorf(codes.Internal, "failed to deserialize attestation report: %s", err)
}
var certFetcher attest.CertFetcher
if *platformCertificateServer == "AMD" {
logrus.Trace("Setting AMD Certificate Fetcher...")
certFetcher = attest.DefaultAMDMilanCertFetcherNew()
} else {
// Use "Azure". The value of platformCertificateServer should be already checked.
logrus.Trace("Setting Azure Certificate Fetcher...")
certFetcher = attest.DefaultAzureCertFetcherNew()
}
logrus.Trace("Fetching platform certificate...")
platformCertificate, _, err = certFetcher.GetCertChain(SNPReport.ChipID, SNPReport.ReportedTCB)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to fetch platform certificate: %s", err)
}
} else {
logrus.Trace("Using platform certificate from UVM info...")
platformCertificate = append(platformCertificate, platformCertificateValue.VcekCert...)
platformCertificate = append(platformCertificate, platformCertificateValue.CertificateChain...)
}
Expand All @@ -89,66 +99,93 @@ func (s *server) FetchAttestation(ctx context.Context, in *pb.FetchAttestationRe

func validateFlags() {
if *platformCertificateServer != "" && *platformCertificateServer != "AMD" && *platformCertificateServer != "Azure" {
log.Fatalf("invalid --platform-certificate-server value %s (valid values: 'AMD', 'Azure')", *platformCertificateServer)
logrus.Fatalf("invalid --platform-certificate-server value %s (valid values: 'AMD', 'Azure')", *platformCertificateServer)
}
}

func main() {
flag.Parse()

// if logFile is not set, logrus defaults to stderr
if *logFile != "" {
// If the file doesn't exist, create it. If it exists, append to it.
file, err := os.OpenFile(*logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
logrus.Fatal(err)
}
defer file.Close()
logrus.SetOutput(file)
}

level, err := logrus.ParseLevel(*logLevel)
if err != nil {
logrus.Fatal(err)
}
logrus.SetLevel(level)
logrus.SetFormatter(&logrus.TextFormatter{FullTimestamp: false, DisableQuote: true, DisableTimestamp: true})

validateFlags()

log.Println("Attestation container started.")
logrus.Info("Attestation container started...")

if *insecureVirtual {
log.Printf("Warning: INSECURE virtual: do not use in production!")
logrus.Warn("Warning: INSECURE virtual: do not use in production!")
} else {
logrus.Trace("Checking if SNP device is detected...")
if attest.IsSNPVM5() {
log.Printf("%s is detected\n", attest.SNP_DEVICE_PATH_5)
logrus.Tracef("%s is detected\n", attest.SNP_DEVICE_PATH_5)
} else if attest.IsSNPVM6() {
log.Printf("%s is detected\n", attest.SNP_DEVICE_PATH_6)
logrus.Tracef("%s is detected\n", attest.SNP_DEVICE_PATH_6)
} else {
log.Fatalf("attestation-container is not running in SNP enabled VM")
logrus.Fatalf("attestation-container is not running in SNP enabled VM")
}

logrus.Trace("Getting UVM Information...")
uvmInfo, err := common.GetUvmInformation()
if err != nil {
log.Fatalf("Failed to get UVM information: %s", err)
logrus.Fatalf("Failed to get UVM information: %s", err)
}

logrus.Trace("Setting platform certificate server...")
if *platformCertificateServer == "" {
platformCertificateValue = &uvmInfo.InitialCerts
} else {
log.Printf("Platform certificates will be retrieved from server %s", *platformCertificateServer)
logrus.Tracef("Platform certificates will be retrieved from server %s", *platformCertificateServer)
}

logrus.Trace("Decoding UVM reference info...")
uvmEndorsementValue, err = base64.StdEncoding.DecodeString(uvmInfo.EncodedUvmReferenceInfo)
if err != nil {
log.Fatalf("Failed to decode base64 string: %s", err)
logrus.Fatalf("Failed to decode base64 string: %s", err)
}
}

// Cleanup
if _, err := os.Stat(*socketAddress); err == nil {
if err := os.RemoveAll(*socketAddress); err != nil {
log.Fatalf("Failed to clean up socket: %s", err)
logrus.Fatalf("Failed to clean up socket: %s", err)
} else {
logrus.Infof("Cleaned existing socket %s", *socketAddress)
}
} else {
logrus.Debugf("Failed to stat socket %s", *socketAddress)
}

// Create parent directory for socketAddress
socketDir := filepath.Dir(*socketAddress)
// os.MkdirAll doesn't return error when the directory already exists
if err := os.MkdirAll(socketDir, os.ModePerm); err != nil {
log.Fatalf("Failed to create directory for Unix domain socket: %s", err)
logrus.Fatalf("Failed to create directory for Unix domain socket: %s", err)
}

lis, err := net.Listen("unix", *socketAddress)
if err != nil {
log.Fatalf("failed to listen: %v", err)
logrus.Fatalf("Failed to listen: %v", err)
}
s := grpc.NewServer()
pb.RegisterAttestationContainerServer(s, &server{})
log.Printf("Server listening at %v", lis.Addr())
logrus.Infof("Server listening at %v", lis.Addr())
if err := s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
logrus.Fatalf("Failed to serve: %v", err)
}
}
75 changes: 58 additions & 17 deletions cmd/azmount/filemanager/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,42 @@ func tokenRefresher(credential azblob.TokenCredential) (t time.Duration) {
currentToken := credential.Token()
// JWT tokens comprise three fields. the second field is the payload (or claims).
// we care about the `aud` attribute of the payload
curentTokenFields := strings.Split(currentToken, ".")
payload, _ := base64.RawURLEncoding.DecodeString(curentTokenFields[1])
currentTokenFields := strings.Split(currentToken, ".")
logrus.Debugf("Current token fields: %v", currentTokenFields)

payload, err := base64.RawURLEncoding.DecodeString(currentTokenFields[1])
if err != nil {
logrus.Errorf("Error decoding base64 token payload: %s", err)
return 0
}
logrus.Debugf("Current token payload: %s", string(payload))

var payloadMap map[string]interface{}
json.Unmarshal([]byte(payload), &payloadMap)
err = json.Unmarshal([]byte(payload), &payloadMap)
if err != nil {
logrus.Errorf("Error unmarshalling token payload: %s", err)
return 0
}
audience := payloadMap["aud"].(string)

identity := common.Identity{
ClientId: payloadMap["appid"].(string),
}

// retrieve token using the existing's token audience
// retrieve token using the existing token audience
logrus.Debugf("Retrieving new token for audience %s and identity %s", audience, identity)
refreshToken, err := common.GetToken(audience, identity)

if err != nil {
logrus.Errorf("Error retrieving token: %s", err)
return 0
}
logrus.Debugf("Retrieved new token: %s", refreshToken.AccessToken)

// Duration expects nanosecond count
ExpiresInSeconds, err := strconv.ParseInt(refreshToken.ExpiresIn, 10, 64)
if err != nil {
logrus.Errorf("Error parsing token expiration to seconds: %s", err)
return 0
}
credential.SetToken(refreshToken.AccessToken)
Expand All @@ -67,74 +84,98 @@ func AzureSetup(urlString string, urlPrivate bool, identity common.Identity) err
// deserialization of HTTP response payloads, and more:
//
// https://pkg.go.dev/github.com/Azure/azure-storage-blob-go/azblob#hdr-URL_Types
logrus.Infof("Connecting to Azure...")
logrus.Info("Connecting to Azure...")
u, err := url.Parse(urlString)
if err != nil {
return errors.Wrapf(err, "can't parse URL string")
return errors.Wrapf(err, "Can't parse URL string %s", urlString)
}

if urlPrivate {
// we use token credentials to access private azure blob storage the blob's
// url Host denotes the scope/audience for which we need to get a token
logrus.Infof("Using token credentials")
logrus.Trace("Using token credentials to access private azure blob storage...")

var token common.TokenResponse
count := 0
logrus.Debugf("Getting token for https://%s", u.Host)
for {
token, err = common.GetToken("https://"+u.Host, identity)

if err != nil {
logrus.Infof("can't obtain a token required for accessing private blobs. will retry in case the ACI identity sidecar is not running yet.")
logrus.Info("Can't obtain a token required for accessing private blobs. Will retry in case the ACI identity sidecar is not running yet...")
time.Sleep(3 * time.Second)
count++
if count == 20 {
return errors.Wrapf(err, "timeout of 60 seconds expired. could not obtained token")
return errors.Wrapf(err, "Timeout of 60 seconds expired. Could not obtain token")
}
} else {
logrus.Infof("token obtained: %s continuing", token.AccessToken)
logrus.Debugf("Token obtained: %s", token.AccessToken)
break
}
}

tokenCredential := azblob.NewTokenCredential(token.AccessToken, tokenRefresher)
logrus.Debugf("Token credential created: %s", tokenCredential.Token())
fm.blobURL = azblob.NewPageBlobURL(*u, azblob.NewPipeline(tokenCredential, azblob.PipelineOptions{}))
logrus.Debugf("Blob URL created: %s", fm.blobURL)
} else {
// we can use anonymous credentials to access public azure blob storage
logrus.Infof("Using anonymous credentials")
logrus.Trace("Using anonymous credentials to access public azure blob storage...")

anonCredential := azblob.NewAnonymousCredential()
logrus.Debugf("Anonymous credential created: %s", anonCredential)
fm.blobURL = azblob.NewPageBlobURL(*u, azblob.NewPipeline(anonCredential, azblob.PipelineOptions{}))
logrus.Debugf("Blob URL created: %s", fm.blobURL)
}

// Use a never-expiring context
fm.ctx = context.Background()
logrus.Infof("Getting size of file...")

logrus.Trace("Getting size of file...")
// Get file size
getMetadata, err := fm.blobURL.GetProperties(fm.ctx, azblob.BlobAccessConditions{},
azblob.ClientProvidedKeyOptions{})
if err != nil {
return errors.Wrapf(err, "can't get size")
return errors.Wrapf(err, "Can't get blob file size")
}
fm.contentLength = getMetadata.ContentLength()
logrus.Infof("Size: %d bytes", fm.contentLength)
logrus.Tracef("Blob Size: %d bytes", fm.contentLength)

// Setup data downloader
// Setup data downloader and uploader
fm.downloadBlock = AzureDownloadBlock
fm.uploadBlock = AzureUploadBlock

return nil
}

func AzureUploadBlock(blockIndex int64, b []byte) (err error) {
logrus.Info("Uploading block...")
bytesInBlock := GetBlockSize()
var offset int64 = blockIndex * bytesInBlock
logrus.Tracef("Block offset %d = block index %d * bytes in block %d", offset, blockIndex, bytesInBlock)

r := bytes.NewReader(b)
_, err = fm.blobURL.UploadPages(fm.ctx, offset, r, azblob.PageBlobAccessConditions{},
nil, azblob.NewClientProvidedKeyOptions(nil, nil, nil))
if err != nil {
return errors.Wrapf(err, "Can't upload block")
}

return nil
}

func AzureDownloadBlock(blockIndex int64) (err error, b []byte) {
logrus.Info("Downloading block...")
bytesInBlock := GetBlockSize()
var offset int64 = blockIndex * bytesInBlock
logrus.Tracef("Block offset %d = block index %d * bytes in block %d", offset, blockIndex, bytesInBlock)
var count int64 = bytesInBlock

get, err := fm.blobURL.Download(fm.ctx, offset, count, azblob.BlobAccessConditions{},
false, azblob.ClientProvidedKeyOptions{})
if err != nil {
var empty []byte
return errors.Wrapf(err, "can't download block"), empty
return errors.Wrapf(err, "Can't download block"), empty
}

blobData := &bytes.Buffer{}
Expand All @@ -145,7 +186,7 @@ func AzureDownloadBlock(blockIndex int64) (err error, b []byte) {

if err != nil {
var empty []byte
return errors.Wrapf(err, "ReadFrom() failed"), empty
return errors.Wrapf(err, "ReadFrom() failed for block"), empty
}

return nil, blobData.Bytes()
Expand Down
Loading

0 comments on commit ad3230d

Please sign in to comment.