diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 1a05c919..7ef309d6 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -4,7 +4,7 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] + branches: [ main, Proton ] jobs: @@ -44,4 +44,4 @@ jobs: run: go test -v ./... -run RandomizeFast -count=512 - name: Randomized test suite 2 - run: go test -v ./... -run RandomizeSlow -count=32 \ No newline at end of file + run: go test -v ./... -run RandomizeSlow -count=32 diff --git a/openpgp/ecdh/ecdh.go b/openpgp/ecdh/ecdh.go index db8fb163..85a06b17 100644 --- a/openpgp/ecdh/ecdh.go +++ b/openpgp/ecdh/ecdh.go @@ -12,13 +12,50 @@ import ( "io" "github.com/ProtonMail/go-crypto/openpgp/aes/keywrap" + pgperrors "github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" "github.com/ProtonMail/go-crypto/openpgp/internal/ecc" + "github.com/ProtonMail/go-crypto/openpgp/internal/ecc/curve25519" +) + +const ( + KDFVersion1 = 1 + KDFVersionForwarding = 255 ) type KDF struct { - Hash algorithm.Hash - Cipher algorithm.Cipher + Version int // Defaults to v1; 255 for forwarding + Hash algorithm.Hash + Cipher algorithm.Cipher + ReplacementFingerprint []byte // (forwarding only) fingerprint to use instead of recipient's (20 octets) +} + +func (kdf *KDF) Serialize(w io.Writer) (err error) { + switch kdf.Version { + case 0, KDFVersion1: // Default to v1 if unspecified + return kdf.serializeForHash(w) + case KDFVersionForwarding: + // Length || Version || Hash || Cipher || Replacement Fingerprint + length := byte(3 + len(kdf.ReplacementFingerprint)) + if _, err := w.Write([]byte{length, KDFVersionForwarding, kdf.Hash.Id(), kdf.Cipher.Id()}); err != nil { + return err + } + if _, err := w.Write(kdf.ReplacementFingerprint); err != nil { + return err + } + + return nil + default: + return errors.New("ecdh: invalid KDF version") + } +} + +func (kdf *KDF) serializeForHash(w io.Writer) (err error) { + // Length || Version || Hash || Cipher + if _, err := w.Write([]byte{3, KDFVersion1, kdf.Hash.Id(), kdf.Cipher.Id()}); err != nil { + return err + } + return nil } type PublicKey struct { @@ -32,13 +69,10 @@ type PrivateKey struct { D []byte } -func NewPublicKey(curve ecc.ECDHCurve, kdfHash algorithm.Hash, kdfCipher algorithm.Cipher) *PublicKey { +func NewPublicKey(curve ecc.ECDHCurve, kdf KDF) *PublicKey { return &PublicKey{ curve: curve, - KDF: KDF{ - Hash: kdfHash, - Cipher: kdfCipher, - }, + KDF: kdf, } } @@ -149,21 +183,31 @@ func Decrypt(priv *PrivateKey, vsG, c, curveOID, fingerprint []byte) (msg []byte } func buildKey(pub *PublicKey, zb []byte, curveOID, fingerprint []byte, stripLeading, stripTrailing bool) ([]byte, error) { - // Param = curve_OID_len || curve_OID || public_key_alg_ID || 03 - // || 01 || KDF_hash_ID || KEK_alg_ID for AESKeyWrap + // Param = curve_OID_len || curve_OID || public_key_alg_ID + // || KDF_params for AESKeyWrap // || "Anonymous Sender " || recipient_fingerprint; param := new(bytes.Buffer) if _, err := param.Write(curveOID); err != nil { return nil, err } - algKDF := []byte{18, 3, 1, pub.KDF.Hash.Id(), pub.KDF.Cipher.Id()} - if _, err := param.Write(algKDF); err != nil { + algo := []byte{18} + if _, err := param.Write(algo); err != nil { + return nil, err + } + + if err := pub.KDF.serializeForHash(param); err != nil { return nil, err } + if _, err := param.Write([]byte("Anonymous Sender ")); err != nil { return nil, err } - if _, err := param.Write(fingerprint[:]); err != nil { + + if pub.KDF.ReplacementFingerprint != nil { + fingerprint = pub.KDF.ReplacementFingerprint + } + + if _, err := param.Write(fingerprint); err != nil { return nil, err } @@ -204,3 +248,40 @@ func buildKey(pub *PublicKey, zb []byte, curveOID, fingerprint []byte, stripLead func Validate(priv *PrivateKey) error { return priv.curve.ValidateECDH(priv.Point, priv.D) } + +func DeriveProxyParam(recipientKey, forwardeeKey *PrivateKey) (proxyParam []byte, err error) { + if recipientKey.GetCurve().GetCurveName() != "curve25519" { + return nil, pgperrors.InvalidArgumentError("recipient subkey is not curve25519") + } + + if forwardeeKey.GetCurve().GetCurveName() != "curve25519" { + return nil, pgperrors.InvalidArgumentError("forwardee subkey is not curve25519") + } + + c := ecc.NewCurve25519() + + // Clamp and reverse two secrets + proxyParam, err = curve25519.DeriveProxyParam(c.MarshalByteSecret(recipientKey.D), c.MarshalByteSecret(forwardeeKey.D)) + + return proxyParam, err +} + +func ProxyTransform(ephemeral, proxyParam []byte) ([]byte, error) { + c := ecc.NewCurve25519() + + parsedEphemeral := c.UnmarshalBytePoint(ephemeral) + if parsedEphemeral == nil { + return nil, pgperrors.InvalidArgumentError("invalid ephemeral") + } + + if len(proxyParam) != curve25519.ParamSize { + return nil, pgperrors.InvalidArgumentError("invalid proxy parameter") + } + + transformed, err := curve25519.ProxyTransform(parsedEphemeral, proxyParam) + if err != nil { + return nil, err + } + + return c.MarshalBytePoint(transformed), nil +} diff --git a/openpgp/ecdh/ecdh_test.go b/openpgp/ecdh/ecdh_test.go index 1f70b7dd..0e79778f 100644 --- a/openpgp/ecdh/ecdh_test.go +++ b/openpgp/ecdh/ecdh_test.go @@ -42,7 +42,7 @@ func TestCurves(t *testing.T) { } func testGenerate(t *testing.T, curve ecc.ECDHCurve) *PrivateKey { - kdf := KDF{ + kdf := KDF { Hash: algorithm.SHA512, Cipher: algorithm.AES256, } @@ -88,7 +88,7 @@ func testMarshalUnmarshal(t *testing.T, priv *PrivateKey) { p := priv.MarshalPoint() d := priv.MarshalByteSecret() - parsed := NewPrivateKey(*NewPublicKey(priv.GetCurve(), priv.KDF.Hash, priv.KDF.Cipher)) + parsed := NewPrivateKey(*NewPublicKey(priv.GetCurve(), priv.KDF)) if err := parsed.UnmarshalPoint(p); err != nil { t.Fatalf("unable to unmarshal point: %s", err) @@ -112,3 +112,37 @@ func testMarshalUnmarshal(t *testing.T, priv *PrivateKey) { t.Fatal("failed to marshal/unmarshal correctly") } } + +func TestKDFParamsWrite(t *testing.T) { + kdf := KDF{ + Hash: algorithm.SHA512, + Cipher: algorithm.AES256, + } + byteBuffer := new(bytes.Buffer) + + testFingerprint := make([]byte, 20) + + expectBytesV1 := []byte{3, 1, kdf.Hash.Id(), kdf.Cipher.Id()} + kdf.Serialize(byteBuffer) + gotBytes := byteBuffer.Bytes() + if !bytes.Equal(gotBytes, expectBytesV1) { + t.Errorf("error serializing KDF params, got %x, want: %x", gotBytes, expectBytesV1) + } + byteBuffer.Reset() + + kdfV2 := KDF{ + Version: KDFVersionForwarding, + Hash: algorithm.SHA512, + Cipher: algorithm.AES256, + ReplacementFingerprint: testFingerprint, + } + expectBytesV2 := []byte{23, 0xFF, kdfV2.Hash.Id(), kdfV2.Cipher.Id()} + expectBytesV2 = append(expectBytesV2, testFingerprint...) + + kdfV2.Serialize(byteBuffer) + gotBytes = byteBuffer.Bytes() + if !bytes.Equal(gotBytes, expectBytesV2) { + t.Errorf("error serializing KDF params v2, got %x, want: %x", gotBytes, expectBytesV2) + } + byteBuffer.Reset() +} diff --git a/openpgp/errors/errors.go b/openpgp/errors/errors.go index c42b01cb..c20292c6 100644 --- a/openpgp/errors/errors.go +++ b/openpgp/errors/errors.go @@ -33,6 +33,8 @@ func (i InvalidArgumentError) Error() string { return "openpgp: invalid argument: " + string(i) } +var InvalidForwardeeKeyError = InvalidArgumentError("invalid forwardee key") + // SignatureError indicates that a syntactically valid signature failed to // validate. type SignatureError string diff --git a/openpgp/forwarding.go b/openpgp/forwarding.go new file mode 100644 index 00000000..ae45c3c2 --- /dev/null +++ b/openpgp/forwarding.go @@ -0,0 +1,163 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package openpgp + +import ( + goerrors "errors" + + "github.com/ProtonMail/go-crypto/openpgp/ecdh" + "github.com/ProtonMail/go-crypto/openpgp/errors" + "github.com/ProtonMail/go-crypto/openpgp/packet" +) + +// NewForwardingEntity generates a new forwardee key and derives the proxy parameters from the entity e. +// If strict, it will return an error if encryption-capable non-revoked subkeys with a wrong algorithm are found, +// instead of ignoring them +func (e *Entity) NewForwardingEntity( + name, comment, email string, config *packet.Config, strict bool, +) ( + forwardeeKey *Entity, instances []packet.ForwardingInstance, err error, +) { + if e.PrimaryKey.Version != 4 { + return nil, nil, errors.InvalidArgumentError("unsupported key version") + } + + now := config.Now() + i := e.PrimaryIdentity() + if e.PrimaryKey.KeyExpired(i.SelfSignature, now) || // primary key has expired + i.SelfSignature.SigExpired(now) || // user ID self-signature has expired + e.Revoked(now) || // primary key has been revoked + i.Revoked(now) { // user ID has been revoked + return nil, nil, errors.InvalidArgumentError("primary key is expired") + } + + // Generate a new Primary key for the forwardee + config.Algorithm = packet.PubKeyAlgoEdDSA + config.Curve = packet.Curve25519 + keyLifetimeSecs := config.KeyLifetime() + + forwardeePrimaryPrivRaw, err := newSigner(config) + if err != nil { + return nil, nil, err + } + + primary := packet.NewSignerPrivateKey(now, forwardeePrimaryPrivRaw) + + forwardeeKey = &Entity{ + PrimaryKey: &primary.PublicKey, + PrivateKey: primary, + Identities: make(map[string]*Identity), + Subkeys: []Subkey{}, + } + + err = forwardeeKey.addUserId(name, comment, email, config, now, keyLifetimeSecs, true) + if err != nil { + return nil, nil, err + } + + // Init empty instances + instances = []packet.ForwardingInstance{} + + // Handle all forwarder subkeys + for _, forwarderSubKey := range e.Subkeys { + // Filter flags + if !forwarderSubKey.PublicKey.PubKeyAlgo.CanEncrypt() { + continue + } + + // Filter expiration & revokal + if forwarderSubKey.PublicKey.KeyExpired(forwarderSubKey.Sig, now) || + forwarderSubKey.Sig.SigExpired(now) || + forwarderSubKey.Revoked(now) { + continue + } + + if forwarderSubKey.PublicKey.PubKeyAlgo != packet.PubKeyAlgoECDH { + if strict { + return nil, nil, errors.InvalidArgumentError("encryption subkey is not algorithm 18 (ECDH)") + } else { + continue + } + } + + forwarderEcdhKey, ok := forwarderSubKey.PrivateKey.PrivateKey.(*ecdh.PrivateKey) + if !ok { + return nil, nil, errors.InvalidArgumentError("malformed key") + } + + err = forwardeeKey.addEncryptionSubkey(config, now, 0) + if err != nil { + return nil, nil, err + } + + forwardeeSubKey := forwardeeKey.Subkeys[len(forwardeeKey.Subkeys)-1] + + forwardeeEcdhKey, ok := forwardeeSubKey.PrivateKey.PrivateKey.(*ecdh.PrivateKey) + if !ok { + return nil, nil, goerrors.New("wrong forwarding sub key generation") + } + + instance := packet.ForwardingInstance{ + KeyVersion: 4, + ForwarderFingerprint: forwarderSubKey.PublicKey.Fingerprint, + } + + instance.ProxyParameter, err = ecdh.DeriveProxyParam(forwarderEcdhKey, forwardeeEcdhKey) + if err != nil { + return nil, nil, err + } + + kdf := ecdh.KDF{ + Version: ecdh.KDFVersionForwarding, + Hash: forwarderEcdhKey.KDF.Hash, + Cipher: forwarderEcdhKey.KDF.Cipher, + } + + // If deriving a forwarding key from a forwarding key + if forwarderSubKey.Sig.FlagForward { + if forwarderEcdhKey.KDF.Version != ecdh.KDFVersionForwarding { + return nil, nil, goerrors.New("malformed forwarder key") + } + kdf.ReplacementFingerprint = forwarderEcdhKey.KDF.ReplacementFingerprint + } else { + kdf.ReplacementFingerprint = forwarderSubKey.PublicKey.Fingerprint + } + + err = forwardeeSubKey.PublicKey.ReplaceKDF(kdf) + if err != nil { + return nil, nil, err + } + + // Extract fingerprint after changing the KDF + instance.ForwardeeFingerprint = forwardeeSubKey.PublicKey.Fingerprint + + // 0x04 - This key may be used to encrypt communications. + forwardeeSubKey.Sig.FlagEncryptCommunications = false + + // 0x08 - This key may be used to encrypt storage. + forwardeeSubKey.Sig.FlagEncryptStorage = false + + // 0x10 - The private component of this key may have been split by a secret-sharing mechanism. + forwardeeSubKey.Sig.FlagSplitKey = true + + // 0x40 - This key may be used for forwarded communications. + forwardeeSubKey.Sig.FlagForward = true + + // Re-sign subkey binding signature + err = forwardeeSubKey.Sig.SignKey(forwardeeSubKey.PublicKey, forwardeeKey.PrivateKey, config) + if err != nil { + return nil, nil, err + } + + // Append each valid instance to the list + instances = append(instances, instance) + } + + if len(instances) == 0 { + return nil, nil, errors.InvalidArgumentError("no valid subkey found") + } + + return forwardeeKey, instances, nil +} diff --git a/openpgp/forwarding_test.go b/openpgp/forwarding_test.go new file mode 100644 index 00000000..7bc16718 --- /dev/null +++ b/openpgp/forwarding_test.go @@ -0,0 +1,224 @@ +package openpgp + +import ( + "bytes" + "crypto/rand" + goerrors "errors" + "io" + "io/ioutil" + "strings" + "testing" + + "github.com/ProtonMail/go-crypto/openpgp/packet" + "golang.org/x/crypto/openpgp/armor" +) + +const forwardeeKey = `-----BEGIN PGP PRIVATE KEY BLOCK----- + +xVgEZAdtGBYJKwYBBAHaRw8BAQdAcNgHyRGEaqGmzEqEwCobfUkyrJnY8faBvsf9 +R2c5ZzYAAP9bFL4nPBdo04ei0C2IAh5RXOpmuejGC3GAIn/UmL5cYQ+XzRtjaGFy +bGVzIDxjaGFybGVzQHByb3Rvbi5tZT7CigQTFggAPAUCZAdtGAmQFXJtmBzDhdcW +IQRl2gNflypl1XjRUV8Vcm2YHMOF1wIbAwIeAQIZAQILBwIVCAIWAAIiAQAAJKYA +/2qY16Ozyo5erNz51UrKViEoWbEpwY3XaFVNzrw+b54YAQC7zXkf/t5ieylvjmA/ +LJz3/qgH5GxZRYAH9NTpWyW1AsdxBGQHbRgSCisGAQQBl1UBBQEBB0CxmxoJsHTW +TiETWh47ot+kwNA1hCk1IYB9WwKxkXYyIBf/CgmKXzV1ODP/mRmtiBYVV+VQk5MF +EAAA/1NW8D8nMc2ky140sPhQrwkeR7rVLKP2fe5n4BEtAnVQEB3CeAQYFggAKgUC +ZAdtGAmQFXJtmBzDhdcWIQRl2gNflypl1XjRUV8Vcm2YHMOF1wIbUAAAl/8A/iIS +zWBsBR8VnoOVfEE+VQk6YAi7cTSjcMjfsIez9FYtAQDKo9aCMhUohYyqvhZjn8aS +3t9mIZPc+zRJtCHzQYmhDg== +=lESj +-----END PGP PRIVATE KEY BLOCK-----` + +const forwardedMessage = `-----BEGIN PGP MESSAGE----- + +wV4DB27Wn97eACkSAQdA62TlMU2QoGmf5iBLnIm4dlFRkLIg+6MbaatghwxK+Ccw +yGZuVVMAK/ypFfebDf4D/rlEw3cysv213m8aoK8nAUO8xQX3XQq3Sg+EGm0BNV8E +0kABEPyCWARoo5klT1rHPEhelnz8+RQXiOIX3G685XCWdCmaV+tzW082D0xGXSlC +7lM8r1DumNnO8srssko2qIja +=pVRa +-----END PGP MESSAGE-----` + +const forwardedPlaintext = "Message for Bob" + +func TestForwardingStatic(t *testing.T) { + charlesKey, err := ReadArmoredKeyRing(bytes.NewBufferString(forwardeeKey)) + if err != nil { + t.Error(err) + return + } + + ciphertext, err := armor.Decode(strings.NewReader(forwardedMessage)) + if err != nil { + t.Error(err) + return + } + + m, err := ReadMessage(ciphertext.Body, charlesKey, nil, nil) + if err != nil { + t.Fatal(err) + } + + dec, err := ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, []byte(forwardedPlaintext)) { + t.Fatal("forwarded decrypted does not match original") + } +} + +func TestForwardingFull(t *testing.T) { + keyConfig := &packet.Config{ + Algorithm: packet.PubKeyAlgoEdDSA, + Curve: packet.Curve25519, + } + + plaintext := make([]byte, 1024) + rand.Read(plaintext) + + bobEntity, err := NewEntity("bob", "", "bob@proton.me", keyConfig) + if err != nil { + t.Fatal(err) + } + + charlesEntity, instances, err := bobEntity.NewForwardingEntity("charles", "", "charles@proton.me", keyConfig, true) + if err != nil { + t.Fatal(err) + } + + charlesEntity = serializeAndParseForwardeeKey(t, charlesEntity) + + if len(instances) != 1 { + t.Fatalf("invalid number of instances, expected 1 got %d", len(instances)) + } + + if !bytes.Equal(instances[0].ForwarderFingerprint, bobEntity.Subkeys[0].PublicKey.Fingerprint) { + t.Fatalf("invalid forwarder key ID, expected: %x, got: %x", bobEntity.Subkeys[0].PublicKey.Fingerprint, instances[0].ForwarderFingerprint) + } + + if !bytes.Equal(instances[0].ForwardeeFingerprint, charlesEntity.Subkeys[0].PublicKey.Fingerprint) { + t.Fatalf("invalid forwardee key ID, expected: %x, got: %x", charlesEntity.Subkeys[0].PublicKey.Fingerprint, instances[0].ForwardeeFingerprint) + } + + // Encrypt message + buf := bytes.NewBuffer(nil) + w, err := Encrypt(buf, []*Entity{bobEntity}, nil, nil, nil) + if err != nil { + t.Fatal(err) + } + + _, err = w.Write(plaintext) + if err != nil { + t.Fatal(err) + } + + err = w.Close() + if err != nil { + t.Fatal(err) + } + + encrypted := buf.Bytes() + + // Decrypt message for Bob + m, err := ReadMessage(bytes.NewBuffer(encrypted), EntityList([]*Entity{bobEntity}), nil, nil) + if err != nil { + t.Fatal(err) + } + dec, err := ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, plaintext) { + t.Fatal("decrypted does not match original") + } + + // Forward message + + transformed := transformTestMessage(t, encrypted, instances[0]) + + // Decrypt forwarded message for Charles + m, err = ReadMessage(bytes.NewBuffer(transformed), EntityList([]*Entity{charlesEntity}), nil /* no prompt */, nil) + if err != nil { + t.Fatal(err) + } + + dec, err = ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, plaintext) { + t.Fatal("forwarded decrypted does not match original") + } + + // Setup further forwarding + danielEntity, secondForwardInstances, err := charlesEntity.NewForwardingEntity("Daniel", "", "daniel@proton.me", keyConfig, true) + if err != nil { + t.Fatal(err) + } + + danielEntity = serializeAndParseForwardeeKey(t, danielEntity) + + secondTransformed := transformTestMessage(t, transformed, secondForwardInstances[0]) + + // Decrypt forwarded message for Charles + m, err = ReadMessage(bytes.NewBuffer(secondTransformed), EntityList([]*Entity{danielEntity}), nil /* no prompt */, nil) + if err != nil { + t.Fatal(err) + } + + dec, err = ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, plaintext) { + t.Fatal("forwarded decrypted does not match original") + } +} + +func transformTestMessage(t *testing.T, encrypted []byte, instance packet.ForwardingInstance) []byte { + bytesReader := bytes.NewReader(encrypted) + packets := packet.NewReader(bytesReader) + splitPoint := int64(0) + transformedEncryptedKey := bytes.NewBuffer(nil) + +Loop: + for { + p, err := packets.Next() + if goerrors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatalf("error in parsing message: %s", err) + } + switch p := p.(type) { + case *packet.EncryptedKey: + tp, err := p.ProxyTransform(instance) + if err != nil { + t.Fatalf("error transforming PKESK: %s", err) + } + + splitPoint = bytesReader.Size() - int64(bytesReader.Len()) + + err = tp.Serialize(transformedEncryptedKey) + if err != nil { + t.Fatalf("error serializing transformed PKESK: %s", err) + } + break Loop + } + } + + transformed := transformedEncryptedKey.Bytes() + transformed = append(transformed, encrypted[splitPoint:]...) + + return transformed +} + +func serializeAndParseForwardeeKey(t *testing.T, key *Entity) *Entity { + serializedEntity := bytes.NewBuffer(nil) + err := key.SerializePrivateWithoutSigning(serializedEntity, nil) + if err != nil { + t.Fatalf("Error in serializing forwardee key: %s", err) + } + el, err := ReadKeyRing(serializedEntity) + if err != nil { + t.Fatalf("Error in reading forwardee key: %s", err) + } + + if len(el) != 1 { + t.Fatalf("Wrong number of entities in parsing, expected 1, got %d", len(el)) + } + + return el[0] +} diff --git a/openpgp/internal/ecc/curve25519.go b/openpgp/internal/ecc/curve25519.go index 888767c4..a6721ff9 100644 --- a/openpgp/internal/ecc/curve25519.go +++ b/openpgp/internal/ecc/curve25519.go @@ -3,10 +3,9 @@ package ecc import ( "crypto/subtle" - "io" - "github.com/ProtonMail/go-crypto/openpgp/errors" x25519lib "github.com/cloudflare/circl/dh/x25519" + "io" ) type curve25519 struct{} diff --git a/openpgp/internal/ecc/curve25519/curve25519.go b/openpgp/internal/ecc/curve25519/curve25519.go new file mode 100644 index 00000000..21670a82 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/curve25519.go @@ -0,0 +1,122 @@ +// Package curve25519 implements custom field operations without clamping for forwarding. +package curve25519 + +import ( + "crypto/subtle" + "github.com/ProtonMail/go-crypto/openpgp/errors" + "github.com/ProtonMail/go-crypto/openpgp/internal/ecc/curve25519/field" + x25519lib "github.com/cloudflare/circl/dh/x25519" + "math/big" +) + +var curveGroupByte = [x25519lib.Size]byte{ + 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x14, 0xde, 0xf9, 0xde, 0xa2, 0xf7, 0x9c, 0xd6, 0x58, 0x12, 0x63, 0x1a, 0x5c, 0xf5, 0xd3, 0xed, +} + +const ParamSize = x25519lib.Size + +func DeriveProxyParam(recipientSecretByte, forwardeeSecretByte []byte) (proxyParam []byte, err error) { + curveGroup := new(big.Int).SetBytes(curveGroupByte[:]) + recipientSecret := new(big.Int).SetBytes(recipientSecretByte) + forwardeeSecret := new(big.Int).SetBytes(forwardeeSecretByte) + + proxyTransform := new(big.Int).Mod( + new(big.Int).Mul( + new(big.Int).ModInverse(forwardeeSecret, curveGroup), + recipientSecret, + ), + curveGroup, + ) + + rawProxyParam := proxyTransform.Bytes() + + // pad and convert to small endian + proxyParam = make([]byte, x25519lib.Size) + l := len(rawProxyParam) + for i := 0; i < l; i++ { + proxyParam[i] = rawProxyParam[l-i-1] + } + + return proxyParam, nil +} + +func ProxyTransform(ephemeral, proxyParam []byte) ([]byte, error) { + var transformed, safetyCheck [x25519lib.Size]byte + + var scalarEight = make([]byte, x25519lib.Size) + scalarEight[0] = 0x08 + err := ScalarMult(&safetyCheck, scalarEight, ephemeral) + if err != nil { + return nil, err + } + + err = ScalarMult(&transformed, proxyParam, ephemeral) + if err != nil { + return nil, err + } + + return transformed[:], nil +} + +func ScalarMult(dst *[32]byte, scalar, point []byte) error { + var in, base, zero [32]byte + copy(in[:], scalar) + copy(base[:], point) + + scalarMult(dst, &in, &base) + if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 { + return errors.InvalidArgumentError("invalid ephemeral: low order point") + } + + return nil +} + +func scalarMult(dst, scalar, point *[32]byte) { + var e [32]byte + + copy(e[:], scalar[:]) + + var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element + x1.SetBytes(point[:]) + x2.One() + x3.Set(&x1) + z3.One() + + swap := 0 + for pos := 254; pos >= 0; pos-- { + b := e[pos/8] >> uint(pos&7) + b &= 1 + swap ^= int(b) + x2.Swap(&x3, swap) + z2.Swap(&z3, swap) + swap = int(b) + + tmp0.Subtract(&x3, &z3) + tmp1.Subtract(&x2, &z2) + x2.Add(&x2, &z2) + z2.Add(&x3, &z3) + z3.Multiply(&tmp0, &x2) + z2.Multiply(&z2, &tmp1) + tmp0.Square(&tmp1) + tmp1.Square(&x2) + x3.Add(&z3, &z2) + z2.Subtract(&z3, &z2) + x2.Multiply(&tmp1, &tmp0) + tmp1.Subtract(&tmp1, &tmp0) + z2.Square(&z2) + + z3.Mult32(&tmp1, 121666) + x3.Square(&x3) + tmp0.Add(&tmp0, &z3) + z3.Multiply(&x1, &z2) + z2.Multiply(&tmp1, &tmp0) + } + + x2.Swap(&x3, swap) + z2.Swap(&z3, swap) + + z2.Invert(&z2) + x2.Multiply(&x2, &z2) + copy(dst[:], x2.Bytes()) +} diff --git a/openpgp/internal/ecc/curve25519/curve25519_test.go b/openpgp/internal/ecc/curve25519/curve25519_test.go new file mode 100644 index 00000000..88921267 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/curve25519_test.go @@ -0,0 +1,89 @@ +// Package curve25519 implements custom field operations without clamping for forwarding. +package curve25519 + +import ( + "bytes" + "encoding/hex" + "testing" +) + +const ( + hexBobSecret = "5989216365053dcf9e35a04b2a1fc19b83328426be6bb7d0a2ae78105e2e3188" + hexCharlesSecret = "684da6225bcd44d880168fc5bec7d2f746217f014c8019005f144cc148f16a00" + hexExpectedProxyParam = "e89786987c3a3ec761a679bc372cd11a425eda72bd5265d78ad0f5f32ee64f02" + + hexMessagePoint = "aaea7b3bb92f5f545d023ccb15b50f84ba1bdd53be7f5cfadcfb0106859bf77e" + hexInputProxyParam = "83c57cbe645a132477af55d5020281305860201608e81a1de43ff83f245fb302" + hexExpectedTransformedPoint = "ec31bb937d7ef08c451d516be1d7976179aa7171eea598370661d1152b85005a" + + hexSmallSubgroupPoint = "ecffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f" +) + +func TestDeriveProxyParam(t *testing.T) { + bobSecret, err := hex.DecodeString(hexBobSecret) + if err != nil { + t.Fatalf("Unexpected error in decoding recipient secret: %s", err) + } + + charlesSecret, err := hex.DecodeString(hexCharlesSecret) + if err != nil { + t.Fatalf("Unexpected error in decoding forwardee secret: %s", err) + } + + expectedProxyParam, err := hex.DecodeString(hexExpectedProxyParam) + if err != nil { + t.Fatalf("Unexpected error in parameter decoding expected proxy parameter: %s", err) + } + + proxyParam, err := DeriveProxyParam(bobSecret, charlesSecret) + if err != nil { + t.Fatalf("Unexpected error in parameter derivation: %s", err) + } + + if bytes.Compare(proxyParam, expectedProxyParam) != 0 { + t.Errorf("Computed wrong proxy parameter, expected %x got %x", expectedProxyParam, proxyParam) + } +} + +func TestTransformMessage(t *testing.T) { + proxyParam, err := hex.DecodeString(hexInputProxyParam) + if err != nil { + t.Fatalf("Unexpected error in decoding proxy parameter: %s", err) + } + + messagePoint, err := hex.DecodeString(hexMessagePoint) + if err != nil { + t.Fatalf("Unexpected error in decoding message point: %s", err) + } + + expectedTransformed, err := hex.DecodeString(hexExpectedTransformedPoint) + if err != nil { + t.Fatalf("Unexpected error in parameter decoding expected transformed point: %s", err) + } + + transformed, err := ProxyTransform(messagePoint, proxyParam) + if err != nil { + t.Fatalf("Unexpected error in parameter derivation: %s", err) + } + + if bytes.Compare(transformed, expectedTransformed) != 0 { + t.Errorf("Computed wrong proxy parameter, expected %x got %x", expectedTransformed, transformed) + } +} + +func TestTransformSmallSubgroup(t *testing.T) { + proxyParam, err := hex.DecodeString(hexInputProxyParam) + if err != nil { + t.Fatalf("Unexpected error in decoding proxy parameter: %s", err) + } + + messagePoint, err := hex.DecodeString(hexSmallSubgroupPoint) + if err != nil { + t.Fatalf("Unexpected error in decoding small sugroup point: %s", err) + } + + _, err = ProxyTransform(messagePoint, proxyParam) + if err == nil { + t.Error("Expected small subgroup error") + } +} diff --git a/openpgp/internal/ecc/curve25519/field/fe.go b/openpgp/internal/ecc/curve25519/field/fe.go new file mode 100644 index 00000000..ca841ad9 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe.go @@ -0,0 +1,416 @@ +// Copyright (c) 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package field implements fast arithmetic modulo 2^255-19. +package field + +import ( + "crypto/subtle" + "encoding/binary" + "math/bits" +) + +// Element represents an element of the field GF(2^255-19). Note that this +// is not a cryptographically secure group, and should only be used to interact +// with edwards25519.Point coordinates. +// +// This type works similarly to math/big.Int, and all arguments and receivers +// are allowed to alias. +// +// The zero value is a valid zero element. +type Element struct { + // An element t represents the integer + // t.l0 + t.l1*2^51 + t.l2*2^102 + t.l3*2^153 + t.l4*2^204 + // + // Between operations, all limbs are expected to be lower than 2^52. + l0 uint64 + l1 uint64 + l2 uint64 + l3 uint64 + l4 uint64 +} + +const maskLow51Bits uint64 = (1 << 51) - 1 + +var feZero = &Element{0, 0, 0, 0, 0} + +// Zero sets v = 0, and returns v. +func (v *Element) Zero() *Element { + *v = *feZero + return v +} + +var feOne = &Element{1, 0, 0, 0, 0} + +// One sets v = 1, and returns v. +func (v *Element) One() *Element { + *v = *feOne + return v +} + +// reduce reduces v modulo 2^255 - 19 and returns it. +func (v *Element) reduce() *Element { + v.carryPropagate() + + // After the light reduction we now have a field element representation + // v < 2^255 + 2^13 * 19, but need v < 2^255 - 19. + + // If v >= 2^255 - 19, then v + 19 >= 2^255, which would overflow 2^255 - 1, + // generating a carry. That is, c will be 0 if v < 2^255 - 19, and 1 otherwise. + c := (v.l0 + 19) >> 51 + c = (v.l1 + c) >> 51 + c = (v.l2 + c) >> 51 + c = (v.l3 + c) >> 51 + c = (v.l4 + c) >> 51 + + // If v < 2^255 - 19 and c = 0, this will be a no-op. Otherwise, it's + // effectively applying the reduction identity to the carry. + v.l0 += 19 * c + + v.l1 += v.l0 >> 51 + v.l0 = v.l0 & maskLow51Bits + v.l2 += v.l1 >> 51 + v.l1 = v.l1 & maskLow51Bits + v.l3 += v.l2 >> 51 + v.l2 = v.l2 & maskLow51Bits + v.l4 += v.l3 >> 51 + v.l3 = v.l3 & maskLow51Bits + // no additional carry + v.l4 = v.l4 & maskLow51Bits + + return v +} + +// Add sets v = a + b, and returns v. +func (v *Element) Add(a, b *Element) *Element { + v.l0 = a.l0 + b.l0 + v.l1 = a.l1 + b.l1 + v.l2 = a.l2 + b.l2 + v.l3 = a.l3 + b.l3 + v.l4 = a.l4 + b.l4 + // Using the generic implementation here is actually faster than the + // assembly. Probably because the body of this function is so simple that + // the compiler can figure out better optimizations by inlining the carry + // propagation. TODO + return v.carryPropagateGeneric() +} + +// Subtract sets v = a - b, and returns v. +func (v *Element) Subtract(a, b *Element) *Element { + // We first add 2 * p, to guarantee the subtraction won't underflow, and + // then subtract b (which can be up to 2^255 + 2^13 * 19). + v.l0 = (a.l0 + 0xFFFFFFFFFFFDA) - b.l0 + v.l1 = (a.l1 + 0xFFFFFFFFFFFFE) - b.l1 + v.l2 = (a.l2 + 0xFFFFFFFFFFFFE) - b.l2 + v.l3 = (a.l3 + 0xFFFFFFFFFFFFE) - b.l3 + v.l4 = (a.l4 + 0xFFFFFFFFFFFFE) - b.l4 + return v.carryPropagate() +} + +// Negate sets v = -a, and returns v. +func (v *Element) Negate(a *Element) *Element { + return v.Subtract(feZero, a) +} + +// Invert sets v = 1/z mod p, and returns v. +// +// If z == 0, Invert returns v = 0. +func (v *Element) Invert(z *Element) *Element { + // Inversion is implemented as exponentiation with exponent p − 2. It uses the + // same sequence of 255 squarings and 11 multiplications as [Curve25519]. + var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t Element + + z2.Square(z) // 2 + t.Square(&z2) // 4 + t.Square(&t) // 8 + z9.Multiply(&t, z) // 9 + z11.Multiply(&z9, &z2) // 11 + t.Square(&z11) // 22 + z2_5_0.Multiply(&t, &z9) // 31 = 2^5 - 2^0 + + t.Square(&z2_5_0) // 2^6 - 2^1 + for i := 0; i < 4; i++ { + t.Square(&t) // 2^10 - 2^5 + } + z2_10_0.Multiply(&t, &z2_5_0) // 2^10 - 2^0 + + t.Square(&z2_10_0) // 2^11 - 2^1 + for i := 0; i < 9; i++ { + t.Square(&t) // 2^20 - 2^10 + } + z2_20_0.Multiply(&t, &z2_10_0) // 2^20 - 2^0 + + t.Square(&z2_20_0) // 2^21 - 2^1 + for i := 0; i < 19; i++ { + t.Square(&t) // 2^40 - 2^20 + } + t.Multiply(&t, &z2_20_0) // 2^40 - 2^0 + + t.Square(&t) // 2^41 - 2^1 + for i := 0; i < 9; i++ { + t.Square(&t) // 2^50 - 2^10 + } + z2_50_0.Multiply(&t, &z2_10_0) // 2^50 - 2^0 + + t.Square(&z2_50_0) // 2^51 - 2^1 + for i := 0; i < 49; i++ { + t.Square(&t) // 2^100 - 2^50 + } + z2_100_0.Multiply(&t, &z2_50_0) // 2^100 - 2^0 + + t.Square(&z2_100_0) // 2^101 - 2^1 + for i := 0; i < 99; i++ { + t.Square(&t) // 2^200 - 2^100 + } + t.Multiply(&t, &z2_100_0) // 2^200 - 2^0 + + t.Square(&t) // 2^201 - 2^1 + for i := 0; i < 49; i++ { + t.Square(&t) // 2^250 - 2^50 + } + t.Multiply(&t, &z2_50_0) // 2^250 - 2^0 + + t.Square(&t) // 2^251 - 2^1 + t.Square(&t) // 2^252 - 2^2 + t.Square(&t) // 2^253 - 2^3 + t.Square(&t) // 2^254 - 2^4 + t.Square(&t) // 2^255 - 2^5 + + return v.Multiply(&t, &z11) // 2^255 - 21 +} + +// Set sets v = a, and returns v. +func (v *Element) Set(a *Element) *Element { + *v = *a + return v +} + +// SetBytes sets v to x, which must be a 32-byte little-endian encoding. +// +// Consistent with RFC 7748, the most significant bit (the high bit of the +// last byte) is ignored, and non-canonical values (2^255-19 through 2^255-1) +// are accepted. Note that this is laxer than specified by RFC 8032. +func (v *Element) SetBytes(x []byte) *Element { + if len(x) != 32 { + panic("edwards25519: invalid field element input size") + } + + // Bits 0:51 (bytes 0:8, bits 0:64, shift 0, mask 51). + v.l0 = binary.LittleEndian.Uint64(x[0:8]) + v.l0 &= maskLow51Bits + // Bits 51:102 (bytes 6:14, bits 48:112, shift 3, mask 51). + v.l1 = binary.LittleEndian.Uint64(x[6:14]) >> 3 + v.l1 &= maskLow51Bits + // Bits 102:153 (bytes 12:20, bits 96:160, shift 6, mask 51). + v.l2 = binary.LittleEndian.Uint64(x[12:20]) >> 6 + v.l2 &= maskLow51Bits + // Bits 153:204 (bytes 19:27, bits 152:216, shift 1, mask 51). + v.l3 = binary.LittleEndian.Uint64(x[19:27]) >> 1 + v.l3 &= maskLow51Bits + // Bits 204:251 (bytes 24:32, bits 192:256, shift 12, mask 51). + // Note: not bytes 25:33, shift 4, to avoid overread. + v.l4 = binary.LittleEndian.Uint64(x[24:32]) >> 12 + v.l4 &= maskLow51Bits + + return v +} + +// Bytes returns the canonical 32-byte little-endian encoding of v. +func (v *Element) Bytes() []byte { + // This function is outlined to make the allocations inline in the caller + // rather than happen on the heap. + var out [32]byte + return v.bytes(&out) +} + +func (v *Element) bytes(out *[32]byte) []byte { + t := *v + t.reduce() + + var buf [8]byte + for i, l := range [5]uint64{t.l0, t.l1, t.l2, t.l3, t.l4} { + bitsOffset := i * 51 + binary.LittleEndian.PutUint64(buf[:], l<= len(out) { + break + } + out[off] |= bb + } + } + + return out[:] +} + +// Equal returns 1 if v and u are equal, and 0 otherwise. +func (v *Element) Equal(u *Element) int { + sa, sv := u.Bytes(), v.Bytes() + return subtle.ConstantTimeCompare(sa, sv) +} + +// mask64Bits returns 0xffffffff if cond is 1, and 0 otherwise. +func mask64Bits(cond int) uint64 { return ^(uint64(cond) - 1) } + +// Select sets v to a if cond == 1, and to b if cond == 0. +func (v *Element) Select(a, b *Element, cond int) *Element { + m := mask64Bits(cond) + v.l0 = (m & a.l0) | (^m & b.l0) + v.l1 = (m & a.l1) | (^m & b.l1) + v.l2 = (m & a.l2) | (^m & b.l2) + v.l3 = (m & a.l3) | (^m & b.l3) + v.l4 = (m & a.l4) | (^m & b.l4) + return v +} + +// Swap swaps v and u if cond == 1 or leaves them unchanged if cond == 0, and returns v. +func (v *Element) Swap(u *Element, cond int) { + m := mask64Bits(cond) + t := m & (v.l0 ^ u.l0) + v.l0 ^= t + u.l0 ^= t + t = m & (v.l1 ^ u.l1) + v.l1 ^= t + u.l1 ^= t + t = m & (v.l2 ^ u.l2) + v.l2 ^= t + u.l2 ^= t + t = m & (v.l3 ^ u.l3) + v.l3 ^= t + u.l3 ^= t + t = m & (v.l4 ^ u.l4) + v.l4 ^= t + u.l4 ^= t +} + +// IsNegative returns 1 if v is negative, and 0 otherwise. +func (v *Element) IsNegative() int { + return int(v.Bytes()[0] & 1) +} + +// Absolute sets v to |u|, and returns v. +func (v *Element) Absolute(u *Element) *Element { + return v.Select(new(Element).Negate(u), u, u.IsNegative()) +} + +// Multiply sets v = x * y, and returns v. +func (v *Element) Multiply(x, y *Element) *Element { + feMul(v, x, y) + return v +} + +// Square sets v = x * x, and returns v. +func (v *Element) Square(x *Element) *Element { + feSquare(v, x) + return v +} + +// Mult32 sets v = x * y, and returns v. +func (v *Element) Mult32(x *Element, y uint32) *Element { + x0lo, x0hi := mul51(x.l0, y) + x1lo, x1hi := mul51(x.l1, y) + x2lo, x2hi := mul51(x.l2, y) + x3lo, x3hi := mul51(x.l3, y) + x4lo, x4hi := mul51(x.l4, y) + v.l0 = x0lo + 19*x4hi // carried over per the reduction identity + v.l1 = x1lo + x0hi + v.l2 = x2lo + x1hi + v.l3 = x3lo + x2hi + v.l4 = x4lo + x3hi + // The hi portions are going to be only 32 bits, plus any previous excess, + // so we can skip the carry propagation. + return v +} + +// mul51 returns lo + hi * 2⁵¹ = a * b. +func mul51(a uint64, b uint32) (lo uint64, hi uint64) { + mh, ml := bits.Mul64(a, uint64(b)) + lo = ml & maskLow51Bits + hi = (mh << 13) | (ml >> 51) + return +} + +// Pow22523 set v = x^((p-5)/8), and returns v. (p-5)/8 is 2^252-3. +func (v *Element) Pow22523(x *Element) *Element { + var t0, t1, t2 Element + + t0.Square(x) // x^2 + t1.Square(&t0) // x^4 + t1.Square(&t1) // x^8 + t1.Multiply(x, &t1) // x^9 + t0.Multiply(&t0, &t1) // x^11 + t0.Square(&t0) // x^22 + t0.Multiply(&t1, &t0) // x^31 + t1.Square(&t0) // x^62 + for i := 1; i < 5; i++ { // x^992 + t1.Square(&t1) + } + t0.Multiply(&t1, &t0) // x^1023 -> 1023 = 2^10 - 1 + t1.Square(&t0) // 2^11 - 2 + for i := 1; i < 10; i++ { // 2^20 - 2^10 + t1.Square(&t1) + } + t1.Multiply(&t1, &t0) // 2^20 - 1 + t2.Square(&t1) // 2^21 - 2 + for i := 1; i < 20; i++ { // 2^40 - 2^20 + t2.Square(&t2) + } + t1.Multiply(&t2, &t1) // 2^40 - 1 + t1.Square(&t1) // 2^41 - 2 + for i := 1; i < 10; i++ { // 2^50 - 2^10 + t1.Square(&t1) + } + t0.Multiply(&t1, &t0) // 2^50 - 1 + t1.Square(&t0) // 2^51 - 2 + for i := 1; i < 50; i++ { // 2^100 - 2^50 + t1.Square(&t1) + } + t1.Multiply(&t1, &t0) // 2^100 - 1 + t2.Square(&t1) // 2^101 - 2 + for i := 1; i < 100; i++ { // 2^200 - 2^100 + t2.Square(&t2) + } + t1.Multiply(&t2, &t1) // 2^200 - 1 + t1.Square(&t1) // 2^201 - 2 + for i := 1; i < 50; i++ { // 2^250 - 2^50 + t1.Square(&t1) + } + t0.Multiply(&t1, &t0) // 2^250 - 1 + t0.Square(&t0) // 2^251 - 2 + t0.Square(&t0) // 2^252 - 4 + return v.Multiply(&t0, x) // 2^252 - 3 -> x^(2^252-3) +} + +// sqrtM1 is 2^((p-1)/4), which squared is equal to -1 by Euler's Criterion. +var sqrtM1 = &Element{1718705420411056, 234908883556509, + 2233514472574048, 2117202627021982, 765476049583133} + +// SqrtRatio sets r to the non-negative square root of the ratio of u and v. +// +// If u/v is square, SqrtRatio returns r and 1. If u/v is not square, SqrtRatio +// sets r according to Section 4.3 of draft-irtf-cfrg-ristretto255-decaf448-00, +// and returns r and 0. +func (r *Element) SqrtRatio(u, v *Element) (rr *Element, wasSquare int) { + var a, b Element + + // r = (u * v3) * (u * v7)^((p-5)/8) + v2 := a.Square(v) + uv3 := b.Multiply(u, b.Multiply(v2, v)) + uv7 := a.Multiply(uv3, a.Square(v2)) + r.Multiply(uv3, r.Pow22523(uv7)) + + check := a.Multiply(v, a.Square(r)) // check = v * r^2 + + uNeg := b.Negate(u) + correctSignSqrt := check.Equal(u) + flippedSignSqrt := check.Equal(uNeg) + flippedSignSqrtI := check.Equal(uNeg.Multiply(uNeg, sqrtM1)) + + rPrime := b.Multiply(r, sqrtM1) // r_prime = SQRT_M1 * r + // r = CT_SELECT(r_prime IF flipped_sign_sqrt | flipped_sign_sqrt_i ELSE r) + r.Select(rPrime, r, flippedSignSqrt|flippedSignSqrtI) + + r.Absolute(r) // Choose the nonnegative square root. + return r, correctSignSqrt | flippedSignSqrt +} diff --git a/openpgp/internal/ecc/curve25519/field/fe_alias_test.go b/openpgp/internal/ecc/curve25519/field/fe_alias_test.go new file mode 100644 index 00000000..5ad81df0 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_alias_test.go @@ -0,0 +1,126 @@ +// Copyright (c) 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package field + +import ( + "testing" + "testing/quick" +) + +func checkAliasingOneArg(f func(v, x *Element) *Element) func(v, x Element) bool { + return func(v, x Element) bool { + x1, v1 := x, x + + // Calculate a reference f(x) without aliasing. + if out := f(&v, &x); out != &v && isInBounds(out) { + return false + } + + // Test aliasing the argument and the receiver. + if out := f(&v1, &v1); out != &v1 || v1 != v { + return false + } + + // Ensure the arguments was not modified. + return x == x1 + } +} + +func checkAliasingTwoArgs(f func(v, x, y *Element) *Element) func(v, x, y Element) bool { + return func(v, x, y Element) bool { + x1, y1, v1 := x, y, Element{} + + // Calculate a reference f(x, y) without aliasing. + if out := f(&v, &x, &y); out != &v && isInBounds(out) { + return false + } + + // Test aliasing the first argument and the receiver. + v1 = x + if out := f(&v1, &v1, &y); out != &v1 || v1 != v { + return false + } + // Test aliasing the second argument and the receiver. + v1 = y + if out := f(&v1, &x, &v1); out != &v1 || v1 != v { + return false + } + + // Calculate a reference f(x, x) without aliasing. + if out := f(&v, &x, &x); out != &v { + return false + } + + // Test aliasing the first argument and the receiver. + v1 = x + if out := f(&v1, &v1, &x); out != &v1 || v1 != v { + return false + } + // Test aliasing the second argument and the receiver. + v1 = x + if out := f(&v1, &x, &v1); out != &v1 || v1 != v { + return false + } + // Test aliasing both arguments and the receiver. + v1 = x + if out := f(&v1, &v1, &v1); out != &v1 || v1 != v { + return false + } + + // Ensure the arguments were not modified. + return x == x1 && y == y1 + } +} + +// TestAliasing checks that receivers and arguments can alias each other without +// leading to incorrect results. That is, it ensures that it's safe to write +// +// v.Invert(v) +// +// or +// +// v.Add(v, v) +// +// without any of the inputs getting clobbered by the output being written. +func TestAliasing(t *testing.T) { + type target struct { + name string + oneArgF func(v, x *Element) *Element + twoArgsF func(v, x, y *Element) *Element + } + for _, tt := range []target{ + {name: "Absolute", oneArgF: (*Element).Absolute}, + {name: "Invert", oneArgF: (*Element).Invert}, + {name: "Negate", oneArgF: (*Element).Negate}, + {name: "Set", oneArgF: (*Element).Set}, + {name: "Square", oneArgF: (*Element).Square}, + {name: "Multiply", twoArgsF: (*Element).Multiply}, + {name: "Add", twoArgsF: (*Element).Add}, + {name: "Subtract", twoArgsF: (*Element).Subtract}, + { + name: "Select0", + twoArgsF: func(v, x, y *Element) *Element { + return (*Element).Select(v, x, y, 0) + }, + }, + { + name: "Select1", + twoArgsF: func(v, x, y *Element) *Element { + return (*Element).Select(v, x, y, 1) + }, + }, + } { + var err error + switch { + case tt.oneArgF != nil: + err = quick.Check(checkAliasingOneArg(tt.oneArgF), &quick.Config{MaxCountScale: 1 << 8}) + case tt.twoArgsF != nil: + err = quick.Check(checkAliasingTwoArgs(tt.twoArgsF), &quick.Config{MaxCountScale: 1 << 8}) + } + if err != nil { + t.Errorf("%v: %v", tt.name, err) + } + } +} diff --git a/openpgp/internal/ecc/curve25519/field/fe_amd64.go b/openpgp/internal/ecc/curve25519/field/fe_amd64.go new file mode 100644 index 00000000..44dc8e8c --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_amd64.go @@ -0,0 +1,13 @@ +// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT. + +// +build amd64,gc,!purego + +package field + +// feMul sets out = a * b. It works like feMulGeneric. +//go:noescape +func feMul(out *Element, a *Element, b *Element) + +// feSquare sets out = a * a. It works like feSquareGeneric. +//go:noescape +func feSquare(out *Element, a *Element) diff --git a/openpgp/internal/ecc/curve25519/field/fe_amd64.s b/openpgp/internal/ecc/curve25519/field/fe_amd64.s new file mode 100644 index 00000000..293f013c --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_amd64.s @@ -0,0 +1,379 @@ +// Code generated by command: go run fe_amd64_asm.go -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field. DO NOT EDIT. + +//go:build amd64 && gc && !purego +// +build amd64,gc,!purego + +#include "textflag.h" + +// func feMul(out *Element, a *Element, b *Element) +TEXT ·feMul(SB), NOSPLIT, $0-24 + MOVQ a+8(FP), CX + MOVQ b+16(FP), BX + + // r0 = a0×b0 + MOVQ (CX), AX + MULQ (BX) + MOVQ AX, DI + MOVQ DX, SI + + // r0 += 19×a1×b4 + MOVQ 8(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(BX) + ADDQ AX, DI + ADCQ DX, SI + + // r0 += 19×a2×b3 + MOVQ 16(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 24(BX) + ADDQ AX, DI + ADCQ DX, SI + + // r0 += 19×a3×b2 + MOVQ 24(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 16(BX) + ADDQ AX, DI + ADCQ DX, SI + + // r0 += 19×a4×b1 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 8(BX) + ADDQ AX, DI + ADCQ DX, SI + + // r1 = a0×b1 + MOVQ (CX), AX + MULQ 8(BX) + MOVQ AX, R9 + MOVQ DX, R8 + + // r1 += a1×b0 + MOVQ 8(CX), AX + MULQ (BX) + ADDQ AX, R9 + ADCQ DX, R8 + + // r1 += 19×a2×b4 + MOVQ 16(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(BX) + ADDQ AX, R9 + ADCQ DX, R8 + + // r1 += 19×a3×b3 + MOVQ 24(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 24(BX) + ADDQ AX, R9 + ADCQ DX, R8 + + // r1 += 19×a4×b2 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 16(BX) + ADDQ AX, R9 + ADCQ DX, R8 + + // r2 = a0×b2 + MOVQ (CX), AX + MULQ 16(BX) + MOVQ AX, R11 + MOVQ DX, R10 + + // r2 += a1×b1 + MOVQ 8(CX), AX + MULQ 8(BX) + ADDQ AX, R11 + ADCQ DX, R10 + + // r2 += a2×b0 + MOVQ 16(CX), AX + MULQ (BX) + ADDQ AX, R11 + ADCQ DX, R10 + + // r2 += 19×a3×b4 + MOVQ 24(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(BX) + ADDQ AX, R11 + ADCQ DX, R10 + + // r2 += 19×a4×b3 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 24(BX) + ADDQ AX, R11 + ADCQ DX, R10 + + // r3 = a0×b3 + MOVQ (CX), AX + MULQ 24(BX) + MOVQ AX, R13 + MOVQ DX, R12 + + // r3 += a1×b2 + MOVQ 8(CX), AX + MULQ 16(BX) + ADDQ AX, R13 + ADCQ DX, R12 + + // r3 += a2×b1 + MOVQ 16(CX), AX + MULQ 8(BX) + ADDQ AX, R13 + ADCQ DX, R12 + + // r3 += a3×b0 + MOVQ 24(CX), AX + MULQ (BX) + ADDQ AX, R13 + ADCQ DX, R12 + + // r3 += 19×a4×b4 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(BX) + ADDQ AX, R13 + ADCQ DX, R12 + + // r4 = a0×b4 + MOVQ (CX), AX + MULQ 32(BX) + MOVQ AX, R15 + MOVQ DX, R14 + + // r4 += a1×b3 + MOVQ 8(CX), AX + MULQ 24(BX) + ADDQ AX, R15 + ADCQ DX, R14 + + // r4 += a2×b2 + MOVQ 16(CX), AX + MULQ 16(BX) + ADDQ AX, R15 + ADCQ DX, R14 + + // r4 += a3×b1 + MOVQ 24(CX), AX + MULQ 8(BX) + ADDQ AX, R15 + ADCQ DX, R14 + + // r4 += a4×b0 + MOVQ 32(CX), AX + MULQ (BX) + ADDQ AX, R15 + ADCQ DX, R14 + + // First reduction chain + MOVQ $0x0007ffffffffffff, AX + SHLQ $0x0d, DI, SI + SHLQ $0x0d, R9, R8 + SHLQ $0x0d, R11, R10 + SHLQ $0x0d, R13, R12 + SHLQ $0x0d, R15, R14 + ANDQ AX, DI + IMUL3Q $0x13, R14, R14 + ADDQ R14, DI + ANDQ AX, R9 + ADDQ SI, R9 + ANDQ AX, R11 + ADDQ R8, R11 + ANDQ AX, R13 + ADDQ R10, R13 + ANDQ AX, R15 + ADDQ R12, R15 + + // Second reduction chain (carryPropagate) + MOVQ DI, SI + SHRQ $0x33, SI + MOVQ R9, R8 + SHRQ $0x33, R8 + MOVQ R11, R10 + SHRQ $0x33, R10 + MOVQ R13, R12 + SHRQ $0x33, R12 + MOVQ R15, R14 + SHRQ $0x33, R14 + ANDQ AX, DI + IMUL3Q $0x13, R14, R14 + ADDQ R14, DI + ANDQ AX, R9 + ADDQ SI, R9 + ANDQ AX, R11 + ADDQ R8, R11 + ANDQ AX, R13 + ADDQ R10, R13 + ANDQ AX, R15 + ADDQ R12, R15 + + // Store output + MOVQ out+0(FP), AX + MOVQ DI, (AX) + MOVQ R9, 8(AX) + MOVQ R11, 16(AX) + MOVQ R13, 24(AX) + MOVQ R15, 32(AX) + RET + +// func feSquare(out *Element, a *Element) +TEXT ·feSquare(SB), NOSPLIT, $0-16 + MOVQ a+8(FP), CX + + // r0 = l0×l0 + MOVQ (CX), AX + MULQ (CX) + MOVQ AX, SI + MOVQ DX, BX + + // r0 += 38×l1×l4 + MOVQ 8(CX), AX + IMUL3Q $0x26, AX, AX + MULQ 32(CX) + ADDQ AX, SI + ADCQ DX, BX + + // r0 += 38×l2×l3 + MOVQ 16(CX), AX + IMUL3Q $0x26, AX, AX + MULQ 24(CX) + ADDQ AX, SI + ADCQ DX, BX + + // r1 = 2×l0×l1 + MOVQ (CX), AX + SHLQ $0x01, AX + MULQ 8(CX) + MOVQ AX, R8 + MOVQ DX, DI + + // r1 += 38×l2×l4 + MOVQ 16(CX), AX + IMUL3Q $0x26, AX, AX + MULQ 32(CX) + ADDQ AX, R8 + ADCQ DX, DI + + // r1 += 19×l3×l3 + MOVQ 24(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 24(CX) + ADDQ AX, R8 + ADCQ DX, DI + + // r2 = 2×l0×l2 + MOVQ (CX), AX + SHLQ $0x01, AX + MULQ 16(CX) + MOVQ AX, R10 + MOVQ DX, R9 + + // r2 += l1×l1 + MOVQ 8(CX), AX + MULQ 8(CX) + ADDQ AX, R10 + ADCQ DX, R9 + + // r2 += 38×l3×l4 + MOVQ 24(CX), AX + IMUL3Q $0x26, AX, AX + MULQ 32(CX) + ADDQ AX, R10 + ADCQ DX, R9 + + // r3 = 2×l0×l3 + MOVQ (CX), AX + SHLQ $0x01, AX + MULQ 24(CX) + MOVQ AX, R12 + MOVQ DX, R11 + + // r3 += 2×l1×l2 + MOVQ 8(CX), AX + IMUL3Q $0x02, AX, AX + MULQ 16(CX) + ADDQ AX, R12 + ADCQ DX, R11 + + // r3 += 19×l4×l4 + MOVQ 32(CX), AX + IMUL3Q $0x13, AX, AX + MULQ 32(CX) + ADDQ AX, R12 + ADCQ DX, R11 + + // r4 = 2×l0×l4 + MOVQ (CX), AX + SHLQ $0x01, AX + MULQ 32(CX) + MOVQ AX, R14 + MOVQ DX, R13 + + // r4 += 2×l1×l3 + MOVQ 8(CX), AX + IMUL3Q $0x02, AX, AX + MULQ 24(CX) + ADDQ AX, R14 + ADCQ DX, R13 + + // r4 += l2×l2 + MOVQ 16(CX), AX + MULQ 16(CX) + ADDQ AX, R14 + ADCQ DX, R13 + + // First reduction chain + MOVQ $0x0007ffffffffffff, AX + SHLQ $0x0d, SI, BX + SHLQ $0x0d, R8, DI + SHLQ $0x0d, R10, R9 + SHLQ $0x0d, R12, R11 + SHLQ $0x0d, R14, R13 + ANDQ AX, SI + IMUL3Q $0x13, R13, R13 + ADDQ R13, SI + ANDQ AX, R8 + ADDQ BX, R8 + ANDQ AX, R10 + ADDQ DI, R10 + ANDQ AX, R12 + ADDQ R9, R12 + ANDQ AX, R14 + ADDQ R11, R14 + + // Second reduction chain (carryPropagate) + MOVQ SI, BX + SHRQ $0x33, BX + MOVQ R8, DI + SHRQ $0x33, DI + MOVQ R10, R9 + SHRQ $0x33, R9 + MOVQ R12, R11 + SHRQ $0x33, R11 + MOVQ R14, R13 + SHRQ $0x33, R13 + ANDQ AX, SI + IMUL3Q $0x13, R13, R13 + ADDQ R13, SI + ANDQ AX, R8 + ADDQ BX, R8 + ANDQ AX, R10 + ADDQ DI, R10 + ANDQ AX, R12 + ADDQ R9, R12 + ANDQ AX, R14 + ADDQ R11, R14 + + // Store output + MOVQ out+0(FP), AX + MOVQ SI, (AX) + MOVQ R8, 8(AX) + MOVQ R10, 16(AX) + MOVQ R12, 24(AX) + MOVQ R14, 32(AX) + RET diff --git a/openpgp/internal/ecc/curve25519/field/fe_amd64_noasm.go b/openpgp/internal/ecc/curve25519/field/fe_amd64_noasm.go new file mode 100644 index 00000000..ddb6c9b8 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_amd64_noasm.go @@ -0,0 +1,12 @@ +// Copyright (c) 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !amd64 || !gc || purego +// +build !amd64 !gc purego + +package field + +func feMul(v, x, y *Element) { feMulGeneric(v, x, y) } + +func feSquare(v, x *Element) { feSquareGeneric(v, x) } diff --git a/openpgp/internal/ecc/curve25519/field/fe_arm64.go b/openpgp/internal/ecc/curve25519/field/fe_arm64.go new file mode 100644 index 00000000..af459ef5 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_arm64.go @@ -0,0 +1,16 @@ +// Copyright (c) 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build arm64 && gc && !purego +// +build arm64,gc,!purego + +package field + +//go:noescape +func carryPropagate(v *Element) + +func (v *Element) carryPropagate() *Element { + carryPropagate(v) + return v +} diff --git a/openpgp/internal/ecc/curve25519/field/fe_arm64.s b/openpgp/internal/ecc/curve25519/field/fe_arm64.s new file mode 100644 index 00000000..5c91e458 --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_arm64.s @@ -0,0 +1,43 @@ +// Copyright (c) 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build arm64 && gc && !purego +// +build arm64,gc,!purego + +#include "textflag.h" + +// carryPropagate works exactly like carryPropagateGeneric and uses the +// same AND, ADD, and LSR+MADD instructions emitted by the compiler, but +// avoids loading R0-R4 twice and uses LDP and STP. +// +// See https://golang.org/issues/43145 for the main compiler issue. +// +// func carryPropagate(v *Element) +TEXT ·carryPropagate(SB),NOFRAME|NOSPLIT,$0-8 + MOVD v+0(FP), R20 + + LDP 0(R20), (R0, R1) + LDP 16(R20), (R2, R3) + MOVD 32(R20), R4 + + AND $0x7ffffffffffff, R0, R10 + AND $0x7ffffffffffff, R1, R11 + AND $0x7ffffffffffff, R2, R12 + AND $0x7ffffffffffff, R3, R13 + AND $0x7ffffffffffff, R4, R14 + + ADD R0>>51, R11, R11 + ADD R1>>51, R12, R12 + ADD R2>>51, R13, R13 + ADD R3>>51, R14, R14 + // R4>>51 * 19 + R10 -> R10 + LSR $51, R4, R21 + MOVD $19, R22 + MADD R22, R10, R21, R10 + + STP (R10, R11), 0(R20) + STP (R12, R13), 16(R20) + MOVD R14, 32(R20) + + RET diff --git a/openpgp/internal/ecc/curve25519/field/fe_arm64_noasm.go b/openpgp/internal/ecc/curve25519/field/fe_arm64_noasm.go new file mode 100644 index 00000000..234a5b2e --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_arm64_noasm.go @@ -0,0 +1,12 @@ +// Copyright (c) 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !arm64 || !gc || purego +// +build !arm64 !gc purego + +package field + +func (v *Element) carryPropagate() *Element { + return v.carryPropagateGeneric() +} diff --git a/openpgp/internal/ecc/curve25519/field/fe_bench_test.go b/openpgp/internal/ecc/curve25519/field/fe_bench_test.go new file mode 100644 index 00000000..77dc06cf --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_bench_test.go @@ -0,0 +1,36 @@ +// Copyright (c) 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package field + +import "testing" + +func BenchmarkAdd(b *testing.B) { + var x, y Element + x.One() + y.Add(feOne, feOne) + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Add(&x, &y) + } +} + +func BenchmarkMultiply(b *testing.B) { + var x, y Element + x.One() + y.Add(feOne, feOne) + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Multiply(&x, &y) + } +} + +func BenchmarkMult32(b *testing.B) { + var x Element + x.One() + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Mult32(&x, 0xaa42aa42) + } +} diff --git a/openpgp/internal/ecc/curve25519/field/fe_generic.go b/openpgp/internal/ecc/curve25519/field/fe_generic.go new file mode 100644 index 00000000..7b5b78cb --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_generic.go @@ -0,0 +1,264 @@ +// Copyright (c) 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package field + +import "math/bits" + +// uint128 holds a 128-bit number as two 64-bit limbs, for use with the +// bits.Mul64 and bits.Add64 intrinsics. +type uint128 struct { + lo, hi uint64 +} + +// mul64 returns a * b. +func mul64(a, b uint64) uint128 { + hi, lo := bits.Mul64(a, b) + return uint128{lo, hi} +} + +// addMul64 returns v + a * b. +func addMul64(v uint128, a, b uint64) uint128 { + hi, lo := bits.Mul64(a, b) + lo, c := bits.Add64(lo, v.lo, 0) + hi, _ = bits.Add64(hi, v.hi, c) + return uint128{lo, hi} +} + +// shiftRightBy51 returns a >> 51. a is assumed to be at most 115 bits. +func shiftRightBy51(a uint128) uint64 { + return (a.hi << (64 - 51)) | (a.lo >> 51) +} + +func feMulGeneric(v, a, b *Element) { + a0 := a.l0 + a1 := a.l1 + a2 := a.l2 + a3 := a.l3 + a4 := a.l4 + + b0 := b.l0 + b1 := b.l1 + b2 := b.l2 + b3 := b.l3 + b4 := b.l4 + + // Limb multiplication works like pen-and-paper columnar multiplication, but + // with 51-bit limbs instead of digits. + // + // a4 a3 a2 a1 a0 x + // b4 b3 b2 b1 b0 = + // ------------------------ + // a4b0 a3b0 a2b0 a1b0 a0b0 + + // a4b1 a3b1 a2b1 a1b1 a0b1 + + // a4b2 a3b2 a2b2 a1b2 a0b2 + + // a4b3 a3b3 a2b3 a1b3 a0b3 + + // a4b4 a3b4 a2b4 a1b4 a0b4 = + // ---------------------------------------------- + // r8 r7 r6 r5 r4 r3 r2 r1 r0 + // + // We can then use the reduction identity (a * 2²⁵⁵ + b = a * 19 + b) to + // reduce the limbs that would overflow 255 bits. r5 * 2²⁵⁵ becomes 19 * r5, + // r6 * 2³⁰⁶ becomes 19 * r6 * 2⁵¹, etc. + // + // Reduction can be carried out simultaneously to multiplication. For + // example, we do not compute r5: whenever the result of a multiplication + // belongs to r5, like a1b4, we multiply it by 19 and add the result to r0. + // + // a4b0 a3b0 a2b0 a1b0 a0b0 + + // a3b1 a2b1 a1b1 a0b1 19×a4b1 + + // a2b2 a1b2 a0b2 19×a4b2 19×a3b2 + + // a1b3 a0b3 19×a4b3 19×a3b3 19×a2b3 + + // a0b4 19×a4b4 19×a3b4 19×a2b4 19×a1b4 = + // -------------------------------------- + // r4 r3 r2 r1 r0 + // + // Finally we add up the columns into wide, overlapping limbs. + + a1_19 := a1 * 19 + a2_19 := a2 * 19 + a3_19 := a3 * 19 + a4_19 := a4 * 19 + + // r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1) + r0 := mul64(a0, b0) + r0 = addMul64(r0, a1_19, b4) + r0 = addMul64(r0, a2_19, b3) + r0 = addMul64(r0, a3_19, b2) + r0 = addMul64(r0, a4_19, b1) + + // r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2) + r1 := mul64(a0, b1) + r1 = addMul64(r1, a1, b0) + r1 = addMul64(r1, a2_19, b4) + r1 = addMul64(r1, a3_19, b3) + r1 = addMul64(r1, a4_19, b2) + + // r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3) + r2 := mul64(a0, b2) + r2 = addMul64(r2, a1, b1) + r2 = addMul64(r2, a2, b0) + r2 = addMul64(r2, a3_19, b4) + r2 = addMul64(r2, a4_19, b3) + + // r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4 + r3 := mul64(a0, b3) + r3 = addMul64(r3, a1, b2) + r3 = addMul64(r3, a2, b1) + r3 = addMul64(r3, a3, b0) + r3 = addMul64(r3, a4_19, b4) + + // r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0 + r4 := mul64(a0, b4) + r4 = addMul64(r4, a1, b3) + r4 = addMul64(r4, a2, b2) + r4 = addMul64(r4, a3, b1) + r4 = addMul64(r4, a4, b0) + + // After the multiplication, we need to reduce (carry) the five coefficients + // to obtain a result with limbs that are at most slightly larger than 2⁵¹, + // to respect the Element invariant. + // + // Overall, the reduction works the same as carryPropagate, except with + // wider inputs: we take the carry for each coefficient by shifting it right + // by 51, and add it to the limb above it. The top carry is multiplied by 19 + // according to the reduction identity and added to the lowest limb. + // + // The largest coefficient (r0) will be at most 111 bits, which guarantees + // that all carries are at most 111 - 51 = 60 bits, which fits in a uint64. + // + // r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1) + // r0 < 2⁵²×2⁵² + 19×(2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵² + 2⁵²×2⁵²) + // r0 < (1 + 19 × 4) × 2⁵² × 2⁵² + // r0 < 2⁷ × 2⁵² × 2⁵² + // r0 < 2¹¹¹ + // + // Moreover, the top coefficient (r4) is at most 107 bits, so c4 is at most + // 56 bits, and c4 * 19 is at most 61 bits, which again fits in a uint64 and + // allows us to easily apply the reduction identity. + // + // r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0 + // r4 < 5 × 2⁵² × 2⁵² + // r4 < 2¹⁰⁷ + // + + c0 := shiftRightBy51(r0) + c1 := shiftRightBy51(r1) + c2 := shiftRightBy51(r2) + c3 := shiftRightBy51(r3) + c4 := shiftRightBy51(r4) + + rr0 := r0.lo&maskLow51Bits + c4*19 + rr1 := r1.lo&maskLow51Bits + c0 + rr2 := r2.lo&maskLow51Bits + c1 + rr3 := r3.lo&maskLow51Bits + c2 + rr4 := r4.lo&maskLow51Bits + c3 + + // Now all coefficients fit into 64-bit registers but are still too large to + // be passed around as a Element. We therefore do one last carry chain, + // where the carries will be small enough to fit in the wiggle room above 2⁵¹. + *v = Element{rr0, rr1, rr2, rr3, rr4} + v.carryPropagate() +} + +func feSquareGeneric(v, a *Element) { + l0 := a.l0 + l1 := a.l1 + l2 := a.l2 + l3 := a.l3 + l4 := a.l4 + + // Squaring works precisely like multiplication above, but thanks to its + // symmetry we get to group a few terms together. + // + // l4 l3 l2 l1 l0 x + // l4 l3 l2 l1 l0 = + // ------------------------ + // l4l0 l3l0 l2l0 l1l0 l0l0 + + // l4l1 l3l1 l2l1 l1l1 l0l1 + + // l4l2 l3l2 l2l2 l1l2 l0l2 + + // l4l3 l3l3 l2l3 l1l3 l0l3 + + // l4l4 l3l4 l2l4 l1l4 l0l4 = + // ---------------------------------------------- + // r8 r7 r6 r5 r4 r3 r2 r1 r0 + // + // l4l0 l3l0 l2l0 l1l0 l0l0 + + // l3l1 l2l1 l1l1 l0l1 19×l4l1 + + // l2l2 l1l2 l0l2 19×l4l2 19×l3l2 + + // l1l3 l0l3 19×l4l3 19×l3l3 19×l2l3 + + // l0l4 19×l4l4 19×l3l4 19×l2l4 19×l1l4 = + // -------------------------------------- + // r4 r3 r2 r1 r0 + // + // With precomputed 2×, 19×, and 2×19× terms, we can compute each limb with + // only three Mul64 and four Add64, instead of five and eight. + + l0_2 := l0 * 2 + l1_2 := l1 * 2 + + l1_38 := l1 * 38 + l2_38 := l2 * 38 + l3_38 := l3 * 38 + + l3_19 := l3 * 19 + l4_19 := l4 * 19 + + // r0 = l0×l0 + 19×(l1×l4 + l2×l3 + l3×l2 + l4×l1) = l0×l0 + 19×2×(l1×l4 + l2×l3) + r0 := mul64(l0, l0) + r0 = addMul64(r0, l1_38, l4) + r0 = addMul64(r0, l2_38, l3) + + // r1 = l0×l1 + l1×l0 + 19×(l2×l4 + l3×l3 + l4×l2) = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3 + r1 := mul64(l0_2, l1) + r1 = addMul64(r1, l2_38, l4) + r1 = addMul64(r1, l3_19, l3) + + // r2 = l0×l2 + l1×l1 + l2×l0 + 19×(l3×l4 + l4×l3) = 2×l0×l2 + l1×l1 + 19×2×l3×l4 + r2 := mul64(l0_2, l2) + r2 = addMul64(r2, l1, l1) + r2 = addMul64(r2, l3_38, l4) + + // r3 = l0×l3 + l1×l2 + l2×l1 + l3×l0 + 19×l4×l4 = 2×l0×l3 + 2×l1×l2 + 19×l4×l4 + r3 := mul64(l0_2, l3) + r3 = addMul64(r3, l1_2, l2) + r3 = addMul64(r3, l4_19, l4) + + // r4 = l0×l4 + l1×l3 + l2×l2 + l3×l1 + l4×l0 = 2×l0×l4 + 2×l1×l3 + l2×l2 + r4 := mul64(l0_2, l4) + r4 = addMul64(r4, l1_2, l3) + r4 = addMul64(r4, l2, l2) + + c0 := shiftRightBy51(r0) + c1 := shiftRightBy51(r1) + c2 := shiftRightBy51(r2) + c3 := shiftRightBy51(r3) + c4 := shiftRightBy51(r4) + + rr0 := r0.lo&maskLow51Bits + c4*19 + rr1 := r1.lo&maskLow51Bits + c0 + rr2 := r2.lo&maskLow51Bits + c1 + rr3 := r3.lo&maskLow51Bits + c2 + rr4 := r4.lo&maskLow51Bits + c3 + + *v = Element{rr0, rr1, rr2, rr3, rr4} + v.carryPropagate() +} + +// carryPropagate brings the limbs below 52 bits by applying the reduction +// identity (a * 2²⁵⁵ + b = a * 19 + b) to the l4 carry. TODO inline +func (v *Element) carryPropagateGeneric() *Element { + c0 := v.l0 >> 51 + c1 := v.l1 >> 51 + c2 := v.l2 >> 51 + c3 := v.l3 >> 51 + c4 := v.l4 >> 51 + + v.l0 = v.l0&maskLow51Bits + c4*19 + v.l1 = v.l1&maskLow51Bits + c0 + v.l2 = v.l2&maskLow51Bits + c1 + v.l3 = v.l3&maskLow51Bits + c2 + v.l4 = v.l4&maskLow51Bits + c3 + + return v +} diff --git a/openpgp/internal/ecc/curve25519/field/fe_test.go b/openpgp/internal/ecc/curve25519/field/fe_test.go new file mode 100644 index 00000000..b484459f --- /dev/null +++ b/openpgp/internal/ecc/curve25519/field/fe_test.go @@ -0,0 +1,558 @@ +// Copyright (c) 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package field + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "io" + "math/big" + "math/bits" + mathrand "math/rand" + "reflect" + "testing" + "testing/quick" +) + +func (v Element) String() string { + return hex.EncodeToString(v.Bytes()) +} + +// quickCheckConfig1024 will make each quickcheck test run (1024 * -quickchecks) +// times. The default value of -quickchecks is 100. +var quickCheckConfig1024 = &quick.Config{MaxCountScale: 1 << 10} + +func generateFieldElement(rand *mathrand.Rand) Element { + const maskLow52Bits = (1 << 52) - 1 + return Element{ + rand.Uint64() & maskLow52Bits, + rand.Uint64() & maskLow52Bits, + rand.Uint64() & maskLow52Bits, + rand.Uint64() & maskLow52Bits, + rand.Uint64() & maskLow52Bits, + } +} + +// weirdLimbs can be combined to generate a range of edge-case field elements. +// 0 and -1 are intentionally more weighted, as they combine well. +var ( + weirdLimbs51 = []uint64{ + 0, 0, 0, 0, + 1, + 19 - 1, + 19, + 0x2aaaaaaaaaaaa, + 0x5555555555555, + (1 << 51) - 20, + (1 << 51) - 19, + (1 << 51) - 1, (1 << 51) - 1, + (1 << 51) - 1, (1 << 51) - 1, + } + weirdLimbs52 = []uint64{ + 0, 0, 0, 0, 0, 0, + 1, + 19 - 1, + 19, + 0x2aaaaaaaaaaaa, + 0x5555555555555, + (1 << 51) - 20, + (1 << 51) - 19, + (1 << 51) - 1, (1 << 51) - 1, + (1 << 51) - 1, (1 << 51) - 1, + (1 << 51) - 1, (1 << 51) - 1, + 1 << 51, + (1 << 51) + 1, + (1 << 52) - 19, + (1 << 52) - 1, + } +) + +func generateWeirdFieldElement(rand *mathrand.Rand) Element { + return Element{ + weirdLimbs52[rand.Intn(len(weirdLimbs52))], + weirdLimbs51[rand.Intn(len(weirdLimbs51))], + weirdLimbs51[rand.Intn(len(weirdLimbs51))], + weirdLimbs51[rand.Intn(len(weirdLimbs51))], + weirdLimbs51[rand.Intn(len(weirdLimbs51))], + } +} + +func (Element) Generate(rand *mathrand.Rand, size int) reflect.Value { + if rand.Intn(2) == 0 { + return reflect.ValueOf(generateWeirdFieldElement(rand)) + } + return reflect.ValueOf(generateFieldElement(rand)) +} + +// isInBounds returns whether the element is within the expected bit size bounds +// after a light reduction. +func isInBounds(x *Element) bool { + return bits.Len64(x.l0) <= 52 && + bits.Len64(x.l1) <= 52 && + bits.Len64(x.l2) <= 52 && + bits.Len64(x.l3) <= 52 && + bits.Len64(x.l4) <= 52 +} + +func TestMultiplyDistributesOverAdd(t *testing.T) { + multiplyDistributesOverAdd := func(x, y, z Element) bool { + // Compute t1 = (x+y)*z + t1 := new(Element) + t1.Add(&x, &y) + t1.Multiply(t1, &z) + + // Compute t2 = x*z + y*z + t2 := new(Element) + t3 := new(Element) + t2.Multiply(&x, &z) + t3.Multiply(&y, &z) + t2.Add(t2, t3) + + return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2) + } + + if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig1024); err != nil { + t.Error(err) + } +} + +func TestMul64to128(t *testing.T) { + a := uint64(5) + b := uint64(5) + r := mul64(a, b) + if r.lo != 0x19 || r.hi != 0 { + t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi) + } + + a = uint64(18014398509481983) // 2^54 - 1 + b = uint64(18014398509481983) // 2^54 - 1 + r = mul64(a, b) + if r.lo != 0xff80000000000001 || r.hi != 0xfffffffffff { + t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi) + } + + a = uint64(1125899906842661) + b = uint64(2097155) + r = mul64(a, b) + r = addMul64(r, a, b) + r = addMul64(r, a, b) + r = addMul64(r, a, b) + r = addMul64(r, a, b) + if r.lo != 16888498990613035 || r.hi != 640 { + t.Errorf("wrong answer: %d + %d*(2**64)", r.lo, r.hi) + } +} + +func TestSetBytesRoundTrip(t *testing.T) { + f1 := func(in [32]byte, fe Element) bool { + fe.SetBytes(in[:]) + + // Mask the most significant bit as it's ignored by SetBytes. (Now + // instead of earlier so we check the masking in SetBytes is working.) + in[len(in)-1] &= (1 << 7) - 1 + + return bytes.Equal(in[:], fe.Bytes()) && isInBounds(&fe) + } + if err := quick.Check(f1, nil); err != nil { + t.Errorf("failed bytes->FE->bytes round-trip: %v", err) + } + + f2 := func(fe, r Element) bool { + r.SetBytes(fe.Bytes()) + + // Intentionally not using Equal not to go through Bytes again. + // Calling reduce because both Generate and SetBytes can produce + // non-canonical representations. + fe.reduce() + r.reduce() + return fe == r + } + if err := quick.Check(f2, nil); err != nil { + t.Errorf("failed FE->bytes->FE round-trip: %v", err) + } + + // Check some fixed vectors from dalek + type feRTTest struct { + fe Element + b []byte + } + var tests = []feRTTest{ + { + fe: Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}, + b: []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31}, + }, + { + fe: Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}, + b: []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122}, + }, + } + + for _, tt := range tests { + b := tt.fe.Bytes() + if !bytes.Equal(b, tt.b) || new(Element).SetBytes(tt.b).Equal(&tt.fe) != 1 { + t.Errorf("Failed fixed roundtrip: %v", tt) + } + } +} + +func swapEndianness(buf []byte) []byte { + for i := 0; i < len(buf)/2; i++ { + buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i] + } + return buf +} + +func TestBytesBigEquivalence(t *testing.T) { + f1 := func(in [32]byte, fe, fe1 Element) bool { + fe.SetBytes(in[:]) + + in[len(in)-1] &= (1 << 7) - 1 // mask the most significant bit + b := new(big.Int).SetBytes(swapEndianness(in[:])) + fe1.fromBig(b) + + if fe != fe1 { + return false + } + + buf := make([]byte, 32) // pad with zeroes + copy(buf, swapEndianness(fe1.toBig().Bytes())) + + return bytes.Equal(fe.Bytes(), buf) && isInBounds(&fe) && isInBounds(&fe1) + } + if err := quick.Check(f1, nil); err != nil { + t.Error(err) + } +} + +// fromBig sets v = n, and returns v. The bit length of n must not exceed 256. +func (v *Element) fromBig(n *big.Int) *Element { + if n.BitLen() > 32*8 { + panic("edwards25519: invalid field element input size") + } + + buf := make([]byte, 0, 32) + for _, word := range n.Bits() { + for i := 0; i < bits.UintSize; i += 8 { + if len(buf) >= cap(buf) { + break + } + buf = append(buf, byte(word)) + word >>= 8 + } + } + + return v.SetBytes(buf[:32]) +} + +func (v *Element) fromDecimal(s string) *Element { + n, ok := new(big.Int).SetString(s, 10) + if !ok { + panic("not a valid decimal: " + s) + } + return v.fromBig(n) +} + +// toBig returns v as a big.Int. +func (v *Element) toBig() *big.Int { + buf := v.Bytes() + + words := make([]big.Word, 32*8/bits.UintSize) + for n := range words { + for i := 0; i < bits.UintSize; i += 8 { + if len(buf) == 0 { + break + } + words[n] |= big.Word(buf[0]) << big.Word(i) + buf = buf[1:] + } + } + + return new(big.Int).SetBits(words) +} + +func TestDecimalConstants(t *testing.T) { + sqrtM1String := "19681161376707505956807079304988542015446066515923890162744021073123829784752" + if exp := new(Element).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 { + t.Errorf("sqrtM1 is %v, expected %v", sqrtM1, exp) + } + // d is in the parent package, and we don't want to expose d or fromDecimal. + // dString := "37095705934669439343138083508754565189542113879843219016388785533085940283555" + // if exp := new(Element).fromDecimal(dString); d.Equal(exp) != 1 { + // t.Errorf("d is %v, expected %v", d, exp) + // } +} + +func TestSetBytesRoundTripEdgeCases(t *testing.T) { + // TODO: values close to 0, close to 2^255-19, between 2^255-19 and 2^255-1, + // and between 2^255 and 2^256-1. Test both the documented SetBytes + // behavior, and that Bytes reduces them. +} + +// Tests self-consistency between Multiply and Square. +func TestConsistency(t *testing.T) { + var x Element + var x2, x2sq Element + + x = Element{1, 1, 1, 1, 1} + x2.Multiply(&x, &x) + x2sq.Square(&x) + + if x2 != x2sq { + t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq) + } + + var bytes [32]byte + + _, err := io.ReadFull(rand.Reader, bytes[:]) + if err != nil { + t.Fatal(err) + } + x.SetBytes(bytes[:]) + + x2.Multiply(&x, &x) + x2sq.Square(&x) + + if x2 != x2sq { + t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq) + } +} + +func TestEqual(t *testing.T) { + x := Element{1, 1, 1, 1, 1} + y := Element{5, 4, 3, 2, 1} + + eq := x.Equal(&x) + if eq != 1 { + t.Errorf("wrong about equality") + } + + eq = x.Equal(&y) + if eq != 0 { + t.Errorf("wrong about inequality") + } +} + +func TestInvert(t *testing.T) { + x := Element{1, 1, 1, 1, 1} + one := Element{1, 0, 0, 0, 0} + var xinv, r Element + + xinv.Invert(&x) + r.Multiply(&x, &xinv) + r.reduce() + + if one != r { + t.Errorf("inversion identity failed, got: %x", r) + } + + var bytes [32]byte + + _, err := io.ReadFull(rand.Reader, bytes[:]) + if err != nil { + t.Fatal(err) + } + x.SetBytes(bytes[:]) + + xinv.Invert(&x) + r.Multiply(&x, &xinv) + r.reduce() + + if one != r { + t.Errorf("random inversion identity failed, got: %x for field element %x", r, x) + } + + zero := Element{} + x.Set(&zero) + if xx := xinv.Invert(&x); xx != &xinv { + t.Errorf("inverting zero did not return the receiver") + } else if xinv.Equal(&zero) != 1 { + t.Errorf("inverting zero did not return zero") + } +} + +func TestSelectSwap(t *testing.T) { + a := Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676} + b := Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972} + + var c, d Element + + c.Select(&a, &b, 1) + d.Select(&a, &b, 0) + + if c.Equal(&a) != 1 || d.Equal(&b) != 1 { + t.Errorf("Select failed") + } + + c.Swap(&d, 0) + + if c.Equal(&a) != 1 || d.Equal(&b) != 1 { + t.Errorf("Swap failed") + } + + c.Swap(&d, 1) + + if c.Equal(&b) != 1 || d.Equal(&a) != 1 { + t.Errorf("Swap failed") + } +} + +func TestMult32(t *testing.T) { + mult32EquivalentToMul := func(x Element, y uint32) bool { + t1 := new(Element) + for i := 0; i < 100; i++ { + t1.Mult32(&x, y) + } + + ty := new(Element) + ty.l0 = uint64(y) + + t2 := new(Element) + for i := 0; i < 100; i++ { + t2.Multiply(&x, ty) + } + + return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2) + } + + if err := quick.Check(mult32EquivalentToMul, quickCheckConfig1024); err != nil { + t.Error(err) + } +} + +func TestSqrtRatio(t *testing.T) { + // From draft-irtf-cfrg-ristretto255-decaf448-00, Appendix A.4. + type test struct { + u, v string + wasSquare int + r string + } + var tests = []test{ + // If u is 0, the function is defined to return (0, TRUE), even if v + // is zero. Note that where used in this package, the denominator v + // is never zero. + { + "0000000000000000000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + 1, "0000000000000000000000000000000000000000000000000000000000000000", + }, + // 0/1 == 0² + { + "0000000000000000000000000000000000000000000000000000000000000000", + "0100000000000000000000000000000000000000000000000000000000000000", + 1, "0000000000000000000000000000000000000000000000000000000000000000", + }, + // If u is non-zero and v is zero, defined to return (0, FALSE). + { + "0100000000000000000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + 0, "0000000000000000000000000000000000000000000000000000000000000000", + }, + // 2/1 is not square in this field. + { + "0200000000000000000000000000000000000000000000000000000000000000", + "0100000000000000000000000000000000000000000000000000000000000000", + 0, "3c5ff1b5d8e4113b871bd052f9e7bcd0582804c266ffb2d4f4203eb07fdb7c54", + }, + // 4/1 == 2² + { + "0400000000000000000000000000000000000000000000000000000000000000", + "0100000000000000000000000000000000000000000000000000000000000000", + 1, "0200000000000000000000000000000000000000000000000000000000000000", + }, + // 1/4 == (2⁻¹)² == (2^(p-2))² per Euler's theorem + { + "0100000000000000000000000000000000000000000000000000000000000000", + "0400000000000000000000000000000000000000000000000000000000000000", + 1, "f6ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3f", + }, + } + + for i, tt := range tests { + u := new(Element).SetBytes(decodeHex(tt.u)) + v := new(Element).SetBytes(decodeHex(tt.v)) + want := new(Element).SetBytes(decodeHex(tt.r)) + got, wasSquare := new(Element).SqrtRatio(u, v) + if got.Equal(want) == 0 || wasSquare != tt.wasSquare { + t.Errorf("%d: got (%v, %v), want (%v, %v)", i, got, wasSquare, want, tt.wasSquare) + } + } +} + +func TestCarryPropagate(t *testing.T) { + asmLikeGeneric := func(a [5]uint64) bool { + t1 := &Element{a[0], a[1], a[2], a[3], a[4]} + t2 := &Element{a[0], a[1], a[2], a[3], a[4]} + + t1.carryPropagate() + t2.carryPropagateGeneric() + + if *t1 != *t2 { + t.Logf("got: %#v,\nexpected: %#v", t1, t2) + } + + return *t1 == *t2 && isInBounds(t2) + } + + if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { + t.Error(err) + } + + if !asmLikeGeneric([5]uint64{0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}) { + t.Errorf("failed for {0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}") + } +} + +func TestFeSquare(t *testing.T) { + asmLikeGeneric := func(a Element) bool { + t1 := a + t2 := a + + feSquareGeneric(&t1, &t1) + feSquare(&t2, &t2) + + if t1 != t2 { + t.Logf("got: %#v,\nexpected: %#v", t1, t2) + } + + return t1 == t2 && isInBounds(&t2) + } + + if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { + t.Error(err) + } +} + +func TestFeMul(t *testing.T) { + asmLikeGeneric := func(a, b Element) bool { + a1 := a + a2 := a + b1 := b + b2 := b + + feMulGeneric(&a1, &a1, &b1) + feMul(&a2, &a2, &b2) + + if a1 != a2 || b1 != b2 { + t.Logf("got: %#v,\nexpected: %#v", a1, a2) + t.Logf("got: %#v,\nexpected: %#v", b1, b2) + } + + return a1 == a2 && isInBounds(&a2) && + b1 == b2 && isInBounds(&b2) + } + + if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { + t.Error(err) + } +} + +func decodeHex(s string) []byte { + b, err := hex.DecodeString(s) + if err != nil { + panic(err) + } + return b +} diff --git a/openpgp/internal/encoding/short_byte_string.go b/openpgp/internal/encoding/short_byte_string.go new file mode 100644 index 00000000..0c3b9123 --- /dev/null +++ b/openpgp/internal/encoding/short_byte_string.go @@ -0,0 +1,50 @@ +package encoding + +import ( + "io" +) + +type ShortByteString struct { + length uint8 + data []byte +} + +func NewShortByteString(data []byte) *ShortByteString { + byteLength := uint8(len(data)) + + return &ShortByteString{byteLength, data} +} + +func (byteString *ShortByteString) Bytes() []byte { + return byteString.data +} + +func (byteString *ShortByteString) BitLength() uint16 { + return uint16(byteString.length) * 8 +} + +func (byteString *ShortByteString) EncodedBytes() []byte { + encodedLength := [1]byte{ + uint8(byteString.length), + } + return append(encodedLength[:], byteString.data...) +} + +func (byteString *ShortByteString) EncodedLength() uint16 { + return uint16(byteString.length) + 1 +} + +func (byteString *ShortByteString) ReadFrom(r io.Reader) (int64, error) { + var lengthBytes [1]byte + if n, err := io.ReadFull(r, lengthBytes[:]); err != nil { + return int64(n), err + } + + byteString.length = uint8(lengthBytes[0]) + + byteString.data = make([]byte, byteString.length) + if n, err := io.ReadFull(r, byteString.data); err != nil { + return int64(n + 1), err + } + return int64(byteString.length + 1), nil +} diff --git a/openpgp/internal/encoding/short_byte_string_test.go b/openpgp/internal/encoding/short_byte_string_test.go new file mode 100644 index 00000000..6544b4ec --- /dev/null +++ b/openpgp/internal/encoding/short_byte_string_test.go @@ -0,0 +1,61 @@ +package encoding + +import ( + "testing" + "bytes" +) + +var octetStreamTests = []struct { + data []byte +} { + { + data: []byte{0x0, 0x0, 0x0}, + }, + { + data: []byte {0x1, 0x2, 0x03}, + }, + { + data: make([]byte, 255), + }, +} + +func TestShortByteString(t *testing.T) { + for i, test := range octetStreamTests { + octetStream := NewShortByteString(test.data) + + if b := octetStream.Bytes(); !bytes.Equal(b, test.data) { + t.Errorf("#%d: bad creation got:%x want:%x", i, b, test.data) + } + + expectedBitLength := uint16(len(test.data)) * 8 + if bitLength := octetStream.BitLength(); bitLength != expectedBitLength { + t.Errorf("#%d: bad bit length got:%d want :%d", i, bitLength, expectedBitLength) + } + + expectedEncodedLength := uint16(len(test.data)) + 1 + if encodedLength := octetStream.EncodedLength(); encodedLength != expectedEncodedLength { + t.Errorf("#%d: bad encoded length got:%d want:%d", i, encodedLength, expectedEncodedLength) + } + + encodedBytes := octetStream.EncodedBytes() + if !bytes.Equal(encodedBytes[1:], test.data) { + t.Errorf("#%d: bad encoded bytes got:%x want:%x", i, encodedBytes[1:], test.data) + } + + encodedLength := int(encodedBytes[0]) + if encodedLength != len(test.data) { + t.Errorf("#%d: bad encoded length got:%d want%d", i, encodedLength, len(test.data)) + } + + newStream := new(ShortByteString) + newStream.ReadFrom(bytes.NewReader(encodedBytes)) + + if !checkEquality(newStream, octetStream) { + t.Errorf("#%d: bad parsing of encoded octet stream", i) + } + } +} + +func checkEquality (left *ShortByteString, right *ShortByteString) bool { + return (left.length == right.length) && (bytes.Equal(left.data, right.data)) +} diff --git a/openpgp/key_generation.go b/openpgp/key_generation.go index c9502c25..f5687495 100644 --- a/openpgp/key_generation.go +++ b/openpgp/key_generation.go @@ -22,6 +22,7 @@ import ( "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" "github.com/ProtonMail/go-crypto/openpgp/internal/ecc" "github.com/ProtonMail/go-crypto/openpgp/packet" + "github.com/ProtonMail/go-crypto/openpgp/symmetric" "github.com/ProtonMail/go-crypto/openpgp/x25519" "github.com/ProtonMail/go-crypto/openpgp/x448" ) @@ -314,6 +315,9 @@ func newSigner(config *packet.Config) (signer interface{}, err error) { return nil, err } return priv, nil + case packet.ExperimentalPubKeyAlgoHMAC: + hash := algorithm.HashById[hashToHashId(config.Hash())] + return symmetric.HMACGenerateKey(config.Random(), hash) default: return nil, errors.InvalidArgumentError("unsupported public key algorithm") } @@ -356,6 +360,9 @@ func newDecrypter(config *packet.Config) (decrypter interface{}, err error) { return x25519.GenerateKey(config.Random()) case packet.PubKeyAlgoEd448, packet.PubKeyAlgoX448: // When passing Ed448, we generate an x448 subkey return x448.GenerateKey(config.Random()) + case packet.ExperimentalPubKeyAlgoAEAD: + cipher := algorithm.CipherFunction(config.Cipher()) + return symmetric.AEADGenerateKey(config.Random(), cipher) default: return nil, errors.InvalidArgumentError("unsupported public key algorithm") } diff --git a/openpgp/keys.go b/openpgp/keys.go index a071353e..bbcc95d9 100644 --- a/openpgp/keys.go +++ b/openpgp/keys.go @@ -371,7 +371,7 @@ func (el EntityList) KeysByIdUsage(id uint64, requiredUsage byte) (keys []Key) { func (el EntityList) DecryptionKeys() (keys []Key) { for _, e := range el { for _, subKey := range e.Subkeys { - if subKey.PrivateKey != nil && subKey.Sig.FlagsValid && (subKey.Sig.FlagEncryptStorage || subKey.Sig.FlagEncryptCommunications) { + if subKey.PrivateKey != nil && subKey.Sig.FlagsValid && (subKey.Sig.FlagEncryptStorage || subKey.Sig.FlagEncryptCommunications || subKey.Sig.FlagForward) { keys = append(keys, Key{e, subKey.PublicKey, subKey.PrivateKey, subKey.Sig, subKey.Revocations}) } } @@ -761,6 +761,10 @@ func (e *Entity) serializePrivate(w io.Writer, config *packet.Config, reSign boo // Serialize writes the public part of the given Entity to w, including // signatures from other entities. No private key material will be output. func (e *Entity) Serialize(w io.Writer) error { + if e.PrimaryKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoHMAC || + e.PrimaryKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoAEAD { + return errors.InvalidArgumentError("Can't serialize symmetric primary key") + } err := e.PrimaryKey.Serialize(w) if err != nil { return err @@ -790,6 +794,16 @@ func (e *Entity) Serialize(w io.Writer) error { } } for _, subkey := range e.Subkeys { + // The types of keys below are only useful as private keys. Thus, the + // public key packets contain no meaningful information and do not need + // to be serialized. + // Prevent public key export for forwarding keys, see forwarding section 4.1. + if subkey.PublicKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoHMAC || + subkey.PublicKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoAEAD || + subkey.Sig.FlagForward { + continue + } + err = subkey.PublicKey.Serialize(w) if err != nil { return err diff --git a/openpgp/keys_test.go b/openpgp/keys_test.go index 3cb4ac00..184325a6 100644 --- a/openpgp/keys_test.go +++ b/openpgp/keys_test.go @@ -22,6 +22,7 @@ import ( "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" "github.com/ProtonMail/go-crypto/openpgp/packet" "github.com/ProtonMail/go-crypto/openpgp/s2k" + "github.com/ProtonMail/go-crypto/openpgp/symmetric" ) var hashes = []crypto.Hash{ @@ -1169,6 +1170,191 @@ func TestAddSubkeySerialized(t *testing.T) { } } +func TestAddHMACSubkey(t *testing.T) { + c := &packet.Config{ + RSABits: 512, + Algorithm: packet.ExperimentalPubKeyAlgoHMAC, + } + + entity, err := NewEntity("Golang Gopher", "Test Key", "no-reply@golang.com", &packet.Config{RSABits: 1024}) + if err != nil { + t.Fatal(err) + } + + err = entity.AddSigningSubkey(c) + if err != nil { + t.Fatal(err) + } + + buf := bytes.NewBuffer(nil) + w, _ := armor.Encode(buf, "PGP PRIVATE KEY BLOCK", nil) + if err := entity.SerializePrivate(w, nil); err != nil { + t.Errorf("failed to serialize entity: %s", err) + } + w.Close() + + key, err := ReadArmoredKeyRing(buf) + if err != nil { + t.Error("could not read keyring", err) + } + + generatedPrivateKey := entity.Subkeys[1].PrivateKey.PrivateKey.(*symmetric.HMACPrivateKey) + parsedPrivateKey := key[0].Subkeys[1].PrivateKey.PrivateKey.(*symmetric.HMACPrivateKey) + + generatedPublicKey := entity.Subkeys[1].PublicKey.PublicKey.(*symmetric.HMACPublicKey) + parsedPublicKey := key[0].Subkeys[1].PublicKey.PublicKey.(*symmetric.HMACPublicKey) + + if !bytes.Equal(parsedPrivateKey.Key, generatedPrivateKey.Key) { + t.Error("parsed wrong key") + } + if !bytes.Equal(parsedPublicKey.Key, generatedPrivateKey.Key) { + t.Error("parsed wrong key in public part") + } + if !bytes.Equal(generatedPublicKey.Key, generatedPrivateKey.Key) { + t.Error("generated Public and Private Key differ") + } + + if !bytes.Equal(parsedPrivateKey.HashSeed[:], generatedPrivateKey.HashSeed[:]) { + t.Error("parsed wrong hash seed") + } + + if parsedPrivateKey.PublicKey.Hash != generatedPrivateKey.PublicKey.Hash { + t.Error("parsed wrong cipher id") + } + if !bytes.Equal(parsedPrivateKey.PublicKey.BindingHash[:], generatedPrivateKey.PublicKey.BindingHash[:]) { + t.Error("parsed wrong binding hash") + } +} + +func TestSerializeSymmetricSubkeyError(t *testing.T) { + entity, err := NewEntity("Golang Gopher", "Test Key", "no-reply@golang.com", &packet.Config{RSABits: 1024}) + if err != nil { + t.Fatal(err) + } + + buf := bytes.NewBuffer(nil) + w, _ := armor.Encode(buf, "PGP PRIVATE KEY BLOCK", nil) + + entity.PrimaryKey.PubKeyAlgo = 100 + err = entity.Serialize(w) + if err == nil { + t.Fatal(err) + } + + entity.PrimaryKey.PubKeyAlgo = 101 + err = entity.Serialize(w) + if err == nil { + t.Fatal(err) + } +} + +func TestAddAEADSubkey(t *testing.T) { + c := &packet.Config{ + RSABits: 512, + Algorithm: packet.ExperimentalPubKeyAlgoAEAD, + } + entity, err := NewEntity("Golang Gopher", "Test Key", "no-reply@golang.com", &packet.Config{RSABits: 1024}) + if err != nil { + t.Fatal(err) + } + + err = entity.AddEncryptionSubkey(c) + if err != nil { + t.Fatal(err) + } + + generatedPrivateKey := entity.Subkeys[1].PrivateKey.PrivateKey.(*symmetric.AEADPrivateKey) + + buf := bytes.NewBuffer(nil) + w, _ := armor.Encode(buf, "PGP PRIVATE KEY BLOCK", nil) + if err := entity.SerializePrivate(w, nil); err != nil { + t.Errorf("failed to serialize entity: %s", err) + } + w.Close() + + key, err := ReadArmoredKeyRing(buf) + if err != nil { + t.Error("could not read keyring", err) + } + + parsedPrivateKey := key[0].Subkeys[1].PrivateKey.PrivateKey.(*symmetric.AEADPrivateKey) + + generatedPublicKey := entity.Subkeys[1].PublicKey.PublicKey.(*symmetric.AEADPublicKey) + parsedPublicKey := key[0].Subkeys[1].PublicKey.PublicKey.(*symmetric.AEADPublicKey) + + if !bytes.Equal(parsedPrivateKey.Key, generatedPrivateKey.Key) { + t.Error("parsed wrong key") + } + if !bytes.Equal(parsedPublicKey.Key, generatedPrivateKey.Key) { + t.Error("parsed wrong key in public part") + } + if !bytes.Equal(generatedPublicKey.Key, generatedPrivateKey.Key) { + t.Error("generated Public and Private Key differ") + } + + if !bytes.Equal(parsedPrivateKey.HashSeed[:], generatedPrivateKey.HashSeed[:]) { + t.Error("parsed wrong hash seed") + } + + if parsedPrivateKey.PublicKey.Cipher.Id() != generatedPrivateKey.PublicKey.Cipher.Id() { + t.Error("parsed wrong cipher id") + } + if !bytes.Equal(parsedPrivateKey.PublicKey.BindingHash[:], generatedPrivateKey.PublicKey.BindingHash[:]) { + t.Error("parsed wrong binding hash") + } +} + +func TestNoSymmetricKeySerialized(t *testing.T) { + aeadConfig := &packet.Config{ + RSABits: 512, + DefaultHash: crypto.SHA512, + Algorithm: packet.ExperimentalPubKeyAlgoAEAD, + DefaultCipher: packet.CipherAES256, + } + hmacConfig := &packet.Config{ + RSABits: 512, + DefaultHash: crypto.SHA512, + Algorithm: packet.ExperimentalPubKeyAlgoHMAC, + DefaultCipher: packet.CipherAES256, + } + entity, err := NewEntity("Golang Gopher", "Test Key", "no-reply@golang.com", &packet.Config{RSABits: 1024}) + if err != nil { + t.Fatal(err) + } + + err = entity.AddEncryptionSubkey(aeadConfig) + if err != nil { + t.Fatal(err) + } + err = entity.AddSigningSubkey(hmacConfig) + if err != nil { + t.Fatal(err) + } + + w := bytes.NewBuffer(nil) + entity.Serialize(w) + + firstSymKey := entity.Subkeys[1].PrivateKey.PrivateKey.(*symmetric.AEADPrivateKey).Key + i := bytes.Index(w.Bytes(), firstSymKey) + + secondSymKey := entity.Subkeys[2].PrivateKey.PrivateKey.(*symmetric.HMACPrivateKey).Key + k := bytes.Index(w.Bytes(), secondSymKey) + + if (i > 0) || (k > 0) { + t.Error("Private key was serialized with public") + } + + firstBindingHash := entity.Subkeys[1].PublicKey.PublicKey.(*symmetric.AEADPublicKey).BindingHash + i = bytes.Index(w.Bytes(), firstBindingHash[:]) + + secondBindingHash := entity.Subkeys[2].PublicKey.PublicKey.(*symmetric.HMACPublicKey).BindingHash + k = bytes.Index(w.Bytes(), secondBindingHash[:]) + if (i > 0) || (k > 0) { + t.Errorf("Symmetric public key metadata exported %d %d", i, k) + } + +} + func TestAddSubkeyWithConfig(t *testing.T) { c := &packet.Config{ DefaultHash: crypto.SHA512, @@ -1865,3 +2051,37 @@ mQ00BF00000BCAD0000000000000000000000000000000000000000000000000 000000000000000000000000000000000000ABE000G0Dn000000000000000000iQ00BB0BAgAGBCG00000` ReadArmoredKeyRing(strings.NewReader(data)) } + +func TestSymmetricKeys(t *testing.T) { + data := `-----BEGIN PGP PRIVATE KEY BLOCK----- + +xWoEYs7w5mUIcFvlmkuricX26x138uvHGlwIaxWIbRnx1+ggPcveTcwA4zSZ +n6XcD0Q5aLe6dTEBwCyfUecZ/nA0W8Pl9xBHfjIjQuxcUBnIqxZ061RZPjef +D/XIQga1ftLDelhylQwL7R3TzQ1TeW1tZXRyaWMgS2V5wmkEEGUIAB0FAmLO +8OYECwkHCAMVCAoEFgACAQIZAQIbAwIeAQAhCRCRTKq2ObiQKxYhBMHTTXXF +ULQ2M2bYNJFMqrY5uJArIawgJ+5RSsN8VNuZTKJbG88TIedU05wwKjW3wqvT +X6Z7yfbHagRizvDmZAluL/kJo6hZ1kFENpQkWD/Kfv1vAG3nbxhsVEzBQ6a1 +OAD24BaKJz6gWgj4lASUNK5OuXnLc3J79Bt1iRGkSbiPzRs/bplB4TwbILeC +ZLeDy9kngZDosgsIk5sBgGEqS9y5HiHCVQQYZQgACQUCYs7w5gIbDAAhCRCR +TKq2ObiQKxYhBMHTTXXFULQ2M2bYNJFMqrY5uJArENkgL0Bc+OI/1na0XWqB +TxGVotQ4A/0u0VbOMEUfnrI8Fms= +=RdCW +-----END PGP PRIVATE KEY BLOCK----- +` + keys, err := ReadArmoredKeyRing(strings.NewReader(data)) + if err != nil { + t.Fatal(err) + } + if len(keys) != 1 { + t.Errorf("Expected 1 symmetric key, got %d", len(keys)) + } + if keys[0].PrivateKey.PubKeyAlgo != packet.ExperimentalPubKeyAlgoHMAC { + t.Errorf("Expected HMAC primary key") + } + if len(keys[0].Subkeys) != 1 { + t.Errorf("Expected 1 symmetric subkey, got %d", len(keys[0].Subkeys)) + } + if keys[0].Subkeys[0].PrivateKey.PubKeyAlgo != packet.ExperimentalPubKeyAlgoAEAD { + t.Errorf("Expected AEAD subkey") + } +} diff --git a/openpgp/packet/encrypted_key.go b/openpgp/packet/encrypted_key.go index 58340945..a84cf6b4 100644 --- a/openpgp/packet/encrypted_key.go +++ b/openpgp/packet/encrypted_key.go @@ -17,7 +17,9 @@ import ( "github.com/ProtonMail/go-crypto/openpgp/ecdh" "github.com/ProtonMail/go-crypto/openpgp/elgamal" "github.com/ProtonMail/go-crypto/openpgp/errors" + "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" "github.com/ProtonMail/go-crypto/openpgp/internal/encoding" + "github.com/ProtonMail/go-crypto/openpgp/symmetric" "github.com/ProtonMail/go-crypto/openpgp/x25519" "github.com/ProtonMail/go-crypto/openpgp/x448" ) @@ -37,6 +39,9 @@ type EncryptedKey struct { ephemeralPublicX25519 *x25519.PublicKey // used for x25519 ephemeralPublicX448 *x448.PublicKey // used for x448 encryptedSession []byte // used for x25519 and x448 + + nonce []byte + aeadMode algorithm.AEADMode } func (e *EncryptedKey) parse(r io.Reader) (err error) { @@ -133,6 +138,21 @@ func (e *EncryptedKey) parse(r io.Reader) (err error) { if err != nil { return } + case ExperimentalPubKeyAlgoAEAD: + var aeadMode [1]byte + if _, err = readFull(r, aeadMode[:]); err != nil { + return + } + e.aeadMode = algorithm.AEADMode(aeadMode[0]) + nonceLength := e.aeadMode.NonceLength() + e.nonce = make([]byte, nonceLength) + if _, err = readFull(r, e.nonce); err != nil { + return + } + e.encryptedMPI1 = new(encoding.ShortByteString) + if _, err = e.encryptedMPI1.ReadFrom(r); err != nil { + return + } } if e.Version < 6 { switch e.Algo { @@ -191,6 +211,9 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey, config *Config) error { b, err = x25519.Decrypt(priv.PrivateKey.(*x25519.PrivateKey), e.ephemeralPublicX25519, e.encryptedSession) case PubKeyAlgoX448: b, err = x448.Decrypt(priv.PrivateKey.(*x448.PrivateKey), e.ephemeralPublicX448, e.encryptedSession) + case ExperimentalPubKeyAlgoAEAD: + priv := priv.PrivateKey.(*symmetric.AEADPrivateKey) + b, err = priv.Decrypt(e.nonce, e.encryptedMPI1.Bytes(), e.aeadMode) default: err = errors.InvalidArgumentError("cannot decrypt encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo))) } @@ -200,7 +223,7 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey, config *Config) error { var key []byte switch priv.PubKeyAlgo { - case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal, PubKeyAlgoECDH: + case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal, PubKeyAlgoECDH, ExperimentalPubKeyAlgoAEAD: keyOffset := 0 if e.Version < 6 { e.CipherFunc = CipherFunction(b[0]) @@ -387,7 +410,7 @@ func SerializeEncryptedKeyAEADwithHiddenOption(w io.Writer, pub *PublicKey, ciph var keyBlock []byte switch pub.PubKeyAlgo { - case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal, PubKeyAlgoECDH: + case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal, PubKeyAlgoECDH, ExperimentalPubKeyAlgoAEAD: lenKeyBlock := len(key) + 2 if version < 6 { lenKeyBlock += 1 // cipher type included @@ -415,7 +438,9 @@ func SerializeEncryptedKeyAEADwithHiddenOption(w io.Writer, pub *PublicKey, ciph return serializeEncryptedKeyX25519(w, config.Random(), buf[:lenHeaderWritten], pub.PublicKey.(*x25519.PublicKey), keyBlock, byte(cipherFunc), version) case PubKeyAlgoX448: return serializeEncryptedKeyX448(w, config.Random(), buf[:lenHeaderWritten], pub.PublicKey.(*x448.PublicKey), keyBlock, byte(cipherFunc), version) - case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly: + case ExperimentalPubKeyAlgoAEAD: + return serializeEncryptedKeyAEAD(w, config.Random(), buf[:lenHeaderWritten], pub.PublicKey.(*symmetric.AEADPublicKey), keyBlock, config.AEAD()) + case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly, ExperimentalPubKeyAlgoHMAC: return errors.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo))) } @@ -438,6 +463,36 @@ func SerializeEncryptedKeyWithHiddenOption(w io.Writer, pub *PublicKey, cipherFu return SerializeEncryptedKeyAEADwithHiddenOption(w, pub, cipherFunc, config.AEAD() != nil, key, hidden, config) } +func (e *EncryptedKey) ProxyTransform(instance ForwardingInstance) (transformed *EncryptedKey, err error) { + if e.Algo != PubKeyAlgoECDH { + return nil, errors.InvalidArgumentError("invalid PKESK") + } + + if e.KeyId != 0 && e.KeyId != instance.GetForwarderKeyId() { + return nil, errors.InvalidArgumentError("invalid key id in PKESK") + } + + ephemeral := e.encryptedMPI1.Bytes() + transformedEphemeral, err := ecdh.ProxyTransform(ephemeral, instance.ProxyParameter) + if err != nil { + return nil, err + } + + wrappedKey := e.encryptedMPI2.Bytes() + copiedWrappedKey := make([]byte, len(wrappedKey)) + copy(copiedWrappedKey, wrappedKey) + + transformed = &EncryptedKey{ + Version: e.Version, + KeyId: instance.getForwardeeKeyIdOrZero(e.KeyId), + Algo: e.Algo, + encryptedMPI1: encoding.NewMPI(transformedEphemeral), + encryptedMPI2: encoding.NewOID(copiedWrappedKey), + } + + return transformed, nil +} + func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header []byte, pub *rsa.PublicKey, keyBlock []byte) error { cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock) if err != nil { @@ -554,6 +609,35 @@ func serializeEncryptedKeyX448(w io.Writer, rand io.Reader, header []byte, pub * return x448.EncodeFields(w, ephemeralPublicX448, ciphertext, cipherFunc, version == 6) } +func serializeEncryptedKeyAEAD(w io.Writer, rand io.Reader, header []byte, pub *symmetric.AEADPublicKey, keyBlock []byte, config *AEADConfig) error { + mode := algorithm.AEADMode(config.Mode()) + iv, ciphertextRaw, err := pub.Encrypt(rand, keyBlock, mode) + if err != nil { + return errors.InvalidArgumentError("AEAD encryption failed: " + err.Error()) + } + + ciphertextShortByteString := encoding.NewShortByteString(ciphertextRaw) + + buffer := append([]byte{byte(mode)}, iv...) + buffer = append(buffer, ciphertextShortByteString.EncodedBytes()...) + + packetLen := len(header) /* header length */ + packetLen += int(len(buffer)) + + err = serializeHeader(w, packetTypeEncryptedKey, packetLen) + if err != nil { + return err + } + + _, err = w.Write(header[:]) + if err != nil { + return err + } + + _, err = w.Write(buffer) + return err +} + func checksumKeyMaterial(key []byte) uint16 { var checksum uint16 for _, v := range key { diff --git a/openpgp/packet/encrypted_key_test.go b/openpgp/packet/encrypted_key_test.go index 787c7fec..5ed0a8ed 100644 --- a/openpgp/packet/encrypted_key_test.go +++ b/openpgp/packet/encrypted_key_test.go @@ -16,6 +16,7 @@ import ( "crypto" "crypto/rsa" + "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" "github.com/ProtonMail/go-crypto/openpgp/x25519" "github.com/ProtonMail/go-crypto/openpgp/x448" ) @@ -338,3 +339,33 @@ func TestSerializingEncryptedKey(t *testing.T) { t.Fatalf("serialization of encrypted key differed from original. Original was %s, but reserialized as %s", encryptedKeyHex, bufHex) } } + +func TestSymmetricallyEncryptedKey(t *testing.T) { + const encryptedKeyHex = "c14f03999bd17d726446da64018cb4d628ae753c646b81f87f21269cd733df9db940896a0b0e48f4d3b26e2dfbcf59ca7d30b65ea95ebb072e643407c732c479093b9d180c2eb51c98814e1bbbc6d0a17f" + + expectedNonce := []byte{0x8c, 0xb4, 0xd6, 0x28, 0xae, 0x75, 0x3c, 0x64, 0x6b, 0x81, 0xf8, 0x7f, 0x21, 0x26, 0x9c, 0xd7} + + expectedCiphertext := []byte{0xdf, 0x9d, 0xb9, 0x40, 0x89, 0x6a, 0x0b, 0x0e, 0x48, 0xf4, 0xd3, 0xb2, 0x6e, 0x2d, 0xfb, 0xcf, 0x59, 0xca, 0x7d, 0x30, 0xb6, 0x5e, 0xa9, 0x5e, 0xbb, 0x07, 0x2e, 0x64, 0x34, 0x07, 0xc7, 0x32, 0xc4, 0x79, 0x09, 0x3b, 0x9d, 0x18, 0x0c, 0x2e, 0xb5, 0x1c, 0x98, 0x81, 0x4e, 0x1b, 0xbb, 0xc6, 0xd0, 0xa1, 0x7f} + + p, err := Read(readerFromHex(encryptedKeyHex)) + if err != nil { + t.Fatal("error reading packet") + } + + ek, ok := p.(*EncryptedKey) + if !ok { + t.Fatalf("didn't parse and EncryptedKey, got %#v", p) + } + + if ek.aeadMode != algorithm.AEADModeEAX { + t.Errorf("Parsed wrong aead mode, got %d, expected: 1", ek.aeadMode) + } + + if !bytes.Equal(expectedNonce, ek.nonce) { + t.Errorf("Parsed wrong nonce, got %x, expected %x", ek.nonce, expectedNonce) + } + + if !bytes.Equal(expectedCiphertext, ek.encryptedMPI1.Bytes()) { + t.Errorf("Parsed wrong ciphertext, got %x, expected %x", ek.encryptedMPI1.Bytes(), expectedCiphertext) + } +} diff --git a/openpgp/packet/forwarding.go b/openpgp/packet/forwarding.go new file mode 100644 index 00000000..50b4de44 --- /dev/null +++ b/openpgp/packet/forwarding.go @@ -0,0 +1,36 @@ +package packet + +import "encoding/binary" + +// ForwardingInstance represents a single forwarding instance (mapping IDs to a Proxy Param) +type ForwardingInstance struct { + KeyVersion int + ForwarderFingerprint []byte + ForwardeeFingerprint []byte + ProxyParameter []byte +} + +func (f *ForwardingInstance) GetForwarderKeyId() uint64 { + return computeForwardingKeyId(f.ForwarderFingerprint, f.KeyVersion) +} + +func (f *ForwardingInstance) GetForwardeeKeyId() uint64 { + return computeForwardingKeyId(f.ForwardeeFingerprint, f.KeyVersion) +} + +func (f *ForwardingInstance) getForwardeeKeyIdOrZero(originalKeyId uint64) uint64 { + if originalKeyId == 0 { + return 0 + } + + return f.GetForwardeeKeyId() +} + +func computeForwardingKeyId(fingerprint []byte, version int) uint64 { + switch version { + case 4: + return binary.BigEndian.Uint64(fingerprint[12:20]) + default: + panic("invalid pgp key version") + } +} \ No newline at end of file diff --git a/openpgp/packet/packet.go b/openpgp/packet/packet.go index 1e92e22c..dd4ad34c 100644 --- a/openpgp/packet/packet.go +++ b/openpgp/packet/packet.go @@ -506,6 +506,9 @@ const ( PubKeyAlgoEd25519 PublicKeyAlgorithm = 27 PubKeyAlgoEd448 PublicKeyAlgorithm = 28 + ExperimentalPubKeyAlgoAEAD PublicKeyAlgorithm = 100 + ExperimentalPubKeyAlgoHMAC PublicKeyAlgorithm = 101 + // Deprecated in RFC 4880, Section 13.5. Use key flags instead. PubKeyAlgoRSAEncryptOnly PublicKeyAlgorithm = 2 PubKeyAlgoRSASignOnly PublicKeyAlgorithm = 3 @@ -515,7 +518,7 @@ const ( // key of the given type. func (pka PublicKeyAlgorithm) CanEncrypt() bool { switch pka { - case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal, PubKeyAlgoECDH, PubKeyAlgoX25519, PubKeyAlgoX448: + case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal, PubKeyAlgoECDH, PubKeyAlgoX25519, PubKeyAlgoX448, ExperimentalPubKeyAlgoAEAD: return true } return false @@ -525,7 +528,7 @@ func (pka PublicKeyAlgorithm) CanEncrypt() bool { // sign a message. func (pka PublicKeyAlgorithm) CanSign() bool { switch pka { - case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA, PubKeyAlgoECDSA, PubKeyAlgoEdDSA, PubKeyAlgoEd25519, PubKeyAlgoEd448: + case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA, PubKeyAlgoECDSA, PubKeyAlgoEdDSA, PubKeyAlgoEd25519, PubKeyAlgoEd448, ExperimentalPubKeyAlgoHMAC: return true } return false diff --git a/openpgp/packet/private_key.go b/openpgp/packet/private_key.go index f04e6c6b..406c56e6 100644 --- a/openpgp/packet/private_key.go +++ b/openpgp/packet/private_key.go @@ -28,6 +28,7 @@ import ( "github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/internal/encoding" "github.com/ProtonMail/go-crypto/openpgp/s2k" + "github.com/ProtonMail/go-crypto/openpgp/symmetric" "github.com/ProtonMail/go-crypto/openpgp/x25519" "github.com/ProtonMail/go-crypto/openpgp/x448" "golang.org/x/crypto/hkdf" @@ -166,6 +167,8 @@ func NewSignerPrivateKey(creationTime time.Time, signer interface{}) *PrivateKey pk.PublicKey = *NewEd448PublicKey(creationTime, &pubkey.PublicKey) case ed448.PrivateKey: pk.PublicKey = *NewEd448PublicKey(creationTime, &pubkey.PublicKey) + case *symmetric.HMACPrivateKey: + pk.PublicKey = *NewHMACPublicKey(creationTime, &pubkey.PublicKey) default: panic("openpgp: unknown signer type in NewSignerPrivateKey") } @@ -187,6 +190,8 @@ func NewDecrypterPrivateKey(creationTime time.Time, decrypter interface{}) *Priv pk.PublicKey = *NewX25519PublicKey(creationTime, &priv.PublicKey) case *x448.PrivateKey: pk.PublicKey = *NewX448PublicKey(creationTime, &priv.PublicKey) + case *symmetric.AEADPrivateKey: + pk.PublicKey = *NewAEADPublicKey(creationTime, &priv.PublicKey) default: panic("openpgp: unknown decrypter type in NewDecrypterPrivateKey") } @@ -530,6 +535,24 @@ func serializeEd448PrivateKey(w io.Writer, priv *ed448.PrivateKey) error { return err } +func serializeAEADPrivateKey(w io.Writer, priv *symmetric.AEADPrivateKey) (err error) { + _, err = w.Write(priv.HashSeed[:]) + if err != nil { + return + } + _, err = w.Write(priv.Key) + return +} + +func serializeHMACPrivateKey(w io.Writer, priv *symmetric.HMACPrivateKey) (err error) { + _, err = w.Write(priv.HashSeed[:]) + if err != nil { + return + } + _, err = w.Write(priv.Key) + return +} + // decrypt decrypts an encrypted private key using a decryption key. func (pk *PrivateKey) decrypt(decryptionKey []byte) error { if pk.Dummy() { @@ -830,6 +853,10 @@ func (pk *PrivateKey) serializePrivateKey(w io.Writer) (err error) { err = serializeEd25519PrivateKey(w, priv) case *ed448.PrivateKey: err = serializeEd448PrivateKey(w, priv) + case *symmetric.AEADPrivateKey: + err = serializeAEADPrivateKey(w, priv) + case *symmetric.HMACPrivateKey: + err = serializeHMACPrivateKey(w, priv) default: err = errors.InvalidArgumentError("unknown private key type") } @@ -861,6 +888,10 @@ func (pk *PrivateKey) parsePrivateKey(data []byte) (err error) { default: err = errors.StructuralError("unknown private key type") return + case ExperimentalPubKeyAlgoAEAD: + return pk.parseAEADPrivateKey(data) + case ExperimentalPubKeyAlgoHMAC: + return pk.parseHMACPrivateKey(data) } } @@ -1121,6 +1152,66 @@ func (pk *PrivateKey) applyHKDF(inputKey []byte) []byte { return encryptionKey } +func (pk *PrivateKey) parseAEADPrivateKey(data []byte) (err error) { + pubKey := pk.PublicKey.PublicKey.(*symmetric.AEADPublicKey) + + aeadPriv := new(symmetric.AEADPrivateKey) + aeadPriv.PublicKey = *pubKey + + copy(aeadPriv.HashSeed[:], data[:32]) + + priv := make([]byte, pubKey.Cipher.KeySize()) + copy(priv, data[32:]) + aeadPriv.Key = priv + aeadPriv.PublicKey.Key = aeadPriv.Key + + if err = validateAEADParameters(aeadPriv); err != nil { + return + } + + pk.PrivateKey = aeadPriv + pk.PublicKey.PublicKey = &aeadPriv.PublicKey + return +} + +func (pk *PrivateKey) parseHMACPrivateKey(data []byte) (err error) { + pubKey := pk.PublicKey.PublicKey.(*symmetric.HMACPublicKey) + + hmacPriv := new(symmetric.HMACPrivateKey) + hmacPriv.PublicKey = *pubKey + + copy(hmacPriv.HashSeed[:], data[:32]) + + priv := make([]byte, pubKey.Hash.Size()) + copy(priv, data[32:]) + hmacPriv.Key = data[32:] + hmacPriv.PublicKey.Key = hmacPriv.Key + + if err = validateHMACParameters(hmacPriv); err != nil { + return + } + + pk.PrivateKey = hmacPriv + pk.PublicKey.PublicKey = &hmacPriv.PublicKey + return +} + +func validateAEADParameters(priv *symmetric.AEADPrivateKey) error { + return validateCommonSymmetric(priv.HashSeed, priv.PublicKey.BindingHash) +} + +func validateHMACParameters(priv *symmetric.HMACPrivateKey) error { + return validateCommonSymmetric(priv.HashSeed, priv.PublicKey.BindingHash) +} + +func validateCommonSymmetric(seed [32]byte, bindingHash [32]byte) error { + expectedBindingHash := symmetric.ComputeBindingHash(seed) + if !bytes.Equal(expectedBindingHash, bindingHash[:]) { + return errors.KeyInvalidError("symmetric: wrong binding hash") + } + return nil +} + func validateDSAParameters(priv *dsa.PrivateKey) error { p := priv.P // group prime q := priv.Q // subgroup order diff --git a/openpgp/packet/public_key.go b/openpgp/packet/public_key.go index f8da781b..d9e2c85e 100644 --- a/openpgp/packet/public_key.go +++ b/openpgp/packet/public_key.go @@ -5,12 +5,14 @@ package packet import ( + "bytes" "crypto/dsa" "crypto/rsa" "crypto/sha1" "crypto/sha256" _ "crypto/sha512" "encoding/binary" + goerrors "errors" "fmt" "hash" "io" @@ -28,6 +30,7 @@ import ( "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" "github.com/ProtonMail/go-crypto/openpgp/internal/ecc" "github.com/ProtonMail/go-crypto/openpgp/internal/encoding" + "github.com/ProtonMail/go-crypto/openpgp/symmetric" "github.com/ProtonMail/go-crypto/openpgp/x25519" "github.com/ProtonMail/go-crypto/openpgp/x448" ) @@ -69,6 +72,26 @@ func (pk *PublicKey) UpgradeToV6() error { return pk.checkV6Compatibility() } +// ReplaceKDF replaces the KDF instance, and updates all necessary fields. +func (pk *PublicKey) ReplaceKDF(kdf ecdh.KDF) error { + ecdhKey, ok := pk.PublicKey.(*ecdh.PublicKey) + if !ok { + return goerrors.New("wrong forwarding sub key generation") + } + + ecdhKey.KDF = kdf + byteBuffer := new(bytes.Buffer) + err := kdf.Serialize(byteBuffer) + if err != nil { + return err + } + + pk.kdf = encoding.NewOID(byteBuffer.Bytes()[1:]) + pk.setFingerprintAndKeyId() + + return nil +} + // signingKey provides a convenient abstraction over signature verification // for v3 and v4 public keys. type signingKey interface { @@ -230,6 +253,30 @@ func NewEd448PublicKey(creationTime time.Time, pub *ed448.PublicKey) *PublicKey return pk } +func NewAEADPublicKey(creationTime time.Time, pub *symmetric.AEADPublicKey) *PublicKey { + var pk *PublicKey + pk = &PublicKey{ + Version: 4, + CreationTime: creationTime, + PubKeyAlgo: ExperimentalPubKeyAlgoAEAD, + PublicKey: pub, + } + + return pk +} + +func NewHMACPublicKey(creationTime time.Time, pub *symmetric.HMACPublicKey) *PublicKey { + var pk *PublicKey + pk = &PublicKey{ + Version: 4, + CreationTime: creationTime, + PubKeyAlgo: ExperimentalPubKeyAlgoHMAC, + PublicKey: pub, + } + + return pk +} + func (pk *PublicKey) parse(r io.Reader) (err error) { // RFC 4880, section 5.5.2 var buf [6]byte @@ -280,6 +327,10 @@ func (pk *PublicKey) parse(r io.Reader) (err error) { err = pk.parseEd25519(r) case PubKeyAlgoEd448: err = pk.parseEd448(r) + case ExperimentalPubKeyAlgoAEAD: + err = pk.parseAEAD(r) + case ExperimentalPubKeyAlgoHMAC: + err = pk.parseHMAC(r) default: err = errors.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo))) } @@ -474,11 +525,13 @@ func (pk *PublicKey) parseECDH(r io.Reader) (err error) { return errors.UnsupportedError(fmt.Sprintf("unsupported oid: %x", pk.oid)) } - if kdfLen := len(pk.kdf.Bytes()); kdfLen < 3 { + kdfLen := len(pk.kdf.Bytes()) + if kdfLen < 3 { return errors.UnsupportedError("unsupported ECDH KDF length: " + strconv.Itoa(kdfLen)) } - if reserved := pk.kdf.Bytes()[0]; reserved != 0x01 { - return errors.UnsupportedError("unsupported KDF reserved field: " + strconv.Itoa(int(reserved))) + kdfVersion := int(pk.kdf.Bytes()[0]) + if kdfVersion != ecdh.KDFVersion1 && kdfVersion != ecdh.KDFVersionForwarding { + return errors.UnsupportedError("unsupported ECDH KDF version: " + strconv.Itoa(kdfVersion)) } kdfHash, ok := algorithm.HashById[pk.kdf.Bytes()[1]] if !ok { @@ -489,10 +542,23 @@ func (pk *PublicKey) parseECDH(r io.Reader) (err error) { return errors.UnsupportedError("unsupported ECDH KDF cipher: " + strconv.Itoa(int(pk.kdf.Bytes()[2]))) } - ecdhKey := ecdh.NewPublicKey(c, kdfHash, kdfCipher) + kdf := ecdh.KDF{ + Version: kdfVersion, + Hash: kdfHash, + Cipher: kdfCipher, + } + + if kdfVersion == ecdh.KDFVersionForwarding { + if pk.Version != 4 || kdfLen != 23 { + return errors.UnsupportedError("unsupported ECDH KDF v2 length: " + strconv.Itoa(kdfLen)) + } + + kdf.ReplacementFingerprint = pk.kdf.Bytes()[3:23] + } + + ecdhKey := ecdh.NewPublicKey(c, kdf) err = ecdhKey.UnmarshalPoint(pk.p.Bytes()) pk.PublicKey = ecdhKey - return } @@ -594,6 +660,58 @@ func (pk *PublicKey) parseEd448(r io.Reader) (err error) { return } +func (pk *PublicKey) parseAEAD(r io.Reader) (err error) { + var cipher [1]byte + _, err = readFull(r, cipher[:]) + if err != nil { + return + } + + var bindingHash [32]byte + _, err = readFull(r, bindingHash[:]) + if err != nil { + return + } + + symmetric := &symmetric.AEADPublicKey{ + Cipher: algorithm.CipherFunction(cipher[0]), + BindingHash: bindingHash, + } + + pk.PublicKey = symmetric + return +} + +func (pk *PublicKey) parseHMAC(r io.Reader) (err error) { + var hash [1]byte + _, err = readFull(r, hash[:]) + if err != nil { + return + } + bindingHash, err := readBindingHash(r) + if err != nil { + return + } + + hmacHash, ok := algorithm.HashById[hash[0]] + if !ok { + return errors.UnsupportedError("unsupported HMAC hash: " + strconv.Itoa(int(hash[0]))) + } + + symmetric := &symmetric.HMACPublicKey{ + Hash: hmacHash, + BindingHash: bindingHash, + } + + pk.PublicKey = symmetric + return +} + +func readBindingHash(r io.Reader) (bindingHash [32]byte, err error) { + _, err = readFull(r, bindingHash[:]) + return +} + // SerializeForHash serializes the PublicKey to w with the special packet // header format needed for hashing. func (pk *PublicKey) SerializeForHash(w io.Writer) error { @@ -681,6 +799,9 @@ func (pk *PublicKey) algorithmSpecificByteCount() uint32 { length += ed25519.PublicKeySize case PubKeyAlgoEd448: length += ed448.PublicKeySize + case ExperimentalPubKeyAlgoAEAD, ExperimentalPubKeyAlgoHMAC: + length += 1 // Hash octet + length += 32 // Binding hash default: panic("unknown public key algorithm") } @@ -773,6 +894,22 @@ func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err error) { publicKey := pk.PublicKey.(*ed448.PublicKey) _, err = w.Write(publicKey.Point) return + case ExperimentalPubKeyAlgoAEAD: + symmKey := pk.PublicKey.(*symmetric.AEADPublicKey) + cipherOctet := [1]byte{symmKey.Cipher.Id()} + if _, err = w.Write(cipherOctet[:]); err != nil { + return + } + _, err = w.Write(symmKey.BindingHash[:]) + return + case ExperimentalPubKeyAlgoHMAC: + symmKey := pk.PublicKey.(*symmetric.HMACPublicKey) + hashOctet := [1]byte{symmKey.Hash.Id()} + if _, err = w.Write(hashOctet[:]); err != nil { + return + } + _, err = w.Write(symmKey.BindingHash[:]) + return } return errors.InvalidArgumentError("bad public-key algorithm") } @@ -859,6 +996,17 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro return errors.SignatureError("ed448 verification failure") } return nil + case ExperimentalPubKeyAlgoHMAC: + HMACKey := pk.PublicKey.(*symmetric.HMACPublicKey) + + result, err := HMACKey.Verify(hashBytes, sig.HMAC.Bytes()) + if err != nil { + return err + } + if !result { + return errors.SignatureError("HMAC verification failure") + } + return nil default: return errors.SignatureError("Unsupported public key algorithm used in signature") } @@ -929,6 +1077,13 @@ func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) error } } + // Keys having this flag MUST have the forwarding KDF parameters version 2 defined in Section 5.1. + if sig.FlagForward && (signed.PubKeyAlgo != PubKeyAlgoECDH || + signed.kdf == nil || + signed.kdf.Bytes()[0] != ecdh.KDFVersionForwarding) { + return errors.StructuralError("forwarding key with wrong ecdh kdf version") + } + return nil } @@ -1080,6 +1235,8 @@ func (pk *PublicKey) BitLength() (bitLength uint16, err error) { bitLength = ed25519.PublicKeySize * 8 case PubKeyAlgoEd448: bitLength = ed448.PublicKeySize * 8 + case ExperimentalPubKeyAlgoAEAD: + bitLength = 32 default: err = errors.InvalidArgumentError("bad public-key algorithm") } diff --git a/openpgp/packet/signature.go b/openpgp/packet/signature.go index 3a4b366d..79161a38 100644 --- a/openpgp/packet/signature.go +++ b/openpgp/packet/signature.go @@ -34,7 +34,7 @@ const ( KeyFlagEncryptStorage KeyFlagSplitKey KeyFlagAuthenticate - _ + KeyFlagForward KeyFlagGroupKey ) @@ -81,6 +81,7 @@ type Signature struct { ECDSASigR, ECDSASigS encoding.Field EdDSASigR, EdDSASigS encoding.Field EdSig []byte + HMAC encoding.Field // rawSubpackets contains the unparsed subpackets, in order. rawSubpackets []outputSubpacket @@ -127,8 +128,8 @@ type Signature struct { // FlagsValid is set if any flags were given. See RFC 9580, section // 5.2.3.29 for details. - FlagsValid bool - FlagCertify, FlagSign, FlagEncryptCommunications, FlagEncryptStorage, FlagSplitKey, FlagAuthenticate, FlagGroupKey bool + FlagsValid bool + FlagCertify, FlagSign, FlagEncryptCommunications, FlagEncryptStorage, FlagSplitKey, FlagAuthenticate, FlagGroupKey, FlagForward bool // RevocationReason is set if this signature has been revoked. // See RFC 9580, section 5.2.3.31 for details. @@ -198,7 +199,7 @@ func (sig *Signature) parse(r io.Reader) (err error) { sig.SigType = SignatureType(buf[0]) sig.PubKeyAlgo = PublicKeyAlgorithm(buf[1]) switch sig.PubKeyAlgo { - case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA, PubKeyAlgoECDSA, PubKeyAlgoEdDSA, PubKeyAlgoEd25519, PubKeyAlgoEd448: + case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA, PubKeyAlgoECDSA, PubKeyAlgoEdDSA, PubKeyAlgoEd25519, PubKeyAlgoEd448, ExperimentalPubKeyAlgoHMAC: default: err = errors.UnsupportedError("public key algorithm " + strconv.Itoa(int(sig.PubKeyAlgo))) return @@ -336,6 +337,11 @@ func (sig *Signature) parse(r io.Reader) (err error) { if err != nil { return } + case ExperimentalPubKeyAlgoHMAC: + sig.HMAC = new(encoding.ShortByteString) + if _, err = sig.HMAC.ReadFrom(r); err != nil { + return + } default: panic("unreachable") } @@ -582,6 +588,9 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r if subpacket[0]&KeyFlagAuthenticate != 0 { sig.FlagAuthenticate = true } + if subpacket[0]&KeyFlagForward != 0 { + sig.FlagForward = true + } if subpacket[0]&KeyFlagGroupKey != 0 { sig.FlagGroupKey = true } @@ -996,6 +1005,11 @@ func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey, config *Config) (err e if err == nil { sig.EdSig = signature } + case ExperimentalPubKeyAlgoHMAC: + sigdata, err := priv.PrivateKey.(crypto.Signer).Sign(config.Random(), digest, nil) + if err == nil { + sig.HMAC = encoding.NewShortByteString(sigdata) + } default: err = errors.UnsupportedError("public key algorithm: " + strconv.Itoa(int(sig.PubKeyAlgo))) } @@ -1113,7 +1127,7 @@ func (sig *Signature) Serialize(w io.Writer) (err error) { if len(sig.outSubpackets) == 0 { sig.outSubpackets = sig.rawSubpackets } - if sig.RSASignature == nil && sig.DSASigR == nil && sig.ECDSASigR == nil && sig.EdDSASigR == nil && sig.EdSig == nil { + if sig.RSASignature == nil && sig.DSASigR == nil && sig.ECDSASigR == nil && sig.EdDSASigR == nil && sig.EdSig == nil && sig.HMAC == nil { return errors.InvalidArgumentError("Signature: need to call Sign, SignUserId or SignKey before Serialize") } @@ -1134,6 +1148,8 @@ func (sig *Signature) Serialize(w io.Writer) (err error) { sigLength = ed25519.SignatureSize case PubKeyAlgoEd448: sigLength = ed448.SignatureSize + case ExperimentalPubKeyAlgoHMAC: + sigLength = int(sig.HMAC.EncodedLength()) default: panic("impossible") } @@ -1240,6 +1256,8 @@ func (sig *Signature) serializeBody(w io.Writer) (err error) { err = ed25519.WriteSignature(w, sig.EdSig) case PubKeyAlgoEd448: err = ed448.WriteSignature(w, sig.EdSig) + case ExperimentalPubKeyAlgoHMAC: + _, err = w.Write(sig.HMAC.EncodedBytes()) default: panic("impossible") } @@ -1352,6 +1370,9 @@ func (sig *Signature) buildSubpackets(issuer PublicKey) (subpackets []outputSubp if sig.FlagAuthenticate { flags |= KeyFlagAuthenticate } + if sig.FlagForward { + flags |= KeyFlagForward + } if sig.FlagGroupKey { flags |= KeyFlagGroupKey } diff --git a/openpgp/packet/signature_test.go b/openpgp/packet/signature_test.go index edc43114..19940387 100644 --- a/openpgp/packet/signature_test.go +++ b/openpgp/packet/signature_test.go @@ -82,6 +82,33 @@ ltm2aQaG } } +func TestSymmetricSignatureRead(t *testing.T) { + const serializedPacket = "c272040165080006050260639e4e002109107fc6eeae2d3315b1162104e29ad49f0b7d0b12bb0401407fc6eeae2d3315b13adc400ecca603da8e6f3c82727ffc3e9416bc0236c9665498dda14f1c1dd4e4acacc7725d6dac7598e0951b5f1f8789714fb7fcdda4a9f10056134a7edf9d9a4fc45d" + expectedHMAC := []byte{0x0e, 0xcc, 0xa6, 0x03, 0xda, 0x8e, 0x6f, 0x3c, 0x82, 0x72, 0x7f, 0xfc, 0x3e, 0x94, 0x16, 0xbc, 0x02, 0x36, 0xc9, 0x66, 0x54, 0x98, 0xdd, 0xa1, 0x4f, 0x1c, 0x1d, 0xd4, 0xe4, 0xac, 0xac, 0xc7, 0x72, 0x5d, 0x6d, 0xac, 0x75, 0x98, 0xe0, 0x95, 0x1b, 0x5f, 0x1f, 0x87, 0x89, 0x71, 0x4f, 0xb7, 0xfc, 0xdd, 0xa4, 0xa9, 0xf1, 0x00, 0x56, 0x13, 0x4a, 0x7e, 0xdf, 0x9d, 0x9a, 0x4f, 0xc4, 0x5d} + + packet, err := Read(readerFromHex(serializedPacket)) + if err != nil { + t.Error(err) + } + + sig, ok := packet.(*Signature) + if !ok { + t.Errorf("Did not parse a signature packet") + } + + if sig.PubKeyAlgo != ExperimentalPubKeyAlgoHMAC { + t.Error("Wrong public key algorithm") + } + + if sig.Hash != crypto.SHA256 { + t.Error("Wrong public key algorithm") + } + + if !bytes.Equal(sig.HMAC.Bytes(), expectedHMAC) { + t.Errorf("Wrong HMAC value, got: %x, expected: %x\n", sig.HMAC.Bytes(), expectedHMAC) + } +} + func TestSignatureReserialize(t *testing.T) { packet, _ := Read(readerFromHex(signatureDataHex)) sig := packet.(*Signature) diff --git a/openpgp/read.go b/openpgp/read.go index 8a69b44a..43def2c4 100644 --- a/openpgp/read.go +++ b/openpgp/read.go @@ -118,7 +118,7 @@ ParsePackets: // This packet contains the decryption key encrypted to a public key. md.EncryptedToKeyIds = append(md.EncryptedToKeyIds, p.KeyId) switch p.Algo { - case packet.PubKeyAlgoRSA, packet.PubKeyAlgoRSAEncryptOnly, packet.PubKeyAlgoElGamal, packet.PubKeyAlgoECDH, packet.PubKeyAlgoX25519, packet.PubKeyAlgoX448: + case packet.PubKeyAlgoRSA, packet.PubKeyAlgoRSAEncryptOnly, packet.PubKeyAlgoElGamal, packet.PubKeyAlgoECDH, packet.PubKeyAlgoX25519, packet.PubKeyAlgoX448, packet.ExperimentalPubKeyAlgoAEAD: break default: continue diff --git a/openpgp/read_test.go b/openpgp/read_test.go index 318d927e..99c390bd 100644 --- a/openpgp/read_test.go +++ b/openpgp/read_test.go @@ -28,6 +28,13 @@ func readerFromHex(s string) io.Reader { return bytes.NewBuffer(data) } +func TestReadKeyRingWithSymmetricSubkey(t *testing.T) { + _, err := ReadArmoredKeyRing(strings.NewReader(keyWithAEADSubkey)) + if err != nil { + t.Error("could not read keyring", err) + } +} + func TestReadKeyRing(t *testing.T) { kring, err := ReadKeyRing(readerFromHex(testKeys1And2Hex)) if err != nil { @@ -750,7 +757,7 @@ func TestSymmetricAeadEaxOpenPGPJsMessage(t *testing.T) { } // Decrypt with key - var edp = p.(*packet.AEADEncrypted) + edp := p.(*packet.AEADEncrypted) rc, err := edp.Decrypt(packet.CipherFunction(0), key) if err != nil { panic(err) diff --git a/openpgp/read_write_test_data.go b/openpgp/read_write_test_data.go index 670d6022..77282c0e 100644 --- a/openpgp/read_write_test_data.go +++ b/openpgp/read_write_test_data.go @@ -455,3 +455,21 @@ byVJHvLO/XErtC+GNIJeMg== =liRq -----END PGP MESSAGE----- ` + +// A key that contains a persistent AEAD subkey +const keyWithAEADSubkey = `-----BEGIN PGP PRIVATE KEY BLOCK----- + +xVgEYs/4KxYJKwYBBAHaRw8BAQdA7tIsntXluwloh/H62PJMqasjP00M86fv +/Pof9A968q8AAQDYcgkPKUdWAxsDjDHJfouPS4q5Me3ks+umlo5RJdwLZw4k +zQ1TeW1tZXRyaWMgS2V5wowEEBYKAB0FAmLP+CsECwkHCAMVCAoEFgACAQIZ +AQIbAwIeAQAhCRDkNhFDvaU8vxYhBDJNoyEFquVOCf99d+Q2EUO9pTy/5XQA +/1F2YPouv0ydBDJU3EOS/4bmPt7yqvzciWzeKVEOkzYuAP9OsP7q/5ccqOPX +mmRUKwd82/cNjdzdnWZ8Tq89XMwMAMdqBGLP+CtkCfFyZxOMF0BWLwAE8pLy +RVj2n2K7k6VvrhyuTqDkFDUFALiSLrEfnmTKlsPYS3/YzsODF354ccR63q73 +3lmCrvFRyaf6AHvVrBYPbJR+VhuTjZTwZKvPPKv0zVdSqi5JDEQiocJ4BBgW +CAAJBQJiz/grAhsMACEJEOQ2EUO9pTy/FiEEMk2jIQWq5U4J/3135DYRQ72l +PL+fEQEA7RaRbfa+AtiRN7a4GuqVEDZi3qtQZ2/Qcb27/LkAD0sA/3r9drYv +jyu46h1fdHHyo0HS2MiShZDZ8u60JnDltloD +=8TxH +-----END PGP PRIVATE KEY BLOCK----- +` diff --git a/openpgp/symmetric/aead.go b/openpgp/symmetric/aead.go new file mode 100644 index 00000000..044b1394 --- /dev/null +++ b/openpgp/symmetric/aead.go @@ -0,0 +1,76 @@ +package symmetric + +import ( + "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" + "io" +) + +type AEADPublicKey struct { + Cipher algorithm.CipherFunction + BindingHash [32]byte + Key []byte +} + +type AEADPrivateKey struct { + PublicKey AEADPublicKey + HashSeed [32]byte + Key []byte +} + +func AEADGenerateKey(rand io.Reader, cipher algorithm.CipherFunction) (priv *AEADPrivateKey, err error) { + priv, err = generatePrivatePartAEAD(rand, cipher) + if err != nil { + return + } + + priv.generatePublicPartAEAD(cipher) + return +} + +func generatePrivatePartAEAD(rand io.Reader, cipher algorithm.CipherFunction) (priv *AEADPrivateKey, err error) { + priv = new(AEADPrivateKey) + var seed [32] byte + _, err = rand.Read(seed[:]) + if err != nil { + return + } + + key := make([]byte, cipher.KeySize()) + _, err = rand.Read(key) + if err != nil { + return + } + + priv.HashSeed = seed + priv.Key = key + return +} + +func (priv *AEADPrivateKey) generatePublicPartAEAD(cipher algorithm.CipherFunction) (err error) { + priv.PublicKey.Cipher = cipher + + bindingHash := ComputeBindingHash(priv.HashSeed) + + priv.PublicKey.Key = make([]byte, len(priv.Key)) + copy(priv.PublicKey.Key, priv.Key) + copy(priv.PublicKey.BindingHash[:], bindingHash) + return +} + +func (pub *AEADPublicKey) Encrypt(rand io.Reader, data []byte, mode algorithm.AEADMode) (nonce []byte, ciphertext []byte, err error) { + block := pub.Cipher.New(pub.Key) + aead := mode.New(block) + nonce = make([]byte, aead.NonceSize()) + rand.Read(nonce) + ciphertext = aead.Seal(nil, nonce, data, nil) + return +} + +func (priv *AEADPrivateKey) Decrypt(nonce []byte, ciphertext []byte, mode algorithm.AEADMode) (message []byte, err error) { + + block := priv.PublicKey.Cipher.New(priv.Key) + aead := mode.New(block) + message, err = aead.Open(nil, nonce, ciphertext, nil) + return +} + diff --git a/openpgp/symmetric/hmac.go b/openpgp/symmetric/hmac.go new file mode 100644 index 00000000..fd4a7cbb --- /dev/null +++ b/openpgp/symmetric/hmac.go @@ -0,0 +1,109 @@ +package symmetric + +import ( + "crypto" + "crypto/hmac" + "crypto/sha256" + "io" + + "github.com/ProtonMail/go-crypto/openpgp/errors" + "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" +) + +type HMACPublicKey struct { + Hash algorithm.Hash + BindingHash [32]byte + // While this is a "public" key, the symmetric key needs to be present here. + // Symmetric cryptographic operations use the same key material for + // signing and verifying, and go-crypto assumes that a public key type will + // be used for verification. Thus, this `Key` field must never be exported + // publicly. + Key []byte +} + +type HMACPrivateKey struct { + PublicKey HMACPublicKey + HashSeed [32]byte + Key []byte +} + +func HMACGenerateKey(rand io.Reader, hash algorithm.Hash) (priv *HMACPrivateKey, err error) { + priv, err = generatePrivatePartHMAC(rand, hash) + if err != nil { + return + } + + priv.generatePublicPartHMAC(hash) + return +} + +func generatePrivatePartHMAC(rand io.Reader, hash algorithm.Hash) (priv *HMACPrivateKey, err error) { + priv = new(HMACPrivateKey) + var seed [32] byte + _, err = rand.Read(seed[:]) + if err != nil { + return + } + + key := make([]byte, hash.Size()) + _, err = rand.Read(key) + if err != nil { + return + } + + priv.HashSeed = seed + priv.Key = key + return +} + +func (priv *HMACPrivateKey) generatePublicPartHMAC(hash algorithm.Hash) (err error) { + priv.PublicKey.Hash = hash + + bindingHash := ComputeBindingHash(priv.HashSeed) + copy(priv.PublicKey.BindingHash[:], bindingHash) + + priv.PublicKey.Key = make([]byte, len(priv.Key)) + copy(priv.PublicKey.Key, priv.Key) + return +} + +func ComputeBindingHash(seed [32]byte) []byte { + bindingHash := sha256.New() + bindingHash.Write(seed[:]) + + return bindingHash.Sum(nil) +} + +func (priv *HMACPrivateKey) Public() crypto.PublicKey { + return &priv.PublicKey +} + +func (priv *HMACPrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { + expectedMAC, err := calculateMAC(priv.PublicKey.Hash, priv.Key, digest) + if err != nil { + return + } + signature = make([]byte, len(expectedMAC)) + copy(signature, expectedMAC) + return +} + +func (pub *HMACPublicKey) Verify(digest []byte, signature []byte) (bool, error) { + expectedMAC, err := calculateMAC(pub.Hash, pub.Key, digest) + if err != nil { + return false, err + } + return hmac.Equal(expectedMAC, signature), nil +} + +func calculateMAC(hash algorithm.Hash, key []byte, data []byte) ([]byte, error) { + hashFunc := hash.HashFunc() + if !hashFunc.Available() { + return nil, errors.UnsupportedError("hash function") + } + + mac := hmac.New(hashFunc.New, key) + mac.Write(data) + + return mac.Sum(nil), nil +} diff --git a/openpgp/v2/forwarding.go b/openpgp/v2/forwarding.go new file mode 100644 index 00000000..1306c510 --- /dev/null +++ b/openpgp/v2/forwarding.go @@ -0,0 +1,159 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package v2 + +import ( + goerrors "errors" + + "github.com/ProtonMail/go-crypto/openpgp/ecdh" + "github.com/ProtonMail/go-crypto/openpgp/errors" + "github.com/ProtonMail/go-crypto/openpgp/packet" +) + +// NewForwardingEntity generates a new forwardee key and derives the proxy parameters from the entity e. +// If strict, it will return an error if encryption-capable non-revoked subkeys with a wrong algorithm are found, +// instead of ignoring them +func (e *Entity) NewForwardingEntity( + name, comment, email string, config *packet.Config, strict bool, +) ( + forwardeeKey *Entity, instances []packet.ForwardingInstance, err error, +) { + if e.PrimaryKey.Version != 4 { + return nil, nil, errors.InvalidArgumentError("unsupported key version") + } + + now := config.Now() + + if _, err = e.VerifyPrimaryKey(now, config); err != nil { + return nil, nil, err + } + + // Generate a new Primary key for the forwardee + config.Algorithm = packet.PubKeyAlgoEdDSA + config.Curve = packet.Curve25519 + + forwardeePrimaryPrivRaw, err := newSigner(config) + if err != nil { + return nil, nil, err + } + + primary := packet.NewSignerPrivateKey(now, forwardeePrimaryPrivRaw) + + forwardeeKey = &Entity{ + PrimaryKey: &primary.PublicKey, + PrivateKey: primary, + Identities: make(map[string]*Identity), + Subkeys: []Subkey{}, + } + + keyProperties := selectKeyProperties(now, config, primary) + err = forwardeeKey.addUserId(userIdData{name, comment, email}, config, keyProperties) + if err != nil { + return nil, nil, err + } + + // Init empty instances + instances = []packet.ForwardingInstance{} + + // Handle all forwarder subkeys + for _, forwarderSubKey := range e.Subkeys { + // Filter flags + if !forwarderSubKey.PublicKey.PubKeyAlgo.CanEncrypt() { + continue + } + + forwarderSubKeySelfSig, err := forwarderSubKey.Verify(now, config) + // Filter expiration & revokal + if err != nil { + continue + } + + if forwarderSubKey.PublicKey.PubKeyAlgo != packet.PubKeyAlgoECDH { + if strict { + return nil, nil, errors.InvalidArgumentError("encryption subkey is not algorithm 18 (ECDH)") + } else { + continue + } + } + + forwarderEcdhKey, ok := forwarderSubKey.PrivateKey.PrivateKey.(*ecdh.PrivateKey) + if !ok { + return nil, nil, errors.InvalidArgumentError("malformed key") + } + + err = forwardeeKey.addEncryptionSubkey(config, now, 0) + if err != nil { + return nil, nil, err + } + + forwardeeSubKey := forwardeeKey.Subkeys[len(forwardeeKey.Subkeys)-1] + forwardeeSubKeySelfSig := forwardeeSubKey.Bindings[0].Packet + + forwardeeEcdhKey, ok := forwardeeSubKey.PrivateKey.PrivateKey.(*ecdh.PrivateKey) + if !ok { + return nil, nil, goerrors.New("wrong forwarding sub key generation") + } + + instance := packet.ForwardingInstance{ + KeyVersion: 4, + ForwarderFingerprint: forwarderSubKey.PublicKey.Fingerprint, + } + + instance.ProxyParameter, err = ecdh.DeriveProxyParam(forwarderEcdhKey, forwardeeEcdhKey) + if err != nil { + return nil, nil, err + } + + kdf := ecdh.KDF{ + Version: ecdh.KDFVersionForwarding, + Hash: forwarderEcdhKey.KDF.Hash, + Cipher: forwarderEcdhKey.KDF.Cipher, + } + + // If deriving a forwarding key from a forwarding key + if forwarderSubKeySelfSig.FlagForward { + if forwarderEcdhKey.KDF.Version != ecdh.KDFVersionForwarding { + return nil, nil, goerrors.New("malformed forwarder key") + } + kdf.ReplacementFingerprint = forwarderEcdhKey.KDF.ReplacementFingerprint + } else { + kdf.ReplacementFingerprint = forwarderSubKey.PublicKey.Fingerprint + } + + err = forwardeeSubKey.PublicKey.ReplaceKDF(kdf) + if err != nil { + return nil, nil, err + } + + // Extract fingerprint after changing the KDF + instance.ForwardeeFingerprint = forwardeeSubKey.PublicKey.Fingerprint + + // 0x04 - This key may be used to encrypt communications. + forwardeeSubKeySelfSig.FlagEncryptCommunications = false + + // 0x08 - This key may be used to encrypt storage. + forwardeeSubKeySelfSig.FlagEncryptStorage = false + + // 0x10 - The private component of this key may have been split by a secret-sharing mechanism. + forwardeeSubKeySelfSig.FlagSplitKey = true + + // 0x40 - This key may be used for forwarded communications. + forwardeeSubKeySelfSig.FlagForward = true + + err = forwardeeSubKeySelfSig.SignKey(forwardeeSubKey.PublicKey, forwardeeKey.PrivateKey, config) + if err != nil { + return nil, nil, err + } + + // Append each valid instance to the list + instances = append(instances, instance) + } + + if len(instances) == 0 { + return nil, nil, errors.InvalidArgumentError("no valid subkey found") + } + + return forwardeeKey, instances, nil +} diff --git a/openpgp/v2/forwarding_test.go b/openpgp/v2/forwarding_test.go new file mode 100644 index 00000000..21c8be0e --- /dev/null +++ b/openpgp/v2/forwarding_test.go @@ -0,0 +1,223 @@ +package v2 + +import ( + "bytes" + "crypto/rand" + goerrors "errors" + "io" + "io/ioutil" + "strings" + "testing" + + "github.com/ProtonMail/go-crypto/openpgp/armor" + "github.com/ProtonMail/go-crypto/openpgp/packet" +) + +const forwardeeKey = `-----BEGIN PGP PRIVATE KEY BLOCK----- + +xVgEZQRXoxYJKwYBBAHaRw8BAQdAhxdzZ8ZP1M4UcauXSGbts38KhhAZxHNRcChs +9H7danMAAQC4tHykQmFpnlvhLYJDDc4MJm68mUB9qUls34GgKkqKNw6FzRtjaGFy +bGVzIDxjaGFybGVzQHByb3Rvbi5tZT7CiwQTFggAPQUCZQRXowkQizX+kwlYIwMW +IQTYm4qmQoyzTnG0eZKLNf6TCVgjAwIbAwIeAQIZAQILBwIVCAIWAAMnBwIAAMsQ +AQD9UHMIU418Z10UQrymhbjkGq/PHCytaaneaq5oycpN/QD/UiK3aA4+HxWhX/F2 +VrvEKL5a2xyd1AKKQ2DInF3xUg3HcQRlBFejEgorBgEEAZdVAQUBAQdAep7x8ncL +ShzEgKL6h9MAJbgX2z3BBgSLeAdg/rczKngX/woJjSg9O4DzqQOtAvdhYkDoOCNf +QgUAAP9OMqK0IwNmshCtktDy1/RTeyPKT8ItHDFAZ1ReKMA5CA63wngEGBYIACoF +AmUEV6MJEIs1/pMJWCMDFiEE2JuKpkKMs05xtHmSizX+kwlYIwMCG1wAAC5EAP9s +AbYBf9NGv1NxJvU0n0K++k3UIGkw9xgGJa3VFHFKvwEAx0DZpTVpCkJmiOFAOcfu +cSvjlMyQwsC/hAAzQpcqvwE= +=8LJg +-----END PGP PRIVATE KEY BLOCK-----` + +const forwardedMessage = `-----BEGIN PGP MESSAGE----- + +wV4DKsXbtIU9/JMSAQdA/6+foCjeUhS7Xto3fimUi6pfMQ/Ft3caHkK/1i767isw +NvG8xRbjQ0sAE1IZVGE1MBcVhCIbHhqp0h2J479Zmfn/iP7hfomYxrkJ/6UMnlEo +0kABKyyfO3QVAzBBNeq6hH27uqzwLgjWVrpgY7dmWPv0goSSaqHUda0lm+8JNUuF +wssOJTwrSwQrX3ezy5D/h/E6 +=okS+ +-----END PGP MESSAGE-----` + +const forwardedPlaintext = "Message for Bob" + +func TestForwardingStatic(t *testing.T) { + charlesKey, err := ReadArmoredKeyRing(bytes.NewBufferString(forwardeeKey)) + if err != nil { + t.Error(err) + return + } + + ciphertext, err := armor.Decode(strings.NewReader(forwardedMessage)) + if err != nil { + t.Error(err) + return + } + + m, err := ReadMessage(ciphertext.Body, charlesKey, nil, nil) + if err != nil { + t.Fatal(err) + } + + dec, err := ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, []byte(forwardedPlaintext)) { + t.Fatal("forwarded decrypted does not match original") + } +} + +func TestForwardingFull(t *testing.T) { + keyConfig := &packet.Config{ + Algorithm: packet.PubKeyAlgoEdDSA, + Curve: packet.Curve25519, + } + + plaintext := make([]byte, 1024) + rand.Read(plaintext) + + bobEntity, err := NewEntity("bob", "", "bob@proton.me", keyConfig) + if err != nil { + t.Fatal(err) + } + + charlesEntity, instances, err := bobEntity.NewForwardingEntity("charles", "", "charles@proton.me", keyConfig, true) + if err != nil { + t.Fatal(err) + } + + charlesEntity = serializeAndParseForwardeeKey(t, charlesEntity) + + if len(instances) != 1 { + t.Fatalf("invalid number of instances, expected 1 got %d", len(instances)) + } + + if !bytes.Equal(instances[0].ForwarderFingerprint, bobEntity.Subkeys[0].PublicKey.Fingerprint) { + t.Fatalf("invalid forwarder key ID, expected: %x, got: %x", bobEntity.Subkeys[0].PublicKey.Fingerprint, instances[0].ForwarderFingerprint) + } + + if !bytes.Equal(instances[0].ForwardeeFingerprint, charlesEntity.Subkeys[0].PublicKey.Fingerprint) { + t.Fatalf("invalid forwardee key ID, expected: %x, got: %x", charlesEntity.Subkeys[0].PublicKey.Fingerprint, instances[0].ForwardeeFingerprint) + } + + // Encrypt message + buf := bytes.NewBuffer(nil) + w, err := Encrypt(buf, []*Entity{bobEntity}, nil, nil, nil, nil) + if err != nil { + t.Fatal(err) + } + + _, err = w.Write(plaintext) + if err != nil { + t.Fatal(err) + } + + err = w.Close() + if err != nil { + t.Fatal(err) + } + + encrypted := buf.Bytes() + + // Decrypt message for Bob + m, err := ReadMessage(bytes.NewBuffer(encrypted), EntityList([]*Entity{bobEntity}), nil, nil) + if err != nil { + t.Fatal(err) + } + dec, err := ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, plaintext) { + t.Fatal("decrypted does not match original") + } + + // Forward message + transformed := transformTestMessage(t, encrypted, instances[0]) + + // Decrypt forwarded message for Charles + m, err = ReadMessage(bytes.NewBuffer(transformed), EntityList([]*Entity{charlesEntity}), nil /* no prompt */, nil) + if err != nil { + t.Fatal(err) + } + + dec, err = ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, plaintext) { + t.Fatal("forwarded decrypted does not match original") + } + + // Setup further forwarding + danielEntity, secondForwardInstances, err := charlesEntity.NewForwardingEntity("Daniel", "", "daniel@proton.me", keyConfig, true) + if err != nil { + t.Fatal(err) + } + + danielEntity = serializeAndParseForwardeeKey(t, danielEntity) + + secondTransformed := transformTestMessage(t, transformed, secondForwardInstances[0]) + + // Decrypt forwarded message for Charles + m, err = ReadMessage(bytes.NewBuffer(secondTransformed), EntityList([]*Entity{danielEntity}), nil /* no prompt */, nil) + if err != nil { + t.Fatal(err) + } + + dec, err = ioutil.ReadAll(m.decrypted) + + if !bytes.Equal(dec, plaintext) { + t.Fatal("forwarded decrypted does not match original") + } +} + +func transformTestMessage(t *testing.T, encrypted []byte, instance packet.ForwardingInstance) []byte { + bytesReader := bytes.NewReader(encrypted) + packets := packet.NewReader(bytesReader) + splitPoint := int64(0) + transformedEncryptedKey := bytes.NewBuffer(nil) + +Loop: + for { + p, err := packets.Next() + if goerrors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatalf("error in parsing message: %s", err) + } + switch p := p.(type) { + case *packet.EncryptedKey: + tp, err := p.ProxyTransform(instance) + if err != nil { + t.Fatalf("error transforming PKESK: %s", err) + } + + splitPoint = bytesReader.Size() - int64(bytesReader.Len()) + + err = tp.Serialize(transformedEncryptedKey) + if err != nil { + t.Fatalf("error serializing transformed PKESK: %s", err) + } + break Loop + } + } + + transformed := transformedEncryptedKey.Bytes() + transformed = append(transformed, encrypted[splitPoint:]...) + + return transformed +} + +func serializeAndParseForwardeeKey(t *testing.T, key *Entity) *Entity { + serializedEntity := bytes.NewBuffer(nil) + err := key.SerializePrivateWithoutSigning(serializedEntity, nil) + if err != nil { + t.Fatalf("Error in serializing forwardee key: %s", err) + } + el, err := ReadKeyRing(serializedEntity) + if err != nil { + t.Fatalf("Error in reading forwardee key: %s", err) + } + + if len(el) != 1 { + t.Fatalf("Wrong number of entities in parsing, expected 1, got %d", len(el)) + } + + return el[0] +} diff --git a/openpgp/v2/key_generation.go b/openpgp/v2/key_generation.go index 1617d48b..aee9d61b 100644 --- a/openpgp/v2/key_generation.go +++ b/openpgp/v2/key_generation.go @@ -22,6 +22,7 @@ import ( "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" "github.com/ProtonMail/go-crypto/openpgp/internal/ecc" "github.com/ProtonMail/go-crypto/openpgp/packet" + "github.com/ProtonMail/go-crypto/openpgp/symmetric" "github.com/ProtonMail/go-crypto/openpgp/x25519" "github.com/ProtonMail/go-crypto/openpgp/x448" ) @@ -393,6 +394,9 @@ func newSigner(config *packet.Config) (signer interface{}, err error) { return nil, err } return priv, nil + case packet.ExperimentalPubKeyAlgoHMAC: + hash := algorithm.HashById[hashToHashId(config.Hash())] + return symmetric.HMACGenerateKey(config.Random(), hash) default: return nil, errors.InvalidArgumentError("unsupported public key algorithm") } @@ -435,6 +439,9 @@ func newDecrypter(config *packet.Config) (decrypter interface{}, err error) { return x25519.GenerateKey(config.Random()) case packet.PubKeyAlgoEd448, packet.PubKeyAlgoX448: // When passing Ed448, we generate an x448 subkey return x448.GenerateKey(config.Random()) + case packet.ExperimentalPubKeyAlgoHMAC, packet.ExperimentalPubKeyAlgoAEAD: // When passing HMAC, we generate an AEAD subkey + cipher := algorithm.CipherFunction(config.Cipher()) + return symmetric.AEADGenerateKey(config.Random(), cipher) default: return nil, errors.InvalidArgumentError("unsupported public key algorithm") } diff --git a/openpgp/v2/keys.go b/openpgp/v2/keys.go index b4a7cc1e..192ebbaf 100644 --- a/openpgp/v2/keys.go +++ b/openpgp/v2/keys.go @@ -61,7 +61,7 @@ func (e *Entity) PrimaryIdentity(date time.Time, config *packet.Config) (*packet var primaryIdentityCandidatesSelfSigs []*packet.Signature for _, identity := range e.Identities { selfSig, err := identity.Verify(date, config) // identity must be valid at date - if err == nil { // verification is successful + if err == nil { // verification is successful primaryIdentityCandidates = append(primaryIdentityCandidates, identity) primaryIdentityCandidatesSelfSigs = append(primaryIdentityCandidatesSelfSigs, selfSig) } @@ -163,12 +163,12 @@ func (e *Entity) DecryptionKeys(id uint64, date time.Time, config *packet.Config for _, subkey := range e.Subkeys { subkeySelfSig, err := subkey.LatestValidBindingSignature(date, config) if err == nil && - isValidEncryptionKey(subkeySelfSig, subkey.PublicKey.PubKeyAlgo) && + isValidDecryptionKey(subkeySelfSig, subkey.PublicKey.PubKeyAlgo) && (id == 0 || subkey.PublicKey.KeyId == id) { keys = append(keys, Key{subkey.Primary, primarySelfSignature, subkey.PublicKey, subkey.PrivateKey, subkeySelfSig}) } } - if isValidEncryptionKey(primarySelfSignature, e.PrimaryKey.PubKeyAlgo) { + if isValidDecryptionKey(primarySelfSignature, e.PrimaryKey.PubKeyAlgo) { keys = append(keys, Key{e, primarySelfSignature, e.PrimaryKey, e.PrivateKey, primarySelfSignature}) } return @@ -608,6 +608,10 @@ func (e *Entity) serializePrivate(w io.Writer, config *packet.Config, reSign boo // Serialize writes the public part of the given Entity to w, including // signatures from other entities. No private key material will be output. func (e *Entity) Serialize(w io.Writer) error { + if e.PrimaryKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoHMAC || + e.PrimaryKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoAEAD { + return errors.InvalidArgumentError("Can't serialize symmetric primary key") + } if err := e.PrimaryKey.Serialize(w); err != nil { return err } @@ -628,6 +632,16 @@ func (e *Entity) Serialize(w io.Writer) error { } } for _, subkey := range e.Subkeys { + // The types of keys below are only useful as private keys. Thus, the + // public key packets contain no meaningful information and do not need + // to be serialized. + // Prevent public key export for forwarding keys, see forwarding section 4.1. + subKeySelfSig, err := subkey.LatestValidBindingSignature(time.Time{}, nil) + if subkey.PublicKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoHMAC || + subkey.PublicKey.PubKeyAlgo == packet.ExperimentalPubKeyAlgoAEAD || + (err == nil && subKeySelfSig.FlagForward) { + continue + } if err := subkey.Serialize(w, false); err != nil { return err } @@ -786,3 +800,9 @@ func isValidEncryptionKey(signature *packet.Signature, algo packet.PublicKeyAlgo signature.FlagsValid && (signature.FlagEncryptCommunications || signature.FlagEncryptStorage) } + +func isValidDecryptionKey(signature *packet.Signature, algo packet.PublicKeyAlgorithm) bool { + return algo.CanEncrypt() && + signature.FlagsValid && + (signature.FlagEncryptCommunications || signature.FlagForward || signature.FlagEncryptStorage) +} diff --git a/openpgp/v2/keys_test.go b/openpgp/v2/keys_test.go index 0b276c23..c9d27734 100644 --- a/openpgp/v2/keys_test.go +++ b/openpgp/v2/keys_test.go @@ -22,6 +22,7 @@ import ( "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" "github.com/ProtonMail/go-crypto/openpgp/packet" "github.com/ProtonMail/go-crypto/openpgp/s2k" + "github.com/ProtonMail/go-crypto/openpgp/symmetric" ) var hashes = []crypto.Hash{ @@ -2022,3 +2023,222 @@ NciH07RTRuMS/aRhRg4OB8PQROmTnZ+iZS0= t.Fatal(err) } } + +func TestAddHMACSubkey(t *testing.T) { + c := &packet.Config{ + RSABits: 512, + Algorithm: packet.ExperimentalPubKeyAlgoHMAC, + } + + entity, err := NewEntity("Golang Gopher", "Test Key", "no-reply@golang.com", &packet.Config{RSABits: 1024}) + if err != nil { + t.Fatal(err) + } + + err = entity.AddSigningSubkey(c) + if err != nil { + t.Fatal(err) + } + + buf := bytes.NewBuffer(nil) + w, _ := armor.Encode(buf, "PGP PRIVATE KEY BLOCK", nil) + if err := entity.SerializePrivate(w, nil); err != nil { + t.Errorf("failed to serialize entity: %s", err) + } + w.Close() + + key, err := ReadArmoredKeyRing(buf) + if err != nil { + t.Error("could not read keyring", err) + } + + generatedPrivateKey := entity.Subkeys[1].PrivateKey.PrivateKey.(*symmetric.HMACPrivateKey) + parsedPrivateKey := key[0].Subkeys[1].PrivateKey.PrivateKey.(*symmetric.HMACPrivateKey) + + generatedPublicKey := entity.Subkeys[1].PublicKey.PublicKey.(*symmetric.HMACPublicKey) + parsedPublicKey := key[0].Subkeys[1].PublicKey.PublicKey.(*symmetric.HMACPublicKey) + + if !bytes.Equal(parsedPrivateKey.Key, generatedPrivateKey.Key) { + t.Error("parsed wrong key") + } + if !bytes.Equal(parsedPublicKey.Key, generatedPrivateKey.Key) { + t.Error("parsed wrong key in public part") + } + if !bytes.Equal(generatedPublicKey.Key, generatedPrivateKey.Key) { + t.Error("generated Public and Private Key differ") + } + + if !bytes.Equal(parsedPrivateKey.HashSeed[:], generatedPrivateKey.HashSeed[:]) { + t.Error("parsed wrong hash seed") + } + + if parsedPrivateKey.PublicKey.Hash != generatedPrivateKey.PublicKey.Hash { + t.Error("parsed wrong cipher id") + } + if !bytes.Equal(parsedPrivateKey.PublicKey.BindingHash[:], generatedPrivateKey.PublicKey.BindingHash[:]) { + t.Error("parsed wrong binding hash") + } +} + +func TestSerializeSymmetricSubkeyError(t *testing.T) { + entity, err := NewEntity("Golang Gopher", "Test Key", "no-reply@golang.com", &packet.Config{RSABits: 1024}) + if err != nil { + t.Fatal(err) + } + + buf := bytes.NewBuffer(nil) + w, _ := armor.Encode(buf, "PGP PRIVATE KEY BLOCK", nil) + + entity.PrimaryKey.PubKeyAlgo = 100 + err = entity.Serialize(w) + if err == nil { + t.Fatal(err) + } + + entity.PrimaryKey.PubKeyAlgo = 101 + err = entity.Serialize(w) + if err == nil { + t.Fatal(err) + } +} + +func TestAddAEADSubkey(t *testing.T) { + c := &packet.Config{ + RSABits: 512, + Algorithm: packet.ExperimentalPubKeyAlgoAEAD, + } + entity, err := NewEntity("Golang Gopher", "Test Key", "no-reply@golang.com", &packet.Config{RSABits: 1024}) + if err != nil { + t.Fatal(err) + } + + err = entity.AddEncryptionSubkey(c) + if err != nil { + t.Fatal(err) + } + + generatedPrivateKey := entity.Subkeys[1].PrivateKey.PrivateKey.(*symmetric.AEADPrivateKey) + + buf := bytes.NewBuffer(nil) + w, _ := armor.Encode(buf, "PGP PRIVATE KEY BLOCK", nil) + if err := entity.SerializePrivate(w, nil); err != nil { + t.Errorf("failed to serialize entity: %s", err) + } + w.Close() + + key, err := ReadArmoredKeyRing(buf) + if err != nil { + t.Error("could not read keyring", err) + } + + parsedPrivateKey := key[0].Subkeys[1].PrivateKey.PrivateKey.(*symmetric.AEADPrivateKey) + + generatedPublicKey := entity.Subkeys[1].PublicKey.PublicKey.(*symmetric.AEADPublicKey) + parsedPublicKey := key[0].Subkeys[1].PublicKey.PublicKey.(*symmetric.AEADPublicKey) + + if !bytes.Equal(parsedPrivateKey.Key, generatedPrivateKey.Key) { + t.Error("parsed wrong key") + } + if !bytes.Equal(parsedPublicKey.Key, generatedPrivateKey.Key) { + t.Error("parsed wrong key in public part") + } + if !bytes.Equal(generatedPublicKey.Key, generatedPrivateKey.Key) { + t.Error("generated Public and Private Key differ") + } + + if !bytes.Equal(parsedPrivateKey.HashSeed[:], generatedPrivateKey.HashSeed[:]) { + t.Error("parsed wrong hash seed") + } + + if parsedPrivateKey.PublicKey.Cipher.Id() != generatedPrivateKey.PublicKey.Cipher.Id() { + t.Error("parsed wrong cipher id") + } + if !bytes.Equal(parsedPrivateKey.PublicKey.BindingHash[:], generatedPrivateKey.PublicKey.BindingHash[:]) { + t.Error("parsed wrong binding hash") + } +} + +func TestNoSymmetricKeySerialized(t *testing.T) { + aeadConfig := &packet.Config{ + RSABits: 512, + DefaultHash: crypto.SHA512, + Algorithm: packet.ExperimentalPubKeyAlgoAEAD, + DefaultCipher: packet.CipherAES256, + } + hmacConfig := &packet.Config{ + RSABits: 512, + DefaultHash: crypto.SHA512, + Algorithm: packet.ExperimentalPubKeyAlgoHMAC, + DefaultCipher: packet.CipherAES256, + } + entity, err := NewEntity("Golang Gopher", "Test Key", "no-reply@golang.com", &packet.Config{RSABits: 1024}) + if err != nil { + t.Fatal(err) + } + + err = entity.AddEncryptionSubkey(aeadConfig) + if err != nil { + t.Fatal(err) + } + err = entity.AddSigningSubkey(hmacConfig) + if err != nil { + t.Fatal(err) + } + + w := bytes.NewBuffer(nil) + entity.Serialize(w) + + firstSymKey := entity.Subkeys[1].PrivateKey.PrivateKey.(*symmetric.AEADPrivateKey).Key + i := bytes.Index(w.Bytes(), firstSymKey) + + secondSymKey := entity.Subkeys[2].PrivateKey.PrivateKey.(*symmetric.HMACPrivateKey).Key + k := bytes.Index(w.Bytes(), secondSymKey) + + if (i > 0) || (k > 0) { + t.Error("Private key was serialized with public") + } + + firstBindingHash := entity.Subkeys[1].PublicKey.PublicKey.(*symmetric.AEADPublicKey).BindingHash + i = bytes.Index(w.Bytes(), firstBindingHash[:]) + + secondBindingHash := entity.Subkeys[2].PublicKey.PublicKey.(*symmetric.HMACPublicKey).BindingHash + k = bytes.Index(w.Bytes(), secondBindingHash[:]) + if (i > 0) || (k > 0) { + t.Errorf("Symmetric public key metadata exported %d %d", i, k) + } + +} + +func TestSymmetricKeys(t *testing.T) { + data := `-----BEGIN PGP PRIVATE KEY BLOCK----- + +xWoEYs7w5mUIcFvlmkuricX26x138uvHGlwIaxWIbRnx1+ggPcveTcwA4zSZ +n6XcD0Q5aLe6dTEBwCyfUecZ/nA0W8Pl9xBHfjIjQuxcUBnIqxZ061RZPjef +D/XIQga1ftLDelhylQwL7R3TzQ1TeW1tZXRyaWMgS2V5wmkEEGUIAB0FAmLO +8OYECwkHCAMVCAoEFgACAQIZAQIbAwIeAQAhCRCRTKq2ObiQKxYhBMHTTXXF +ULQ2M2bYNJFMqrY5uJArIawgJ+5RSsN8VNuZTKJbG88TIedU05wwKjW3wqvT +X6Z7yfbHagRizvDmZAluL/kJo6hZ1kFENpQkWD/Kfv1vAG3nbxhsVEzBQ6a1 +OAD24BaKJz6gWgj4lASUNK5OuXnLc3J79Bt1iRGkSbiPzRs/bplB4TwbILeC +ZLeDy9kngZDosgsIk5sBgGEqS9y5HiHCVQQYZQgACQUCYs7w5gIbDAAhCRCR +TKq2ObiQKxYhBMHTTXXFULQ2M2bYNJFMqrY5uJArENkgL0Bc+OI/1na0XWqB +TxGVotQ4A/0u0VbOMEUfnrI8Fms= +=RdCW +-----END PGP PRIVATE KEY BLOCK----- +` + keys, err := ReadArmoredKeyRing(strings.NewReader(data)) + if err != nil { + t.Fatal(err) + } + if len(keys) != 1 { + t.Errorf("Expected 1 symmetric key, got %d", len(keys)) + } + if keys[0].PrivateKey.PubKeyAlgo != packet.ExperimentalPubKeyAlgoHMAC { + t.Errorf("Expected HMAC primary key") + } + if len(keys[0].Subkeys) != 1 { + t.Errorf("Expected 1 symmetric subkey, got %d", len(keys[0].Subkeys)) + } + if keys[0].Subkeys[0].PrivateKey.PubKeyAlgo != packet.ExperimentalPubKeyAlgoAEAD { + t.Errorf("Expected AEAD subkey") + } +} diff --git a/openpgp/v2/read.go b/openpgp/v2/read.go index 24c8c8f0..b275130d 100644 --- a/openpgp/v2/read.go +++ b/openpgp/v2/read.go @@ -138,7 +138,7 @@ ParsePackets: switch p.Algo { case packet.PubKeyAlgoRSA, packet.PubKeyAlgoRSAEncryptOnly, packet.PubKeyAlgoElGamal, packet.PubKeyAlgoECDH, - packet.PubKeyAlgoX25519, packet.PubKeyAlgoX448: + packet.PubKeyAlgoX25519, packet.PubKeyAlgoX448, packet.ExperimentalPubKeyAlgoAEAD: break default: continue diff --git a/openpgp/v2/read_test.go b/openpgp/v2/read_test.go index 2feaf392..d7084b8a 100644 --- a/openpgp/v2/read_test.go +++ b/openpgp/v2/read_test.go @@ -1002,3 +1002,10 @@ func testMalformedMessage(t *testing.T, keyring EntityList, message string) { return } } + +func TestReadKeyRingWithSymmetricSubkey(t *testing.T) { + _, err := ReadArmoredKeyRing(strings.NewReader(keyWithAEADSubkey)) + if err != nil { + t.Error("could not read keyring", err) + } +} diff --git a/openpgp/v2/read_write_test_data.go b/openpgp/v2/read_write_test_data.go index 2f0efc22..bb383b19 100644 --- a/openpgp/v2/read_write_test_data.go +++ b/openpgp/v2/read_write_test_data.go @@ -740,3 +740,21 @@ NVniEke6hM3CNBXYPAMhQBMWhCulcoz+0lxi8L34rMN+Dsbma96psdUrn7uLaB91 xqAY9Bwizt4FWgXuLm1a4+So4V9j1TRCXd12Uc2l2RNmgDE= =miES -----END PGP PRIVATE KEY BLOCK-----` + +// A key that contains a persistent AEAD subkey +const keyWithAEADSubkey = `-----BEGIN PGP PRIVATE KEY BLOCK----- + +xVgEYs/4KxYJKwYBBAHaRw8BAQdA7tIsntXluwloh/H62PJMqasjP00M86fv +/Pof9A968q8AAQDYcgkPKUdWAxsDjDHJfouPS4q5Me3ks+umlo5RJdwLZw4k +zQ1TeW1tZXRyaWMgS2V5wowEEBYKAB0FAmLP+CsECwkHCAMVCAoEFgACAQIZ +AQIbAwIeAQAhCRDkNhFDvaU8vxYhBDJNoyEFquVOCf99d+Q2EUO9pTy/5XQA +/1F2YPouv0ydBDJU3EOS/4bmPt7yqvzciWzeKVEOkzYuAP9OsP7q/5ccqOPX +mmRUKwd82/cNjdzdnWZ8Tq89XMwMAMdqBGLP+CtkCfFyZxOMF0BWLwAE8pLy +RVj2n2K7k6VvrhyuTqDkFDUFALiSLrEfnmTKlsPYS3/YzsODF354ccR63q73 +3lmCrvFRyaf6AHvVrBYPbJR+VhuTjZTwZKvPPKv0zVdSqi5JDEQiocJ4BBgW +CAAJBQJiz/grAhsMACEJEOQ2EUO9pTy/FiEEMk2jIQWq5U4J/3135DYRQ72l +PL+fEQEA7RaRbfa+AtiRN7a4GuqVEDZi3qtQZ2/Qcb27/LkAD0sA/3r9drYv +jyu46h1fdHHyo0HS2MiShZDZ8u60JnDltloD +=8TxH +-----END PGP PRIVATE KEY BLOCK----- +` diff --git a/openpgp/v2/write_test.go b/openpgp/v2/write_test.go index f3c4f9da..28550862 100644 --- a/openpgp/v2/write_test.go +++ b/openpgp/v2/write_test.go @@ -6,6 +6,7 @@ package v2 import ( "bytes" + "crypto" "crypto/rand" "io" mathrand "math/rand" @@ -997,3 +998,90 @@ FindKey: } return nil } + +func TestEncryptWithAEAD(t *testing.T) { + c := &packet.Config{ + MinRSABits: 1024, + Algorithm: packet.ExperimentalPubKeyAlgoAEAD, + DefaultCipher: packet.CipherAES256, + AEADConfig: &packet.AEADConfig{ + DefaultMode: packet.AEADMode(1), + }, + } + entity, err := NewEntity("Golang Gopher", "Test Key", "no-reply@golang.com", &packet.Config{RSABits: 1024}) + if err != nil { + t.Fatal(err) + } + + err = entity.AddEncryptionSubkey(c) + if err != nil { + t.Fatal(err) + } + + list := make([]*Entity, 1) + list[0] = entity + entityList := EntityList(list) + buf := bytes.NewBuffer(nil) + w, err := Encrypt(buf, entityList[:], nil, nil, nil, c) + if err != nil { + t.Fatal(err) + } + + const message = "test" + _, err = w.Write([]byte(message)) + if err != nil { + t.Fatal(err) + } + err = w.Close() + if err != nil { + t.Fatal(err) + } + + m, err := ReadMessage(buf, entityList, nil /* no prompt */, c) + if err != nil { + t.Fatal(err) + } + dec, err := io.ReadAll(m.decrypted) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(dec, []byte(message)) { + t.Error("decrypted does not match original") + } +} + +func TestSignWithHMAC(t *testing.T) { + c := &packet.Config{ + MinRSABits: 1024, + Algorithm: packet.ExperimentalPubKeyAlgoHMAC, + DefaultHash: crypto.SHA512, + } + entity, err := NewEntity("Golang Gopher", "Test Key", "no-reply@golang.com", &packet.Config{RSABits: 1024}) + if err != nil { + t.Fatal(err) + } + + err = entity.AddSigningSubkey(c) + if err != nil { + t.Fatal(err) + } + list := make([]*Entity, 1) + list[0] = entity + entityList := EntityList(list) + + msgBytes := []byte("message") + msg := bytes.NewBuffer(msgBytes) + sig := bytes.NewBuffer(nil) + + err = DetachSign(sig, []*Entity{entity}, msg, c) + if err != nil { + t.Fatal(err) + } + + msg = bytes.NewBuffer(msgBytes) + _, _, err = VerifyDetachedSignature(entityList, msg, sig, c) + if err != nil { + t.Fatal(err) + } +} diff --git a/openpgp/write_test.go b/openpgp/write_test.go index c928236b..315e7323 100644 --- a/openpgp/write_test.go +++ b/openpgp/write_test.go @@ -6,6 +6,7 @@ package openpgp import ( "bytes" + "crypto" "crypto/rand" "io" mathrand "math/rand" @@ -263,6 +264,91 @@ func TestNewEntity(t *testing.T) { } } +func TestEncryptWithAEAD(t *testing.T) { + c := &packet.Config{ + Algorithm: packet.ExperimentalPubKeyAlgoAEAD, + DefaultCipher: packet.CipherAES256, + AEADConfig: &packet.AEADConfig{ + DefaultMode: packet.AEADMode(1), + }, + } + entity, err := NewEntity("Golang Gopher", "Test Key", "no-reply@golang.com", &packet.Config{RSABits: 1024}) + if err != nil { + t.Fatal(err) + } + + err = entity.AddEncryptionSubkey(c) + if err != nil { + t.Fatal(err) + } + + list := make([]*Entity, 1) + list[0] = entity + entityList := EntityList(list) + buf := bytes.NewBuffer(nil) + w, err := Encrypt(buf, entityList[:], nil, nil, c) + if err != nil { + t.Fatal(err) + } + + const message = "test" + _, err = w.Write([]byte(message)) + if err != nil { + t.Fatal(err) + } + err = w.Close() + if err != nil { + t.Fatal(err) + } + + m, err := ReadMessage(buf, entityList, nil /* no prompt */, c) + if err != nil { + t.Fatal(err) + } + dec, err := io.ReadAll(m.decrypted) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(dec, []byte(message)) { + t.Error("decrypted does not match original") + } +} + +func TestSignWithHMAC(t *testing.T) { + c := &packet.Config{ + Algorithm: packet.ExperimentalPubKeyAlgoHMAC, + DefaultHash: crypto.SHA512, + } + entity, err := NewEntity("Golang Gopher", "Test Key", "no-reply@golang.com", &packet.Config{RSABits: 1024}) + if err != nil { + t.Fatal(err) + } + + err = entity.AddSigningSubkey(c) + if err != nil { + t.Fatal(err) + } + list := make([]*Entity, 1) + list[0] = entity + entityList := EntityList(list) + + msgBytes := []byte("message") + msg := bytes.NewBuffer(msgBytes) + sig := bytes.NewBuffer(nil) + + err = DetachSign(sig, entity, msg, nil) + if err != nil { + t.Fatal(err) + } + + msg = bytes.NewBuffer(msgBytes) + _, err = CheckDetachedSignature(entityList, msg, sig, nil) + if err != nil { + t.Fatal(err) + } +} + func TestEncryptWithCompression(t *testing.T) { kring, _ := ReadKeyRing(readerFromHex(testKeys1And2PrivateHex)) passphrase := []byte("passphrase")