diff --git a/src/common/SSLUtil.cc b/src/common/SSLUtil.cc index 2c9f35b6c95..60607d4b340 100644 --- a/src/common/SSLUtil.cc +++ b/src/common/SSLUtil.cc @@ -254,12 +254,12 @@ namespace ssl_util int rsa_public_encrypt(const std::string& in, std::string& out) { - static RSA * rsa = nullptr; + static EVP_PKEY_CTX *ctx = nullptr; static std::mutex m; std::lock_guard lock(m); - if ( rsa == nullptr) //initialize RSA structure + if ( ctx == nullptr) //initialize RSA structure { FILE * fp = fopen(pubk_path.c_str(), "r"); @@ -268,30 +268,61 @@ namespace ssl_util return -1; } - rsa = PEM_read_RSAPublicKey(fp, &rsa, nullptr, nullptr); + EVP_PKEY* pub_key = PEM_read_PUBKEY(fp, nullptr, nullptr, nullptr); - if ( rsa == nullptr ) + fclose(fp); + + if ( pub_key == nullptr ) { return -1; } - fclose(fp); - } + if (EVP_PKEY_base_id(pub_key) != EVP_PKEY_RSA) + { + EVP_PKEY_free(pub_key); + return -1; + } - char * out_c = (char *) malloc(sizeof(char) * RSA_size(rsa)); + ctx = EVP_PKEY_CTX_new(pub_key, nullptr); + if (!ctx) + return -1; - int rc = RSA_public_encrypt(in.length(), (const unsigned char *) in.c_str(), - (unsigned char *) out_c, rsa, RSA_PKCS1_PADDING); + if (EVP_PKEY_encrypt_init(ctx) < 1) + { + EVP_PKEY_CTX_free(ctx); + EVP_PKEY_free(pub_key); + ctx = nullptr; + return -1; + } + + if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_PADDING) < 1) + { + EVP_PKEY_CTX_free(ctx); + EVP_PKEY_free(pub_key); + ctx = nullptr; + return -1; + } - if ( rc != -1 ) - { - out.assign(out_c, rc); - rc = 0; } - free(out_c); + size_t out_len; + if (EVP_PKEY_encrypt(ctx, nullptr, &out_len, (const unsigned char *)in.c_str(), in.size()) <= 0) + return -1; + + // OpenSSL documentation states that if EVP_PKEY_encrypt is called with out as nullptr, + // it will write maximum size of the output buffer to out_len. + // If it's not nullptr it will write the actual size of buffer, so we call resize here twice. + std::string result; + result.resize(out_len); + + if (EVP_PKEY_encrypt(ctx, (unsigned char*) result.data(), &out_len, (const unsigned char *)in.c_str(), in.size()) <= 0) + return -1; - return rc; + result.resize(out_len); + + out = std::move(result); + + return 0; } /* -------------------------------------------------------------------------- */ @@ -299,12 +330,14 @@ namespace ssl_util int rsa_private_decrypt(const std::string& in, std::string& out) { - static RSA * rsa = nullptr; + static EVP_PKEY_CTX* ctx = nullptr; static std::mutex m; std::lock_guard lock(m); - if ( rsa == nullptr) //initialize RSA structure + static std::size_t key_size = 0; + + if ( ctx == nullptr) //initialize RSA structure { FILE * fp = fopen(prik_path.c_str(), "r"); @@ -313,48 +346,70 @@ namespace ssl_util return -1; } - rsa = PEM_read_RSAPrivateKey(fp, &rsa, nullptr, nullptr); + EVP_PKEY* priv_key = PEM_read_PrivateKey(fp, nullptr, nullptr, nullptr); + + fclose(fp); + + if ( priv_key == nullptr ) + { + return -1; + } + + if (EVP_PKEY_base_id(priv_key) != EVP_PKEY_RSA) + { + EVP_PKEY_free(priv_key); + return -1; + } + + ctx = EVP_PKEY_CTX_new(priv_key, nullptr); + if (!ctx) + return -1; - if ( rsa == nullptr ) + if (EVP_PKEY_decrypt_init(ctx) < 1 || + EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_PADDING) < 1) { + EVP_PKEY_CTX_free(ctx); + EVP_PKEY_free(priv_key); + ctx = nullptr; return -1; } - fclose(fp); - } + int tmp = EVP_PKEY_get_size(priv_key); - std::string result; + if (tmp < 0) + { + EVP_PKEY_CTX_free(ctx); + EVP_PKEY_free(priv_key); + ctx = nullptr; + return -1; + } - int key_size = RSA_size(rsa); - int in_size = in.length(); - char * out_c = (char *) malloc(sizeof(char) * RSA_size(rsa)); + key_size = static_cast(tmp); + } - const char * in_c = in.c_str(); + const auto in_size = in.size(); - for (int index = 0; index < in_size; index += key_size) + std::size_t index = 0; + std::vector out_c(key_size, 0); + std::string result; + + while (index < in_size) { - int block_size = key_size; + auto step = (in_size - index) < key_size ? in_size - index : key_size; - if ( index + key_size > in_size ) - { - block_size = in_size - index; - } + const unsigned char* in_p = (unsigned char*)in.data() + index; - int rc = RSA_private_decrypt(block_size, (const unsigned char *) - in_c + index, (unsigned char *) out_c, rsa, RSA_PKCS1_PADDING); + std::size_t out_size = key_size; - if ( rc != -1 ) - { - result.append(out_c, rc); - } - else + if (EVP_PKEY_decrypt(ctx, out_c.data(), &out_size, in_p, step) < 1) { - free(out_c); return -1; } - } - free(out_c); + result.append((char*)out_c.data(), out_size); + + index +=step; + } out = std::move(result);