Skip to content

Commit

Permalink
fixup! Implement cmp KEM combiner and cmp KEM encaps/decaps fns
Browse files Browse the repository at this point in the history
Signed-off-by: Pravek Sharma <[email protected]>
  • Loading branch information
praveksharma committed Nov 18, 2024
1 parent a2af9b1 commit 8779bd3
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 96 deletions.
192 changes: 125 additions & 67 deletions oqsprov/oqs_hyb_kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*
*/

#include <openssl/asn1.h>
static OSSL_FUNC_kem_encapsulate_fn oqs_hyb_kem_encaps;
static OSSL_FUNC_kem_decapsulate_fn oqs_hyb_kem_decaps;

Expand Down Expand Up @@ -34,17 +35,21 @@ static int oqs_evp_kem_encaps_keyslot(void *vpkemctx, unsigned char *ct,
pubkey_kexlen = evp_ctx->evp_info->length_public_key;
kexDeriveLen = evp_ctx->evp_info->kex_length_secret;

*ctlen = pubkey_kexlen;
*secretlen = kexDeriveLen;
if (keytype == EVP_PKEY_RSA) {
*ctlen = evp_ctx->evp_info->kex_length_secret;
*secretlen = 32;
if (ct == NULL || secret == NULL) {
OQS_KEM_PRINTF3("EVP KEM returning lengths %ld and %ld\n", *ctlen,
*secretlen);
return 1;
}

if (ct == NULL || secret == NULL) {
OQS_KEM_PRINTF3("EVP KEM returning lengths %ld and %ld\n", *ctlen,
*secretlen);
return 1;
}
pkey =
d2i_PublicKey(keytype, NULL, (const unsigned char **)&pubkey_kex,
pubkey_kexlen);
ON_ERR_SET_GOTO(!pkey, ret, -1, err);

if (keytype == EVP_PKEY_RSA) {
ctx = EVP_PKEY_CTX_new(evp_ctx->keyParam, NULL);
ctx = EVP_PKEY_CTX_new(pkey, NULL);
ON_ERR_SET_GOTO(!ctx, ret, -1, err);

ret = EVP_PKEY_encrypt_init(ctx);
Expand All @@ -60,8 +65,7 @@ static int oqs_evp_kem_encaps_keyslot(void *vpkemctx, unsigned char *ct,
ON_ERR_SET_GOTO(ret <= 0, ret, -1, err);

// set pSourceFunc to empty string for pSpecifiedEmptyIdentifier
unsigned char empty_string[] = "";
ret = EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, empty_string, 0);
ret = EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, NULL, 0);
ON_ERR_SET_GOTO(ret <= 0, ret, -1, err);

// generate random secret, 256 bits = 32 bytes
Expand All @@ -74,9 +78,15 @@ static int oqs_evp_kem_encaps_keyslot(void *vpkemctx, unsigned char *ct,
ret = EVP_PKEY_encrypt(ctx, ct, &outlen, secret, 32);
ON_ERR_SET_GOTO(ret <= 0, ret, -1, err);

*ctlen = outlen;
*secretlen = 32; // 256 bits
} else {
*ctlen = pubkey_kexlen;
*secretlen = kexDeriveLen;
if (ct == NULL || secret == NULL) {
OQS_KEM_PRINTF3("EVP KEM returning lengths %ld and %ld\n", *ctlen,
*secretlen);
return 1;
}

peerpk = EVP_PKEY_new();
ON_ERR_SET_GOTO(!peerpk, ret, -1, err);

Expand Down Expand Up @@ -145,11 +155,14 @@ static int oqs_evp_kem_decaps_keyslot(void *vpkemctx, unsigned char *secret,
EVP_PKEY_CTX *ctx = NULL;
EVP_PKEY *pkey = NULL, *peerpkey = NULL;

*secretlen = kexDeriveLen;
if (secret == NULL)
return 1;

if (keytype == EVP_PKEY_RSA) {
*secretlen = 32;
size_t outlen = 32; // expected secret length (256 bits)
if (secret == NULL) {
OQS_KEM_PRINTF2("EVP KEM returning lengths %ld\n", *secretlen);
return 1;
}

pkey =
d2i_PrivateKey(keytype, NULL, (const unsigned char **)&privkey_kex,
privkey_kexlen);
Expand All @@ -171,16 +184,20 @@ static int oqs_evp_kem_decaps_keyslot(void *vpkemctx, unsigned char *secret,
ON_ERR_SET_GOTO(ret <= 0, ret, -6, err);

// expect pSourceFunc to be pSpecifiedEmptyIdentifier
unsigned char empty_string[] = "";
ret = EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, empty_string, 0);
ret = EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, NULL, 0);
ON_ERR_SET_GOTO(ret <= 0, ret, -7, err);

size_t outlen = 32; // expected secret length (256 bits)
ret = EVP_PKEY_decrypt(ctx, NULL, &outlen, ct, ctlen);
ret = EVP_PKEY_decrypt(ctx, secret, &outlen, ct, ctlen);
ON_ERR_SET_GOTO(ret <= 0, ret, -8, err);

*secretlen = outlen;
} else {
*secretlen = kexDeriveLen;
if (secret == NULL) {
OQS_KEM_PRINTF2("EVP KEM returning lengths %ld\n", *secretlen);
return 1;
}

if (evp_ctx->evp_info->raw_key_support) {
pkey = EVP_PKEY_new_raw_private_key(
evp_ctx->evp_info->keytype, NULL, privkey_kex, privkey_kexlen);
Expand Down Expand Up @@ -339,9 +356,9 @@ static int oqs_hyb_kem_decaps(void *vpkemctx, unsigned char *secret,

// Composite KEM functions

static int oqs_cmp_kem_encaps(void *vpkemctx, unsigned char *ct, size_t ctlen,
static int oqs_cmp_kem_encaps(void *vpkemctx, unsigned char *ct, size_t *ctlen,
unsigned char *secret, size_t *secretlen) {
int ret = OQS_SUCCESS, ret2 = 0;
int ret = 1, ret2 = 0;
PROV_OQSKEM_CTX *pkemctx = (PROV_OQSKEM_CTX *)vpkemctx;
const OQS_KEM *qs_kem = pkemctx->kem->oqsx_provider_ctx.oqsx_qs_ctx.kem;
const OQSX_EVP_INFO *evp_info =
Expand All @@ -357,56 +374,97 @@ static int oqs_cmp_kem_encaps(void *vpkemctx, unsigned char *ct, size_t ctlen,

ret2 = oqs_qs_kem_encaps_keyslot(vpkemctx, NULL, &ctLen0, NULL, &secretLen0,
0);
ON_ERR_SET_GOTO(ret2 <= 0, ret, OQS_ERROR, err);
secret0 = OPENSSL_malloc(secretLen0);
ON_ERR_SET_GOTO(!secret0, ret, OQS_ERROR, err);
ct0 = OPENSSL_malloc(ctLen0);
ON_ERR_SET_GOTO(!ct0, ret, OQS_ERROR, err_alloc0);
ON_ERR_SET_GOTO(ret2 <= 0, ret, 0, err);
secret0 = OPENSSL_zalloc(secretLen0);
ON_ERR_SET_GOTO(!secret0, ret, 0, err_secret0);
ct0 = OPENSSL_zalloc(ctLen0); // do not free ct0, freed later with cmpCT
ON_ERR_SET_GOTO(!ct0, ret, 0, err_ct0);

ret2 = oqs_evp_kem_encaps_keyslot(vpkemctx, NULL, &ctLen1, NULL,
&secretLen1, 1);
ON_ERR_SET_GOTO(ret2 <= 0, ret, OQS_ERROR, err_alloc1);
secret1 = OPENSSL_malloc(secretLen1);
ON_ERR_SET_GOTO(!secret1, ret, OQS_ERROR, err_alloc1);
ct1 = OPENSSL_malloc(ctLen1);
ON_ERR_SET_GOTO(!ct1, ret, OQS_ERROR, err_alloc2);
ON_ERR_SET_GOTO(ret2 <= 0, ret, 0, err_ct0);
secret1 = OPENSSL_zalloc(secretLen1);
ON_ERR_SET_GOTO(!secret1, ret, 0, err_secret1);
ct1 = OPENSSL_zalloc(ctLen1);
ON_ERR_SET_GOTO(!ct1, ret, 0, err_ct1);

cmpCT = CompositeCiphertext_new();
ON_ERR_SET_GOTO(!cmpCT, ret, OQS_ERROR, err_cmpct);
cmpCT->ct1 = ASN1_OCTET_STRING_new();
ON_ERR_SET_GOTO(!(cmpCT->ct1), ret, OQS_ERROR, err_cmpct);
cmpCT->ct2 = ASN1_OCTET_STRING_new();
ON_ERR_SET_GOTO(!(cmpCT->ct2), ret, OQS_ERROR, err_cmpct);

if (ct == NULL || secret == NULL) {
unsigned char *temp = NULL;

ret2 = ASN1_STRING_set(cmpCT->ct1, ct0, ctLen0);
if (!ret2){
OPENSSL_free(temp);
ON_ERR_SET_GOTO(1, ret, 0, err_cmpct);
}
cmpCT->ct1->flags = 8; // do not check for unused bits

ret2 = ASN1_STRING_set(cmpCT->ct2, ct1, ctLen1);
if (!ret2){
OPENSSL_free(temp);
ON_ERR_SET_GOTO(1, ret, 0, err_cmpct);
}
cmpCT->ct2->flags = 8; // do not check for unused bits

*ctlen = (size_t)i2d_CompositeCiphertext(cmpCT, &temp);
if (ctlen <= 0) {
OPENSSL_free(temp);
ON_ERR_SET_GOTO(1, ret, 0, err_cmpct);
}

ret2 = oqs_kem_combiner(pkemctx, NULL, secretLen1, NULL, secretLen0, NULL, ctLen1, NULL, pkemctx->kem->pubkeylen_cmp[1], NULL, secretlen);
if (!ret2){
OPENSSL_free(temp);
ON_ERR_SET_GOTO(1, ret, 0, err_cmpct);
}

OPENSSL_free(temp);
ON_ERR_SET_GOTO(1, ret, 1, err_cmpct);
}

ret2 = oqs_qs_kem_encaps_keyslot(vpkemctx, ct0, &ctLen0, secret0,
&secretLen0, 0);
ON_ERR_SET_GOTO(ret2 <= 0, ret, OQS_ERROR, err_alloc3);
ON_ERR_SET_GOTO(ret2 <= 0, ret, 0, err_cmpct);

ret2 = oqs_evp_kem_encaps_keyslot(vpkemctx, ct1, &ctLen1, secret1,
&secretLen1, 1);
ON_ERR_SET_GOTO(ret2 <= 0, ret, OQS_ERROR, err_alloc3);
ON_ERR_SET_GOTO(ret2 <= 0, ret, 0, err_cmpct);

cmpCT = CompositeCiphertext_new();
ON_ERR_SET_GOTO(!cmpCT, ret, OQS_ERROR, err_alloc3);

cmpCT->ct1->data = ct0;
cmpCT->ct1->length = ctLen0;
ret2 = ASN1_STRING_set(cmpCT->ct1, ct0, ctLen0);
ON_ERR_SET_GOTO(!ret2, ret, 0, err_cmpct);
cmpCT->ct1->flags = 8; // do not check for unused bits

cmpCT->ct2->data = ct1;
cmpCT->ct2->length = ctLen1;
ret2 = ASN1_STRING_set(cmpCT->ct2, ct1, ctLen1);
cmpCT->ct2->flags = 8; // do not check for unused bits
ON_ERR_SET_GOTO(!ret2, ret, 0, err_cmpct);

ctlen = i2d_CompositeCiphertext(cmpCT, &p);
ON_ERR_SET_GOTO(!ctlen, ret, OQS_ERROR, err_cmpct);
*ctlen = (size_t)i2d_CompositeCiphertext(cmpCT, &p);
ON_ERR_SET_GOTO(ctlen <= 0, ret, 0, err_cmpct);

ret2 = oqs_kem_combiner(pkemctx, secret1, secretLen1, secret0, secretLen0,
ct1, ctLen1, pkemctx->kem->comp_pubkey[1],
pkemctx->kem->pubkeylen_cmp[1], secret, secretlen);
ON_ERR_SET_GOTO(!ret2, ret, OQS_ERROR, err_cmpct);
ON_ERR_SET_GOTO(!ret2, ret, 0, err_cmpct);

err_cmpct:
CompositeCiphertext_free(cmpCT);
err_alloc3:
CompositeCiphertext_free(cmpCT); // ct0 and ct1 also freed here
OPENSSL_clear_free(secret1, secretLen1);
OPENSSL_clear_free(secret0, secretLen0);
return ret;

err_ct1:
OPENSSL_free(ct1);
err_alloc2:
err_secret1:
OPENSSL_clear_free(secret1, secretLen1);
err_alloc1:
err_ct0:
OPENSSL_free(ct0);
err_alloc0:
err_secret0:
OPENSSL_clear_free(secret0, secretLen0);
err:
return ret;
Expand All @@ -415,7 +473,7 @@ static int oqs_cmp_kem_encaps(void *vpkemctx, unsigned char *ct, size_t ctlen,
static int oqs_cmp_kem_decaps(void *vpkemctx, unsigned char *secret,
size_t *secretlen, const unsigned char *ct,
size_t ctlen) {
int ret = OQS_SUCCESS, ret2 = 0;
int ret = 1, ret2 = 0;
PROV_OQSKEM_CTX *pkemctx = (PROV_OQSKEM_CTX *)vpkemctx;
const OQS_KEM *qs_kem = pkemctx->kem->oqsx_provider_ctx.oqsx_qs_ctx.kem;
const OQSX_EVP_INFO *evp_info =
Expand All @@ -429,41 +487,41 @@ static int oqs_cmp_kem_decaps(void *vpkemctx, unsigned char *secret,
CompositeCiphertext *cmpCT;
const unsigned char *p = ct; // temp ptr because d2i_* may move input ct ptr

cmpCT = d2i_CompositeCiphertext(NULL, (const unsigned char **)&p, ctlen);
ON_ERR_SET_GOTO(!cmpCT, ret, OQS_ERROR, err);
cmpCT = d2i_CompositeCiphertext(&cmpCT, (const unsigned char **)&p, ctlen);
ON_ERR_SET_GOTO(!cmpCT, ret, 0, err);

ct0 = cmpCT->ct1->data;
ctLen0 = cmpCT->ct1->length;
ct1 = cmpCT->ct2->data;
ctLen1 = cmpCT->ct2->length;
ON_ERR_SET_GOTO(!ct0 || !ct1, ret, OQS_ERROR, err_cmpct);
ct0 = ASN1_STRING_get0_data(cmpCT->ct1);
ctLen0 = ASN1_STRING_length(cmpCT->ct1);
ct1 = ASN1_STRING_get0_data(cmpCT->ct2);
ctLen1 = ASN1_STRING_length(cmpCT->ct2);
ON_ERR_SET_GOTO(!ct0 || !ct1, ret, 0, err_cmpct);

ret2 = oqs_qs_kem_decaps_keyslot(vpkemctx, NULL, &secretLen0, NULL, 0, 0);
ON_ERR_SET_GOTO(ret2 <= 0, ret, OQS_ERROR, err_cmpct);
ON_ERR_SET_GOTO(ret2 <= 0, ret, 0, err_cmpct);
secret0 = OPENSSL_malloc(secretLen0);
ON_ERR_SET_GOTO(!secret0, ret, OQS_ERROR, err_cmpct);
ON_ERR_SET_GOTO(!secret0, ret, 0, err_secret0);

ret2 = oqs_evp_kem_decaps_keyslot(vpkemctx, NULL, &secretLen1, NULL, 0, 1);
ON_ERR_SET_GOTO(ret2 <= 0, ret, OQS_ERROR, err_alloc0);
ON_ERR_SET_GOTO(ret2 <= 0, ret, 0, err_secret0);
secret1 = OPENSSL_malloc(secretLen1);
ON_ERR_SET_GOTO(!secret1, ret, OQS_ERROR, err_alloc0);
ON_ERR_SET_GOTO(!secret1, ret, 0, err_secret1);

ret2 = oqs_qs_kem_decaps_keyslot(vpkemctx, secret0, &secretLen0, ct0,
ctLen0, 0);
ON_ERR_SET_GOTO(ret2 <= 0, ret, OQS_ERROR, err_alloc1);
ON_ERR_SET_GOTO(ret2 <= 0, ret, 0, err_secret1);

ret2 = oqs_evp_kem_decaps_keyslot(vpkemctx, secret1, &secretLen1, ct1,
ctLen1, 1);
ON_ERR_SET_GOTO(ret2 <= 0, ret, OQS_ERROR, err_alloc1);
ON_ERR_SET_GOTO(ret2 <= 0, ret, 0, err_secret1);

ret2 = oqs_kem_combiner(pkemctx, secret1, secretLen1, secret0, secretLen0,
ct1, ctLen1, pkemctx->kem->comp_pubkey[1],
pkemctx->kem->pubkeylen_cmp[1], secret, secretlen);
ON_ERR_SET_GOTO(!ret2, ret, OQS_ERROR, err_alloc1);
ON_ERR_SET_GOTO(!ret2, ret, 0, err_secret1);

err_alloc1:
err_secret1:
OPENSSL_clear_free(secret1, secretLen1);
err_alloc0:
err_secret0:
OPENSSL_clear_free(secret0, secretLen0);
err_cmpct:
CompositeCiphertext_free(cmpCT);
Expand Down
Loading

0 comments on commit 8779bd3

Please sign in to comment.