From 4fdda82adacc83d8272a1031a724a718fef26143 Mon Sep 17 00:00:00 2001 From: Sean Parkinson Date: Fri, 13 Dec 2024 14:44:13 +1000 Subject: [PATCH] DTLS decryption threaded Add support for decryption in threads for DTLS. --- src/dtls13.c | 5 +- src/internal.c | 618 ++++++++++++++++++++++++++++----------------- src/ssl.c | 68 ++++- src/tls13.c | 7 +- tests/api.c | 2 + wolfssl/internal.h | 12 +- wolfssl/ssl.h | 6 + 7 files changed, 472 insertions(+), 246 deletions(-) diff --git a/src/dtls13.c b/src/dtls13.c index 5011f7d85b..0ca7138aa1 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -2666,9 +2666,6 @@ int DoDtls13Ack(WOLFSSL* ssl, const byte* input, word32 inputSize, const byte* ackMessage; w64wrapper epoch, seq; word16 length; -#ifndef WOLFSSL_RW_THREADED - int ret; -#endif int i; if (inputSize < OPAQUE16_LEN) @@ -2702,7 +2699,7 @@ int DoDtls13Ack(WOLFSSL* ssl, const byte* input, word32 inputSize, #ifndef WOLFSSL_RW_THREADED if (ssl->dtls13WaitKeyUpdateAck) { - ret = DoDtls13KeyUpdateAck(ssl); + int ret = DoDtls13KeyUpdateAck(ssl); if (ret != 0) return ret; } diff --git a/src/internal.c b/src/internal.c index 756f2812fc..ef53cf8d1a 100644 --- a/src/internal.c +++ b/src/internal.c @@ -7350,6 +7350,9 @@ int InitSSL(WOLFSSL* ssl, WOLFSSL_CTX* ctx, int writeDup) for (i = 0; i < WOLFSSL_THREADED_CRYPT_CNT; i++) { ssl->buffers.encrypt[i].avail = 1; } + for (i = 0; i < WOLFSSL_THREADED_CRYPT_CNT; i++) { + ssl->buffers.decrypt[i].avail = 1; + } } #endif @@ -8234,7 +8237,21 @@ void wolfSSL_ResourceFree(WOLFSSL* ssl) bufferStatic* buff = &ssl->buffers.encrypt[i].buffer; ssl->buffers.encrypt[i].stop = 1; - FreeCiphersSide(&ssl->buffers.encrypt[i].encrypt, ssl->heap); + FreeCiphersSide(&ssl->buffers.encrypt[i].cipher, ssl->heap); + if (buff->dynamicFlag) { + XFREE(buff->buffer - buff->offset, ssl->heap, + DYNAMIC_TYPE_OUT_BUFFER); + buff->buffer = buff->staticBuffer; + buff->bufferSize = STATIC_BUFFER_LEN; + buff->offset = 0; + buff->dynamicFlag = 0; + } + } + for (i = 0; i < WOLFSSL_THREADED_CRYPT_CNT; i++) { + bufferStatic* buff = &ssl->buffers.decrypt[i].buffer; + + ssl->buffers.decrypt[i].stop = 1; + FreeCiphersSide(&ssl->buffers.decrypt[i].cipher, ssl->heap); if (buff->dynamicFlag) { XFREE(buff->buffer - buff->offset, ssl->heap, DYNAMIC_TYPE_OUT_BUFFER); @@ -10761,7 +10778,6 @@ int SendBuffered(WOLFSSL* ssl) return 0; } -#ifdef WOLFSSL_THREADED_CRYPT static WC_INLINE int GrowAnOutputBuffer(WOLFSSL* ssl, bufferStatic* outputBuffer, int size) { @@ -10773,6 +10789,7 @@ static WC_INLINE int GrowAnOutputBuffer(WOLFSSL* ssl, #else const byte align = WOLFSSL_GENERAL_ALIGNMENT; #endif + word32 newSz; #if WOLFSSL_GENERAL_ALIGNMENT > 0 /* the encrypted data will be offset from the front of the buffer by @@ -10783,8 +10800,14 @@ static WC_INLINE int GrowAnOutputBuffer(WOLFSSL* ssl, align *= 2; #endif - tmp = (byte*)XMALLOC(size + outputBuffer->length + align, - ssl->heap, DYNAMIC_TYPE_OUT_BUFFER); + if (! WC_SAFE_SUM_WORD32(outputBuffer->idx, outputBuffer->length, newSz)) + return BUFFER_E; + if (! WC_SAFE_SUM_WORD32(newSz, (word32)size, newSz)) + return BUFFER_E; + if (! WC_SAFE_SUM_WORD32(newSz, align, newSz)) + return BUFFER_E; + tmp = (byte*)XMALLOC(newSz, ssl->heap, DYNAMIC_TYPE_OUT_BUFFER); + newSz -= align; WOLFSSL_MSG("growing output buffer"); if (tmp == NULL) @@ -10798,7 +10821,7 @@ static WC_INLINE int GrowAnOutputBuffer(WOLFSSL* ssl, #ifdef WOLFSSL_STATIC_MEMORY /* can be from IO memory pool which does not need copy if same buffer */ if (outputBuffer->length && tmp == outputBuffer->buffer) { - outputBuffer->bufferSize = size + outputBuffer->length; + outputBuffer->bufferSize = newSz; return 0; } #endif @@ -10810,6 +10833,7 @@ static WC_INLINE int GrowAnOutputBuffer(WOLFSSL* ssl, XFREE(outputBuffer->buffer - outputBuffer->offset, ssl->heap, DYNAMIC_TYPE_OUT_BUFFER); } + outputBuffer->dynamicFlag = 1; #if WOLFSSL_GENERAL_ALIGNMENT > 0 if (align) @@ -10819,11 +10843,9 @@ static WC_INLINE int GrowAnOutputBuffer(WOLFSSL* ssl, outputBuffer->offset = 0; outputBuffer->buffer = tmp; - outputBuffer->dynamicFlag = 1; - outputBuffer->bufferSize = size + outputBuffer->length; + outputBuffer->bufferSize = newSz; return 0; } -#endif /* returns the current location in the output buffer to start writing to */ byte* GetOutputBuffer(WOLFSSL* ssl) @@ -10836,80 +10858,13 @@ byte* GetOutputBuffer(WOLFSSL* ssl) /* Grow the output buffer */ static WC_INLINE int GrowOutputBuffer(WOLFSSL* ssl, int size) { - byte* tmp; -#if WOLFSSL_GENERAL_ALIGNMENT > 0 - byte hdrSz = ssl->options.dtls ? DTLS_RECORD_HEADER_SZ : - RECORD_HEADER_SZ; - byte align = WOLFSSL_GENERAL_ALIGNMENT; -#else - const byte align = WOLFSSL_GENERAL_ALIGNMENT; -#endif - word32 newSz; - -#if WOLFSSL_GENERAL_ALIGNMENT > 0 - /* the encrypted data will be offset from the front of the buffer by - the header, if the user wants encrypted alignment they need - to define their alignment requirement */ - - while (align < hdrSz) - align *= 2; -#endif - - if (! WC_SAFE_SUM_WORD32(ssl->buffers.outputBuffer.idx, - ssl->buffers.outputBuffer.length, newSz)) - return BUFFER_E; - if (! WC_SAFE_SUM_WORD32(newSz, (word32)size, newSz)) - return BUFFER_E; - if (! WC_SAFE_SUM_WORD32(newSz, align, newSz)) - return BUFFER_E; - tmp = (byte*)XMALLOC(newSz, ssl->heap, DYNAMIC_TYPE_OUT_BUFFER); - newSz -= align; - WOLFSSL_MSG("growing output buffer"); - - if (tmp == NULL) - return MEMORY_E; - -#if WOLFSSL_GENERAL_ALIGNMENT > 0 - if (align) - tmp += align - hdrSz; -#endif - -#ifdef WOLFSSL_STATIC_MEMORY - /* can be from IO memory pool which does not need copy if same buffer */ - if (ssl->buffers.outputBuffer.length && - tmp == ssl->buffers.outputBuffer.buffer) { - ssl->buffers.outputBuffer.bufferSize = newSz; - return 0; - } -#endif - - if (ssl->buffers.outputBuffer.length) - XMEMCPY(tmp, ssl->buffers.outputBuffer.buffer, - ssl->buffers.outputBuffer.idx + - ssl->buffers.outputBuffer.length); - - if (ssl->buffers.outputBuffer.dynamicFlag) { - XFREE(ssl->buffers.outputBuffer.buffer - - ssl->buffers.outputBuffer.offset, ssl->heap, - DYNAMIC_TYPE_OUT_BUFFER); - } - ssl->buffers.outputBuffer.dynamicFlag = 1; - -#if WOLFSSL_GENERAL_ALIGNMENT > 0 - if (align) - ssl->buffers.outputBuffer.offset = align - hdrSz; - else -#endif - ssl->buffers.outputBuffer.offset = 0; - - ssl->buffers.outputBuffer.buffer = tmp; - ssl->buffers.outputBuffer.bufferSize = newSz; - return 0; + return GrowAnOutputBuffer(ssl, &ssl->buffers.outputBuffer, size); } /* Grow the input buffer, should only be to read cert or big app data */ -int GrowInputBuffer(WOLFSSL* ssl, int size, int usedLength) +static int GrowAnInputBuffer(WOLFSSL* ssl, bufferStatic* inputBuffer, int size, + int usedLength) { byte* tmp; #if defined(WOLFSSL_DTLS) || WOLFSSL_GENERAL_ALIGNMENT > 0 @@ -10950,43 +10905,46 @@ int GrowInputBuffer(WOLFSSL* ssl, int size, int usedLength) #ifdef WOLFSSL_STATIC_MEMORY /* can be from IO memory pool which does not need copy if same buffer */ - if (usedLength && tmp == ssl->buffers.inputBuffer.buffer) { - ssl->buffers.inputBuffer.bufferSize = size + usedLength; - ssl->buffers.inputBuffer.idx = 0; - ssl->buffers.inputBuffer.length = usedLength; + if (usedLength && tmp == inputBuffer->buffer) { + inputBuffer->bufferSize = size + usedLength; + inputBuffer->idx = 0; + inputBuffer->length = usedLength; return 0; } #endif if (usedLength) - XMEMCPY(tmp, ssl->buffers.inputBuffer.buffer + - ssl->buffers.inputBuffer.idx, usedLength); + XMEMCPY(tmp, inputBuffer->buffer + inputBuffer->idx, usedLength); - if (ssl->buffers.inputBuffer.dynamicFlag) { + if (inputBuffer->dynamicFlag) { if (IsEncryptionOn(ssl, 1)) { - ForceZero(ssl->buffers.inputBuffer.buffer, - ssl->buffers.inputBuffer.length); + ForceZero(inputBuffer->buffer, inputBuffer->length); } - XFREE(ssl->buffers.inputBuffer.buffer - ssl->buffers.inputBuffer.offset, - ssl->heap, DYNAMIC_TYPE_IN_BUFFER); + XFREE(inputBuffer->buffer - inputBuffer->offset, ssl->heap, + DYNAMIC_TYPE_IN_BUFFER); } - ssl->buffers.inputBuffer.dynamicFlag = 1; + inputBuffer->dynamicFlag = 1; #if defined(WOLFSSL_DTLS) || WOLFSSL_GENERAL_ALIGNMENT > 0 if (align) - ssl->buffers.inputBuffer.offset = align - hdrSz; + inputBuffer->offset = align - hdrSz; else #endif - ssl->buffers.inputBuffer.offset = 0; + inputBuffer->offset = 0; - ssl->buffers.inputBuffer.buffer = tmp; - ssl->buffers.inputBuffer.bufferSize = (word32)(size + usedLength); - ssl->buffers.inputBuffer.idx = 0; - ssl->buffers.inputBuffer.length = (word32)usedLength; + inputBuffer->buffer = tmp; + inputBuffer->bufferSize = (word32)(size + usedLength); + inputBuffer->idx = 0; + inputBuffer->length = (word32)usedLength; return 0; } +/* Grow the input buffer, should only be to read cert or big app data */ +int GrowInputBuffer(WOLFSSL* ssl, int size, int usedLength) +{ + return GrowAnInputBuffer(ssl, &ssl->buffers.inputBuffer, size, usedLength); +} /* Check available size into output buffer, make room if needed. * This function needs to be called before anything gets put @@ -21331,6 +21289,77 @@ static int removeMsgInnerPadding(WOLFSSL* ssl) } #endif +#ifdef WOLFSSL_THREADED_CRYPT +static int ReceiveAsyncData(WOLFSSL* ssl) +{ + int i; + int cnt; + + do { + cnt = 0; + for (i = 0; i < WOLFSSL_THREADED_CRYPT_CNT; i++) { + ThreadCrypt* decrypt = &ssl->buffers.decrypt[i]; + + if (decrypt->avail) { + cnt++; + } + else if (decrypt->done) { + int ret; + int error; + + + /* Parse record header again. */ + GrowInputBuffer(ssl, decrypt->recordHdrLen, 0); + XMEMCPY(ssl->buffers.inputBuffer.buffer, decrypt->recordHdr, + decrypt->recordHdrLen); + ssl->buffers.inputBuffer.length = decrypt->recordHdrLen; + ssl->buffers.inputBuffer.idx = 0; + ret = GetRecordHeader(ssl, &ssl->buffers.inputBuffer.idx, + &ssl->curRL, &ssl->curSize); + if (ret != 0) { + return ret; + } + + GrowInputBuffer(ssl, decrypt->buffer.length, 0); + XMEMCPY(ssl->buffers.inputBuffer.buffer, decrypt->buffer.buffer, + decrypt->buffer.length); + ssl->buffers.inputBuffer.length = decrypt->buffer.length; + ssl->buffers.inputBuffer.idx = RECORD_HEADER_SZ; + + decrypt->done = 0; + decrypt->avail = 1; + + ssl->curStartIdx = ssl->buffers.inputBuffer.idx; + ssl->options.processReply = runProcessingOneRecord; + if ((error = ProcessReply(ssl)) < 0) { + if (error == WC_NO_ERR_TRACE(ZERO_RETURN)) { + ssl->error = error; + WOLFSSL_MSG("Zero return, no more data coming"); + return 0; /* no more data coming */ + } + if (error == WC_NO_ERR_TRACE(SOCKET_ERROR_E)) { + if (ssl->options.connReset || ssl->options.isClosed) { + WOLFSSL_MSG( + "Peer reset or closed, connection done"); + error = SOCKET_PEER_CLOSED_E; + ssl->error = error; + WOLFSSL_ERROR(error); + return 0; /* peer reset or closed */ + } + } + ssl->error = error; + WOLFSSL_ERROR(error); + return error; + } + } + } + } + while (cnt == 0); + + return 0; +} +#endif + int ProcessReply(WOLFSSL* ssl) { return ProcessReplyEx(ssl, 0); @@ -21347,6 +21376,13 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr) #if defined(WOLFSSL_DTLS) int used; #endif +#ifdef WOLFSSL_THREADED_CRYPT + int recordHdrIdx; + byte recordHdr[DTLS_RECORD_HEADER_MAX_SZ]; + int recordHdrLen = 0; + byte origRecordHdr[DTLS_RECORD_HEADER_MAX_SZ + 16]; + int origRecordHdrLen = 0; +#endif #ifdef ATOMIC_USER if (ssl->ctx->DecryptVerifyCb) @@ -21530,6 +21566,15 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr) /* get the record layer header */ case getRecordLayerHeader: +#ifdef WOLFSSL_THREADED_CRYPT + recordHdrIdx = ssl->buffers.inputBuffer.idx; + origRecordHdrLen = min(ssl->buffers.inputBuffer.length, + DTLS_RECORD_HEADER_MAX_SZ + 16); + XMEMCPY(origRecordHdr, + ssl->buffers.inputBuffer.buffer + recordHdrIdx, + origRecordHdrLen); +#endif + /* DTLSv1.3 record numbers in the header are encrypted, and AAD * uses the unencrypted form. Because of this we need to modify the * header, decrypting the numbers inside @@ -21538,6 +21583,15 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr) ret = GetRecordHeader(ssl, &ssl->buffers.inputBuffer.idx, &ssl->curRL, &ssl->curSize); +#ifdef WOLFSSL_THREADED_CRYPT + if (ret == 0) { + recordHdrLen = ssl->buffers.inputBuffer.idx - recordHdrIdx; + XMEMCPY(recordHdr, + ssl->buffers.inputBuffer.buffer + recordHdrIdx, + recordHdrLen); + } +#endif + #ifdef WOLFSSL_DTLS if (ssl->options.dtls && DtlsShouldDrop(ssl, ret)) { ssl->options.processReply = doProcessInit; @@ -21697,6 +21751,75 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr) return ret; } +#ifdef WOLFSSL_THREADED_CRYPT + if (ssl->options.handShakeDone && + ssl->buffers.decryptSignalRegistered) { + int i; + ThreadCrypt* decrypt = NULL; + byte *aad = (byte*)&ssl->curRL; + word16 aad_size = RECORD_HEADER_SZ; + #ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls) { + /* aad now points to the record header */ + aad = ssl->dtls13CurRL; + aad_size = ssl->dtls13CurRlLength; + } + #endif /* WOLFSSL_DTLS13 */ + + WOLFSSL_MSG("Not decrypting\n"); + + for (i = 0; i < WOLFSSL_THREADED_CRYPT_CNT; i++) { + if (ssl->buffers.decrypt[i].avail) { + decrypt = &ssl->buffers.decrypt[i]; + break; + } + } + decrypt->done = 0; + decrypt->avail = 0; + + GrowAnInputBuffer(ssl, &decrypt->buffer, + recordHdrLen + ssl->curSize, 0); + + XMEMCPY(decrypt->buffer.buffer, recordHdr, recordHdrLen); + decrypt->offset = recordHdrLen; + XMEMCPY(decrypt->recordHdr, origRecordHdr, + origRecordHdrLen); + decrypt->recordHdrLen = origRecordHdrLen; + XMEMCPY(decrypt->buffer.buffer + decrypt->offset, + in->buffer + in->idx, ssl->curSize); + decrypt->buffer.length = recordHdrLen + ssl->curSize; + decrypt->cryptLen = ssl->curSize; + + if (!decrypt->init) { + SetKeys(NULL, &decrypt->cipher, &ssl->keys, &ssl->specs, + ssl->options.side, ssl->heap, ssl->devId, ssl->rng, + ssl->options.tls1_3); + decrypt->init = 1; + } + #ifdef HAVE_TRUNCATED_HMAC + if (ssl->truncated_hmac) { + decrypt->cryptLen -= min(TRUNCATED_HMAC_SZ, + ssl->specs.hash_size); + } + else + #endif + { + decrypt->cryptLen -= ssl->specs.aead_mac_size; + } + + XMEMCPY(decrypt->additional, aad, aad_size); + BuildTls13Nonce(ssl, decrypt->nonce, + ssl->keys.aead_dec_imp_IV, PEER_ORDER); + + if (decrypt->signal != NULL) { + decrypt->signal(decrypt->signalCtx, ssl); + } + + ssl->options.processReply = doProcessInit; + in->idx += ssl->curSize; + return ret; + } +#endif if (atomicUser) { #ifdef ATOMIC_USER #if defined(HAVE_ENCRYPT_THEN_MAC) && !defined(WOLFSSL_AEAD_ONLY) @@ -23292,51 +23415,51 @@ int BuildMessage(WOLFSSL* ssl, byte* output, int outSz, const byte* input, } #endif -#ifdef WOLFSSL_THREADED_CRYPT - if (asyncOkay) { - WOLFSSL_MSG("Not encrypting\n"); - /* make sure build message state is reset */ - ssl->options.buildMsgState = BUILD_MSG_BEGIN; + #ifdef WOLFSSL_THREADED_CRYPT + if (asyncOkay && ssl->buffers.encryptSignalRegistered) { + WOLFSSL_MSG("Not encrypting\n"); + /* make sure build message state is reset */ + ssl->options.buildMsgState = BUILD_MSG_BEGIN; - /* return sz on success */ - if (ret == 0) { - ret = args->sz; - } - else { - WOLFSSL_ERROR_VERBOSE(ret); - } + /* return sz on success */ + if (ret == 0) { + ret = args->sz; + } + else { + WOLFSSL_ERROR_VERBOSE(ret); + } - /* Final cleanup */ - FreeBuildMsgArgs(ssl, args); + /* Final cleanup */ + FreeBuildMsgArgs(ssl, args); - return ret; - } - else -#endif - { - #if defined(HAVE_ENCRYPT_THEN_MAC) && !defined(WOLFSSL_AEAD_ONLY) - if (ssl->options.startedETMWrite) { - ret = Encrypt(ssl, output + args->headerSz, - output + args->headerSz, - (word16)(args->size - args->digestSz), - asyncOkay, args->type); + return ret; } else - #endif + #endif { - ret = Encrypt(ssl, output + args->headerSz, - output + args->headerSz, args->size, asyncOkay, - args->type); - } - #if defined(HAVE_SECURE_RENEGOTIATION) && defined(WOLFSSL_DTLS) - /* Restore sequence numbers */ - if (swap_seq) { - ssl->keys.dtls_epoch = dtls_epoch; - ssl->keys.dtls_sequence_number_hi = dtls_sequence_number_hi; - ssl->keys.dtls_sequence_number_lo = dtls_sequence_number_lo; + #if defined(HAVE_ENCRYPT_THEN_MAC) && !defined(WOLFSSL_AEAD_ONLY) + if (ssl->options.startedETMWrite) { + ret = Encrypt(ssl, output + args->headerSz, + output + args->headerSz, + (word16)(args->size - args->digestSz), + asyncOkay, args->type); + } + else + #endif + { + ret = Encrypt(ssl, output + args->headerSz, + output + args->headerSz, args->size, + asyncOkay, args->type); + } + #if defined(HAVE_SECURE_RENEGOTIATION) && defined(WOLFSSL_DTLS) + /* Restore sequence numbers */ + if (swap_seq) { + ssl->keys.dtls_epoch = dtls_epoch; + ssl->keys.dtls_sequence_number_hi = dtls_sequence_number_hi; + ssl->keys.dtls_sequence_number_lo = dtls_sequence_number_lo; + } + #endif } - #endif - } } if (ret != 0) { @@ -24982,15 +25105,17 @@ int SendData(WOLFSSL* ssl, const void* data, int sz) } #ifdef WOLFSSL_THREADED_CRYPT - ret = SendAsyncData(ssl); - if (ret != 0) { - ssl->error = ret; - return WOLFSSL_FATAL_ERROR; - } - if (ssl->dtls13WaitKeyUpdateAck) { - ret = DoDtls13KeyUpdateAck(ssl); - if (ret != 0) - return ret; + if (ssl->buffers.encryptSignalRegistered) { + ret = SendAsyncData(ssl); + if (ret != 0) { + ssl->error = ret; + return WOLFSSL_FATAL_ERROR; + } + if (ssl->dtls13WaitKeyUpdateAck) { + ret = DoDtls13KeyUpdateAck(ssl); + if (ret != 0) + return ret; + } } #endif @@ -25093,30 +25218,34 @@ int SendData(WOLFSSL* ssl, const void* data, int sz) return (ssl->error = ret); /* get output buffer */ -#ifndef WOLFSSL_THREADED_CRYPT - out = GetOutputBuffer(ssl); -#else - do { - for (i = 0; i < WOLFSSL_THREADED_CRYPT_CNT; i++) { - if (ssl->buffers.encrypt[i].avail) { - encrypt = &ssl->buffers.encrypt[i]; - break; +#ifdef WOLFSSL_THREADED_CRYPT + if (ssl->buffers.encryptSignalRegistered) { + do { + for (i = 0; i < WOLFSSL_THREADED_CRYPT_CNT; i++) { + if (ssl->buffers.encrypt[i].avail) { + encrypt = &ssl->buffers.encrypt[i]; + break; + } } - } - if (encrypt == NULL) { - ret = SendAsyncData(ssl); - if (ret != 0) { - ssl->error = ret; - return WOLFSSL_FATAL_ERROR; + if (encrypt == NULL) { + ret = SendAsyncData(ssl); + if (ret != 0) { + ssl->error = ret; + return WOLFSSL_FATAL_ERROR; + } } } + while (encrypt == NULL); + encrypt->done = 0; + encrypt->avail = 0; + GrowAnOutputBuffer(ssl, &encrypt->buffer, outputSz); + out = encrypt->buffer.buffer; } - while (encrypt == NULL); - encrypt->done = 0; - encrypt->avail = 0; - GrowAnOutputBuffer(ssl, &encrypt->buffer, outputSz); - out = encrypt->buffer.buffer; + else #endif + { + out = GetOutputBuffer(ssl); + } #ifdef HAVE_LIBZ if (ssl->options.usingCompression) { @@ -25161,79 +25290,86 @@ int SendData(WOLFSSL* ssl, const void* data, int sz) FreeAsyncCtx(ssl, 0); #endif #ifdef WOLFSSL_THREADED_CRYPT - if (!encrypt->init) { - SetKeys(&encrypt->encrypt, NULL, &ssl->keys, &ssl->specs, - ssl->options.side, ssl->heap, ssl->devId, ssl->rng, - ssl->options.tls1_3); - encrypt->init = 1; - } - encrypt->buffer.length = sendSz; - encrypt->offset = RECORD_HEADER_SZ; - if (ssl->options.dtls) { - encrypt->offset += DTLS_RECORD_EXTRA; - } - encrypt->cryptLen = outputSz - encrypt->offset; - #ifdef HAVE_TRUNCATED_HMAC - if (ssl->truncated_hmac) { - encrypt->cryptLen -= min(TRUNCATED_HMAC_SZ, ssl->specs.hash_size); - } - else - #endif - { - encrypt->cryptLen -= ssl->specs.hash_size; - } - -#if !defined(NO_PUBLIC_GCM_SET_IV) && \ - ((defined(HAVE_FIPS) || defined(HAVE_SELFTEST)) && \ - (!defined(HAVE_FIPS_VERSION) || (HAVE_FIPS_VERSION < 2))) - XMEMCPY(encrypt->nonce, ssl->keys.aead_enc_imp_IV, AESGCM_IMP_IV_SZ); - XMEMCPY(encrypt->nonce + AESGCM_IMP_IV_SZ, ssl->keys.aead_exp_IV, - AESGCM_EXP_IV_SZ); -#endif - XMEMSET(encrypt->additional, 0, AEAD_AUTH_DATA_SZ); - WriteSEQ(ssl, CUR_ORDER, encrypt->additional); - XMEMCPY(encrypt->additional + AEAD_TYPE_OFFSET, encrypt->buffer.buffer, - 3); - c16toa(sendSz - encrypt->offset - AESGCM_EXP_IV_SZ - - ssl->specs.aead_mac_size, encrypt->additional + AEAD_LEN_OFFSET); + if (ssl->buffers.encryptSignalRegistered) { + if (!encrypt->init) { + SetKeys(&encrypt->cipher, NULL, &ssl->keys, &ssl->specs, + ssl->options.side, ssl->heap, ssl->devId, ssl->rng, + ssl->options.tls1_3); + encrypt->init = 1; + } + encrypt->buffer.length = sendSz; + encrypt->offset = RECORD_HEADER_SZ; + if (ssl->options.dtls) { + encrypt->offset += DTLS_RECORD_EXTRA; + } + encrypt->cryptLen = outputSz - encrypt->offset; + #ifdef HAVE_TRUNCATED_HMAC + if (ssl->truncated_hmac) { + encrypt->cryptLen -= min(TRUNCATED_HMAC_SZ, + ssl->specs.hash_size); + } + else + #endif + { + encrypt->cryptLen -= ssl->specs.hash_size; + } - #ifdef WOLFSSL_DTLS - if (ssl->options.dtls) - DtlsSEQIncrement(ssl, CUR_ORDER); + #if !defined(NO_PUBLIC_GCM_SET_IV) && \ + ((defined(HAVE_FIPS) || defined(HAVE_SELFTEST)) && \ + (!defined(HAVE_FIPS_VERSION) || (HAVE_FIPS_VERSION < 2))) + XMEMCPY(encrypt->nonce, ssl->keys.aead_enc_imp_IV, + AESGCM_IMP_IV_SZ); + XMEMCPY(encrypt->nonce + AESGCM_IMP_IV_SZ, ssl->keys.aead_exp_IV, + AESGCM_EXP_IV_SZ); #endif + XMEMSET(encrypt->additional, 0, AEAD_AUTH_DATA_SZ); + WriteSEQ(ssl, CUR_ORDER, encrypt->additional); + XMEMCPY(encrypt->additional + AEAD_TYPE_OFFSET, + encrypt->buffer.buffer, 3); + c16toa(sendSz - encrypt->offset - AESGCM_EXP_IV_SZ - + ssl->specs.aead_mac_size, + encrypt->additional + AEAD_LEN_OFFSET); - if (encrypt->signal != NULL) { - encrypt->signal(encrypt->signalCtx, ssl); + #ifdef WOLFSSL_DTLS + if (ssl->options.dtls) + DtlsSEQIncrement(ssl, CUR_ORDER); + #endif + + if (encrypt->signal != NULL) { + encrypt->signal(encrypt->signalCtx, ssl); + } + return sendSz; } - return sendSz; -#else - ssl->buffers.outputBuffer.length += (word32)sendSz; + else +#endif + { + ssl->buffers.outputBuffer.length += (word32)sendSz; - if ( (error = SendBuffered(ssl)) < 0) { - ssl->error = error; - WOLFSSL_ERROR(error); - /* store for next call if WANT_WRITE or user embedSend() that - doesn't present like WANT_WRITE */ - ssl->buffers.plainSz = buffSz; - ssl->buffers.prevSent = sent; - if (error == WC_NO_ERR_TRACE(SOCKET_ERROR_E) && - (ssl->options.connReset || ssl->options.isClosed)) { - error = SOCKET_PEER_CLOSED_E; - ssl->error = SOCKET_PEER_CLOSED_E; + if ( (error = SendBuffered(ssl)) < 0) { + ssl->error = error; WOLFSSL_ERROR(error); - return 0; /* peer reset or closed */ + /* store for next call if WANT_WRITE or user embedSend() that + doesn't present like WANT_WRITE */ + ssl->buffers.plainSz = buffSz; + ssl->buffers.prevSent = sent; + if (error == WC_NO_ERR_TRACE(SOCKET_ERROR_E) && + (ssl->options.connReset || ssl->options.isClosed)) { + error = SOCKET_PEER_CLOSED_E; + ssl->error = SOCKET_PEER_CLOSED_E; + WOLFSSL_ERROR(error); + return 0; /* peer reset or closed */ + } + return error; } - return error; - } - sent += buffSz; + sent += buffSz; - /* only one message per attempt */ - if (ssl->options.partialWrite == 1) { - WOLFSSL_MSG("Partial Write on, only sending one record"); - break; + /* only one message per attempt */ + if (ssl->options.partialWrite == 1) { + WOLFSSL_MSG("Partial Write on, only sending one record"); + break; + } } -#endif } return sent; @@ -25286,19 +25422,17 @@ int ReceiveData(WOLFSSL* ssl, byte* output, int sz, int peek) } else #endif - { - if (ssl_in_handshake(ssl, 0)) { - int err; - WOLFSSL_MSG("Handshake not complete, trying to finish"); - if ( (err = wolfSSL_negotiate(ssl)) != WOLFSSL_SUCCESS) { - #ifdef WOLFSSL_ASYNC_CRYPT - /* if async would block return WANT_WRITE */ - if (ssl->error == WC_NO_ERR_TRACE(WC_PENDING_E)) { - return WOLFSSL_CBIO_ERR_WANT_READ; - } - #endif - return err; + if (ssl_in_handshake(ssl, 0)) { + int err; + WOLFSSL_MSG("Handshake not complete, trying to finish"); + if ( (err = wolfSSL_negotiate(ssl)) != WOLFSSL_SUCCESS) { + #ifdef WOLFSSL_ASYNC_CRYPT + /* if async would block return WANT_WRITE */ + if (ssl->error == WC_NO_ERR_TRACE(WC_PENDING_E)) { + return WOLFSSL_CBIO_ERR_WANT_READ; } + #endif + return err; } } @@ -25315,6 +25449,16 @@ int ReceiveData(WOLFSSL* ssl, byte* output, int sz, int peek) #endif while (ssl->buffers.clearOutputBuffer.length == 0) { + #ifdef WOLFSSL_THREADED_CRYPT + if (ssl->buffers.decryptSignalRegistered && !ssl_in_handshake(ssl, 0)) { + error = ReceiveAsyncData(ssl); + if (error != 0) { + ssl->error = error; + WOLFSSL_ERROR(error); + return WOLFSSL_FATAL_ERROR; + } + } + #endif if ( (error = ProcessReply(ssl)) < 0) { if (error == WC_NO_ERR_TRACE(ZERO_RETURN)) { ssl->error = error; diff --git a/src/ssl.c b/src/ssl.c index b11ed59a7e..0afddab226 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -23930,7 +23930,7 @@ int wolfSSL_AsyncEncrypt(WOLFSSL* ssl, int idx) #else wc_AesGcmEncrypt #endif - (encrypt->encrypt.aes, + (encrypt->cipher.aes, out + AESGCM_EXP_IV_SZ, input + AESGCM_EXP_IV_SZ, encSz - AESGCM_EXP_IV_SZ - ssl->specs.aead_mac_size, encrypt->nonce, AESGCM_NONCE_SZ, @@ -23959,6 +23959,72 @@ int wolfSSL_AsyncEncryptSetSignal(WOLFSSL* ssl, int idx, else { ssl->buffers.encrypt[idx].signal = signal; ssl->buffers.encrypt[idx].signalCtx = ctx; + + ssl->buffers.encryptSignalRegistered = 1; + } + + return ret; +} + + +int wolfSSL_AsyncDecryptReady(WOLFSSL* ssl, int idx) +{ + ThreadCrypt* decrypt; + + if (ssl == NULL) { + return 0; + } + + decrypt = &ssl->buffers.decrypt[idx]; + return (decrypt->avail == 0) && (decrypt->done == 0); +} + +int wolfSSL_AsyncDecryptStop(WOLFSSL* ssl, int idx) +{ + ThreadCrypt* decrypt; + + if (ssl == NULL) { + return 1; + } + + decrypt = &ssl->buffers.decrypt[idx]; + return decrypt->stop; +} + +int wolfSSL_AsyncDecrypt(WOLFSSL* ssl, int idx) +{ + int ret = WC_NO_ERR_TRACE(NOT_COMPILED_IN); + ThreadCrypt* decrypt = &ssl->buffers.decrypt[idx]; + + if (ssl->specs.bulk_cipher_algorithm == wolfssl_aes_gcm) { + unsigned char* out = decrypt->buffer.buffer + decrypt->offset; + unsigned char* input = decrypt->buffer.buffer + decrypt->offset; + unsigned char* tag = input + decrypt->cryptLen; + + ret = wc_AesGcmDecrypt(decrypt->cipher.aes, out, input, + decrypt->cryptLen, + decrypt->nonce, AESGCM_NONCE_SZ, + tag, ssl->specs.aead_mac_size, + decrypt->additional, RECORD_HEADER_SZ); + decrypt->done = 1; + } + + return ret; +} + +int wolfSSL_AsyncDecryptSetSignal(WOLFSSL* ssl, int idx, + WOLFSSL_THREAD_SIGNAL signal, void* ctx) +{ + int ret = 0; + + if (ssl == NULL) { + ret = BAD_FUNC_ARG; + } + else { + ssl->buffers.decrypt[idx].signal = signal; + ssl->buffers.decrypt[idx].signalCtx = ctx; + + ssl->buffers.decryptSignalRegistered = 1; } return ret; diff --git a/src/tls13.c b/src/tls13.c index 0d5a8b9365..96c9ae9c48 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -2425,8 +2425,10 @@ static WC_INLINE void WriteSEQTls13(WOLFSSL* ssl, int verifyOrder, byte* out) * iv The derived IV. * order The side on which the message is to be or was sent. */ -static WC_INLINE void BuildTls13Nonce(WOLFSSL* ssl, byte* nonce, const byte* iv, - int order) +#ifndef WOLFSSL_THREADED_CRYPT +static WC_INLINE +#endif +void BuildTls13Nonce(WOLFSSL* ssl, byte* nonce, const byte* iv, int order) { int i; @@ -2655,7 +2657,6 @@ static int EncryptTls13(WOLFSSL* ssl, byte* output, const byte* input, if (ret == WC_NO_ERR_TRACE(NOT_COMPILED_IN)) #endif { - #if ((defined(HAVE_FIPS) || defined(HAVE_SELFTEST)) && \ (!defined(HAVE_FIPS_VERSION) || (HAVE_FIPS_VERSION < 2))) ret = wc_AesGcmEncrypt(ssl->encrypt.aes, output, input, diff --git a/tests/api.c b/tests/api.c index e7b64124a6..ec4087a056 100644 --- a/tests/api.c +++ b/tests/api.c @@ -90683,6 +90683,7 @@ static void test_AEAD_limit_client(WOLFSSL* ssl) /* Key update should be sent and negotiated */ ret = wolfSSL_read(ssl, msgBuf, sizeof(msgBuf)); AssertIntGT(ret, 0); + ret = wolfSSL_write(ssl, msgBuf, sizeof(msgBuf)); /* Epoch after one key update is 4 */ if (w64Equal(ssl->dtls13PeerEpoch, w64From32(0, 4)) && w64Equal(Dtls13GetEpoch(ssl, ssl->dtls13PeerEpoch)->dropCount, counter)) { @@ -90706,6 +90707,7 @@ static void test_AEAD_limit_client(WOLFSSL* ssl) /* Key update should be sent and negotiated */ ret = wolfSSL_read(ssl, msgBuf, sizeof(msgBuf)); AssertIntGT(ret, 0); + ret = wolfSSL_write(ssl, msgBuf, sizeof(msgBuf)); /* Epoch after another key update is 5 */ if (w64Equal(ssl->dtls13Epoch, w64From32(0, 5)) && w64Equal(Dtls13GetEpoch(ssl, ssl->dtls13Epoch)->dropCount, counter)) { diff --git a/wolfssl/internal.h b/wolfssl/internal.h index f5ce5b02ef..dc9afda97b 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -4743,10 +4743,12 @@ enum AcceptStateTls13 { #include typedef struct ThreadCrypt { - Ciphers encrypt; + Ciphers cipher; bufferStatic buffer; unsigned char nonce[AESGCM_NONCE_SZ]; unsigned char additional[AEAD_AUTH_DATA_SZ]; + unsigned char recordHdr[DTLS_RECORD_HEADER_MAX_SZ + 16]; + word32 recordHdrLen; int init; int offset; int cryptLen; @@ -4764,7 +4766,10 @@ typedef struct Buffers { bufferStatic inputBuffer; bufferStatic outputBuffer; #ifdef WOLFSSL_THREADED_CRYPT + int encryptSignalRegistered; ThreadCrypt encrypt[WOLFSSL_THREADED_CRYPT_CNT]; + int decryptSignalRegistered; + ThreadCrypt decrypt[WOLFSSL_THREADED_CRYPT_CNT]; #endif buffer domainName; /* for client check */ buffer clearOutputBuffer; @@ -6867,6 +6872,11 @@ WOLFSSL_LOCAL int BuildMessage(WOLFSSL* ssl, byte* output, int outSz, int sizeOnly, int asyncOkay, int epochOrder); #ifdef WOLFSSL_TLS13 +#ifdef WOLFSSL_THREADED_CRYPT +WOLFSSL_LOCAL void BuildTls13Nonce(WOLFSSL* ssl, byte* nonce, const byte* iv, + int order); +#endif + /* Use WOLFSSL_API to use this function in tests/api.c */ WOLFSSL_API int BuildTls13Message(WOLFSSL* ssl, byte* output, int outSz, const byte* input, int inSz, int type, int hashOutput, int sizeOnly, int asyncOkay); diff --git a/wolfssl/ssl.h b/wolfssl/ssl.h index 8989f52044..07f4ed1598 100644 --- a/wolfssl/ssl.h +++ b/wolfssl/ssl.h @@ -3718,6 +3718,12 @@ WOLFSSL_API int wolfSSL_AsyncEncryptStop(WOLFSSL* ssl, int idx); WOLFSSL_API int wolfSSL_AsyncEncrypt(WOLFSSL* ssl, int idx); WOLFSSL_API int wolfSSL_AsyncEncryptSetSignal(WOLFSSL* ssl, int idx, WOLFSSL_THREAD_SIGNAL signal, void* ctx); + +WOLFSSL_API int wolfSSL_AsyncDecryptReady(WOLFSSL* ssl, int idx); +WOLFSSL_API int wolfSSL_AsyncDecryptStop(WOLFSSL* ssl, int idx); +WOLFSSL_API int wolfSSL_AsyncDecrypt(WOLFSSL* ssl, int idx); +WOLFSSL_API int wolfSSL_AsyncDecryptSetSignal(WOLFSSL* ssl, int idx, + WOLFSSL_THREAD_SIGNAL signal, void* ctx); #endif