Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ACVP ML-KEM testing #1840

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions util/fipstools/acvp/acvptool/subprocess/ml_kem.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 OR ISC

package subprocess

import (
"encoding/json"
"fmt"
"strings"
)

type mlKem struct{}

func (*mlKem) Process(vectorSet []byte, m Transactable) (interface{}, error) {
var vs struct {
Mode string `json:"mode"`
TestGroups json.RawMessage `json:"testGroups"`
}

if err := json.Unmarshal(vectorSet, &vs); err != nil {
return nil, err
}

switch {
case strings.EqualFold(vs.Mode, "keyGen"):
return processMlKemKeyGen(vs.TestGroups, m)
case strings.EqualFold(vs.Mode, "encapDecap"):
return processMlKemEncapDecap(vs.TestGroups, m)
}

return nil, fmt.Errorf("unknown ML-KEM mode: %v", vs.Mode)
}

type mlKemKeyGenTestGroup struct {
ID uint64 `json:"tgId"`
Type string `json:"testType"`
ParameterSet string `json:"parameterSet"`
Tests []struct {
ID uint64 `json:"tcId"`
D hexEncodedByteString `json:"d"`
Z hexEncodedByteString `json:"z"`
}
}

type mlKemKeyGenTestGroupResponse struct {
ID uint64 `json:"tgId"`
Tests []mlKemKeyGenTestCaseResponse `json:"tests"`
}

type mlKemKeyGenTestCaseResponse struct {
ID uint64 `json:"tcId"`
EK hexEncodedByteString `json:"ek"`
DK hexEncodedByteString `json:"dk"`
}

func processMlKemKeyGen(vectors json.RawMessage, m Transactable) (interface{}, error) {
var groups []mlKemKeyGenTestGroup

if err := json.Unmarshal(vectors, &groups); err != nil {
return nil, err
}

var responses []mlKemKeyGenTestGroupResponse

for _, group := range groups {
if !strings.EqualFold(group.Type, "AFT") {
return nil, fmt.Errorf("unsupported keyGen test type: %v", group.Type)
}

response := mlKemKeyGenTestGroupResponse{
ID: group.ID,
}

for _, test := range group.Tests {
results, err := m.Transact("ML-KEM/"+group.ParameterSet+"/keyGen", 2, test.D, test.Z)
if err != nil {
return nil, err
}

ek := results[0]
dk := results[1]

response.Tests = append(response.Tests, mlKemKeyGenTestCaseResponse{
ID: test.ID,
EK: ek,
DK: dk,
})
}

responses = append(responses, response)
}

return responses, nil
}

type mlKemEncapDecapTestGroup struct {
ID uint64 `json:"tgId"`
Type string `json:"testType"`
ParameterSet string `json:"parameterSet"`
Function string `json:"function"`
DK hexEncodedByteString `json:"dk"`
Tests []struct {
ID uint64 `json:"tcId"`
EK hexEncodedByteString `json:"ek"`
M hexEncodedByteString `json:"m"`
C hexEncodedByteString `json:"c"`
}
}

type mlKemEncDecapTestGroupResponse struct {
ID uint64 `json:"tgId"`
Tests []mlKemEncDecapTestCaseResponse `json:"tests"`
}

type mlKemEncDecapTestCaseResponse struct {
ID uint64 `json:"tcId"`
C hexEncodedByteString `json:"c,omitempty"`
K hexEncodedByteString `json:"k,omitempty"`
}

func processMlKemEncapDecap(vectors json.RawMessage, m Transactable) (interface{}, error) {
var groups []mlKemEncapDecapTestGroup

if err := json.Unmarshal(vectors, &groups); err != nil {
return nil, err
}

var responses []mlKemEncDecapTestGroupResponse

for _, group := range groups {
if (strings.EqualFold(group.Function, "encapsulation") && !strings.EqualFold(group.Type, "AFT")) ||
(strings.EqualFold(group.Function, "decapsulation") && !strings.EqualFold(group.Type, "VAL")) {
return nil, fmt.Errorf("unsupported encapDecap function and test group type pair: (%v, %v)", group.Function, group.Type)
}

response := mlKemEncDecapTestGroupResponse{
ID: group.ID,
}

for _, test := range group.Tests {
var (
err error
testResponse mlKemEncDecapTestCaseResponse
)

switch {
case strings.EqualFold(group.Function, "encapsulation"):
testResponse, err = processMlKemEncapTestCase(test.ID, group.ParameterSet, test.EK, test.M, m)
case strings.EqualFold(group.Function, "decapsulation"):
testResponse, err = processMlKemDecapTestCase(test.ID, group.ParameterSet, group.DK, test.C, m)
default:
return nil, fmt.Errorf("unknown encDecap function: %v", group.Function)
}
if err != nil {
return nil, err
}

response.Tests = append(response.Tests, testResponse)
}

responses = append(responses, response)
}
return responses, nil
}

func processMlKemEncapTestCase(id uint64, algorithm string, ek []byte, m []byte, t Transactable) (mlKemEncDecapTestCaseResponse, error) {
results, err := t.Transact("ML-KEM/"+algorithm+"/encap", 2, ek, m)
if err != nil {
return mlKemEncDecapTestCaseResponse{}, err
}

c := results[0]
k := results[1]

return mlKemEncDecapTestCaseResponse{
ID: id,
C: c,
K: k,
}, nil
}

func processMlKemDecapTestCase(id uint64, algorithm string, dk []byte, c []byte, t Transactable) (mlKemEncDecapTestCaseResponse, error) {
results, err := t.Transact("ML-KEM/"+algorithm+"/decap", 1, dk, c)
if err != nil {
return mlKemEncDecapTestCaseResponse{}, err
}

k := results[0]

return mlKemEncDecapTestCaseResponse{
ID: id,
K: k,
}, nil
}
1 change: 1 addition & 0 deletions util/fipstools/acvp/acvptool/subprocess/subprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ func NewWithIO(cmd *exec.Cmd, in io.WriteCloser, out io.ReadCloser) *Subprocess
"KAS-ECC-SSC": &kas{},
"KAS-FFC-SSC": &kasDH{},
"PBKDF": &pbkdf{},
"ML-KEM": &mlKem{},
}
m.primitives["ECDSA"] = &ecdsa{"ECDSA", map[string]bool{"P-224": true, "P-256": true, "P-384": true, "P-521": true}, m.primitives}

Expand Down
Binary file not shown.
3 changes: 2 additions & 1 deletion util/fipstools/acvp/acvptool/test/tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@
{"Wrapper": "modulewrapper", "In": "vectors/TLS-1.2-KDF.bz2", "Out": "expected/TLS-1.2-KDF.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/PBKDF.bz2", "Out": "expected/PBKDF.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/KDA-HKDF.bz2", "Out": "expected/KDA-HKDF.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/KDA-OneStep.bz2", "Out": "expected/KDA-OneStep.bz2"}
{"Wrapper": "modulewrapper", "In": "vectors/KDA-OneStep.bz2", "Out": "expected/KDA-OneStep.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/ML-KEM.bz2", "Out": "expected/ML-KEM.bz2"}
]
Binary file not shown.
136 changes: 133 additions & 3 deletions util/fipstools/acvp/modulewrapper/modulewrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
* OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
* CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */

#include <signal.h>
#include <algorithm>
#include <map>
#include <string>
#include <vector>
#include <signal.h>
#include <cstring>

#include <sstream>

Expand All @@ -41,6 +42,7 @@
#include <openssl/ecdsa.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/experimental/kem_deterministic_api.h>
#include <openssl/hkdf.h>
#include <openssl/hmac.h>
#include <openssl/kdf.h>
Expand Down Expand Up @@ -732,7 +734,7 @@ static bool GetConfig(const Span<const uint8_t> args[],
}]
}]
},)"
R"({
R"({
"sigType": "pss",
"properties": [{
"modulo": 2048,
Expand Down Expand Up @@ -988,7 +990,7 @@ static bool GetConfig(const Span<const uint8_t> args[],
}]
}]
},)"
R"({
R"({
"sigType": "pss",
"properties": [{
"modulo": 2048,
Expand Down Expand Up @@ -1307,6 +1309,19 @@ static bool GetConfig(const Span<const uint8_t> args[],
"encoding": ["concatenation"],
"z": [{"min": 224, "max": 8192, "increment": 8}],
"l": 2048
},)"
R"({
"algorithm": "ML-KEM",
"mode": "keyGen",
"revision": "FIPS203",
"parameterSets": ["ML-KEM-512", "ML-KEM-768", "ML-KEM-1024"]
},
{
"algorithm": "ML-KEM",
"mode": "encapDecap",
"revision": "FIPS203",
"parameterSets": ["ML-KEM-512", "ML-KEM-768", "ML-KEM-1024"],
"functions": ["encapsulation", "decapsulation"]
}
])";
return write_reply({Span<const uint8_t>(
Expand Down Expand Up @@ -2830,6 +2845,112 @@ static bool KBKDF_CTR_HMAC(const Span<const uint8_t> args[],
return write_reply({Span<const uint8_t>(out)});
}

template <int nid>
static bool ML_KEM_KEYGEN(const Span<const uint8_t> args[],
ReplyCallback write_reply) {
const Span<const uint8_t> d = args[0];
const Span<const uint8_t> z = args[1];

std::vector<uint8_t> seed(d.size() + z.size());
std::memcpy(seed.data(), d.data(), d.size());
std::memcpy(seed.data() + d.size(), z.data(), z.size());

EVP_PKEY *raw = NULL;
size_t seed_len = 0;

bssl::UniquePtr<EVP_PKEY_CTX> ctx(EVP_PKEY_CTX_new_id(EVP_PKEY_KEM, nullptr));
if (!EVP_PKEY_CTX_kem_set_params(ctx.get(), nid) ||
!EVP_PKEY_keygen_init(ctx.get()) ||
!EVP_PKEY_keygen_deterministic(ctx.get(), &raw, NULL, &seed_len) ||
seed_len != seed.size() ||
!EVP_PKEY_keygen_deterministic(ctx.get(), &raw, seed.data(), &seed_len)) {
return false;
}
bssl::UniquePtr<EVP_PKEY> pkey(raw);

size_t decaps_key_size = 0;
size_t encaps_key_size = 0;

if (!EVP_PKEY_get_raw_private_key(pkey.get(), nullptr, &decaps_key_size) ||
!EVP_PKEY_get_raw_public_key(pkey.get(), nullptr, &encaps_key_size)) {
return false;
}

std::vector<uint8_t> decaps_key(decaps_key_size);
std::vector<uint8_t> encaps_key(encaps_key_size);

if (!EVP_PKEY_get_raw_private_key(pkey.get(), decaps_key.data(),
&decaps_key_size) ||
!EVP_PKEY_get_raw_public_key(pkey.get(), encaps_key.data(),
&encaps_key_size)) {
return false;
}

return write_reply({Span<const uint8_t>(encaps_key.data(), encaps_key_size),
Span<const uint8_t>(decaps_key.data(), decaps_key_size)});
}

template <int nid>
static bool ML_KEM_ENCAP(const Span<const uint8_t> args[],
ReplyCallback write_reply) {
const Span<const uint8_t> ek = args[0];
const Span<const uint8_t> m = args[1];

bssl::UniquePtr<EVP_PKEY> pkey(
EVP_PKEY_kem_new_raw_public_key(nid, ek.data(), ek.size()));
bssl::UniquePtr<EVP_PKEY_CTX> ctx(EVP_PKEY_CTX_new(pkey.get(), nullptr));

size_t ciphertext_len = 0;
size_t shared_secret_len = 0;
size_t seed_len = 0;
if (!EVP_PKEY_encapsulate_deterministic(ctx.get(), nullptr, &ciphertext_len,
nullptr, &shared_secret_len, nullptr,
&seed_len) ||
seed_len != m.size()) {
return false;
}

std::vector<uint8_t> ciphertext(ciphertext_len);
std::vector<uint8_t> shared_secret(shared_secret_len);

if (!EVP_PKEY_encapsulate_deterministic(
ctx.get(), ciphertext.data(), &ciphertext_len, shared_secret.data(),
&shared_secret_len, m.data(), &seed_len)) {
return false;
}

return write_reply(
{Span<const uint8_t>(ciphertext.data(), ciphertext_len),
Span<const uint8_t>(shared_secret.data(), shared_secret_len)});
}

template <int nid>
static bool ML_KEM_DECAP(const Span<const uint8_t> args[],
ReplyCallback write_reply) {
const Span<const uint8_t> dk = args[0];
const Span<const uint8_t> c = args[1];

bssl::UniquePtr<EVP_PKEY> pkey(
EVP_PKEY_kem_new_raw_secret_key(nid, dk.data(), dk.size()));
bssl::UniquePtr<EVP_PKEY_CTX> ctx(EVP_PKEY_CTX_new(pkey.get(), nullptr));

size_t shared_secret_len = 0;
if (!EVP_PKEY_decapsulate(ctx.get(), nullptr, &shared_secret_len, c.data(),
c.size())) {
return false;
}

std::vector<uint8_t> shared_secret(shared_secret_len);

if (!EVP_PKEY_decapsulate(ctx.get(), shared_secret.data(), &shared_secret_len,
c.data(), c.size())) {
return false;
}

return write_reply(
{Span<const uint8_t>(shared_secret.data(), shared_secret_len)});
}

static struct {
char name[kMaxNameLength + 1];
uint8_t num_expected_args;
Expand Down Expand Up @@ -3064,6 +3185,15 @@ static struct {
{"KDF/Counter/HMAC-SHA2-512", 3, KBKDF_CTR_HMAC<EVP_sha512>},
{"KDF/Counter/HMAC-SHA2-512/224", 3, KBKDF_CTR_HMAC<EVP_sha512_224>},
{"KDF/Counter/HMAC-SHA2-512/256", 3, KBKDF_CTR_HMAC<EVP_sha512_256>},
{"ML-KEM/ML-KEM-512/keyGen", 2, ML_KEM_KEYGEN<NID_MLKEM512>},
{"ML-KEM/ML-KEM-768/keyGen", 2, ML_KEM_KEYGEN<NID_MLKEM768>},
{"ML-KEM/ML-KEM-1024/keyGen", 2, ML_KEM_KEYGEN<NID_MLKEM1024>},
{"ML-KEM/ML-KEM-512/encap", 2, ML_KEM_ENCAP<NID_MLKEM512>},
{"ML-KEM/ML-KEM-768/encap", 2, ML_KEM_ENCAP<NID_MLKEM768>},
{"ML-KEM/ML-KEM-1024/encap", 2, ML_KEM_ENCAP<NID_MLKEM1024>},
{"ML-KEM/ML-KEM-512/decap", 2, ML_KEM_DECAP<NID_MLKEM512>},
{"ML-KEM/ML-KEM-768/decap", 2, ML_KEM_DECAP<NID_MLKEM768>},
{"ML-KEM/ML-KEM-1024/decap", 2, ML_KEM_DECAP<NID_MLKEM1024>},
};

Handler FindHandler(Span<const Span<const uint8_t>> args) {
Expand Down
Loading