diff --git a/neqo-crypto/src/aead_null.rs b/neqo-crypto/src/aead_null.rs index bc0174492b..28dfc204b7 100644 --- a/neqo-crypto/src/aead_null.rs +++ b/neqo-crypto/src/aead_null.rs @@ -42,8 +42,8 @@ impl AeadNull { ) -> Res<&'a [u8]> { let l = input.len(); output[..l].copy_from_slice(input); - output[l..l + 16].copy_from_slice(AEAD_NULL_TAG); - Ok(&output[..l + 16]) + output[l..l + self.expansion()].copy_from_slice(AEAD_NULL_TAG); + Ok(&output[..l + self.expansion()]) } #[allow(clippy::missing_errors_doc)] @@ -52,24 +52,18 @@ impl AeadNull { _count: u64, _aad: Range, input: Range, - _data: &mut [u8], + data: &mut [u8], ) -> Res { - Ok(input.len() + 16) + data[input.end..input.end + self.expansion()].copy_from_slice(AEAD_NULL_TAG); + Ok(input.len() + self.expansion()) } - #[allow(clippy::missing_errors_doc)] - pub fn decrypt<'a>( - &self, - _count: u64, - _aad: &[u8], - input: &[u8], - output: &'a mut [u8], - ) -> Res<&'a [u8]> { - if input.len() < AEAD_NULL_TAG.len() { + fn decrypt_check(&self, _count: u64, _aad: &[u8], input: &[u8]) -> Res { + if input.len() < self.expansion() { return Err(Error::from(SEC_ERROR_BAD_DATA)); } - let len_encrypted = input.len() - AEAD_NULL_TAG.len(); + let len_encrypted = input.len() - self.expansion(); // Check that: // 1) expansion is all zeros and // 2) if the encrypted data is also supplied that at least some values are no zero @@ -77,12 +71,43 @@ impl AeadNull { if &input[len_encrypted..] == AEAD_NULL_TAG && (len_encrypted == 0 || input[..len_encrypted].iter().any(|x| *x != 0x0)) { - output[..len_encrypted].copy_from_slice(&input[..len_encrypted]); - Ok(&output[..len_encrypted]) + Ok(len_encrypted) } else { Err(Error::from(SEC_ERROR_BAD_DATA)) } } + + #[allow(clippy::missing_errors_doc)] + pub fn decrypt<'a>( + &self, + count: u64, + aad: &[u8], + input: &[u8], + output: &'a mut [u8], + ) -> Res<&'a [u8]> { + self.decrypt_check(count, aad, input).map(|len| { + output[..len].copy_from_slice(&input[..len]); + &output[..len] + }) + } + + #[allow(clippy::missing_errors_doc)] + pub fn decrypt_in_place<'a>( + &self, + count: u64, + aad: Range, + input: Range, + data: &'a mut [u8], + ) -> Res<&'a mut [u8]> { + let aad = data + .get(aad) + .ok_or_else(|| Error::from(SEC_ERROR_BAD_DATA))?; + let inp = data + .get(input.clone()) + .ok_or_else(|| Error::from(SEC_ERROR_BAD_DATA))?; + self.decrypt_check(count, aad, inp) + .map(move |len| &mut data[input.start..input.start + len]) + } } impl fmt::Debug for AeadNull { diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index 45955bd181..24a1b2d285 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -788,6 +788,7 @@ impl<'a> PublicPacket<'a> { self.data.len() } + #[must_use] pub fn data(&self) -> &[u8] { self.data }