From de60a0f00ab856701a24ba774cea57483408e01e Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Thu, 23 Jan 2025 10:59:13 +0200 Subject: [PATCH 01/19] feat: In-place crypto Only in-place encryption so far, and only for the main data path. Fixes #2246 (eventually) --- neqo-crypto/src/aead.rs | 35 +++++++++++++++++++++++++++++++- neqo-transport/src/crypto.rs | 23 ++++++++++++--------- neqo-transport/src/packet/mod.rs | 21 ++++++++++++++----- 3 files changed, 63 insertions(+), 16 deletions(-) diff --git a/neqo-crypto/src/aead.rs b/neqo-crypto/src/aead.rs index 059e9acb61..eb8bee0c70 100644 --- a/neqo-crypto/src/aead.rs +++ b/neqo-crypto/src/aead.rs @@ -6,7 +6,7 @@ use std::{ fmt, - ops::{Deref, DerefMut}, + ops::{Deref, DerefMut, Range}, os::raw::{c_char, c_uint}, ptr::null_mut, }; @@ -126,6 +126,39 @@ impl RealAead { Ok(&output[0..(l.try_into()?)]) } + /// Encrypt a plaintext in place. + /// + /// The space provided in `data` needs to allow `Aead::expansion` more bytes to be appended. + /// + /// # Errors + /// + /// If the input can't be protected or any input is too large for NSS. + pub fn encrypt_in_place( + &self, + count: u64, + aad: Range, + input: Range, + data: &mut [u8], + ) -> Res { + let aad = &data[aad]; + let input = &data[input]; + let mut l: c_uint = 0; + unsafe { + SSL_AeadEncrypt( + *self.ctx, + count, + aad.as_ptr(), + c_uint::try_from(aad.len())?, + input.as_ptr(), + c_uint::try_from(input.len())?, + input.as_ptr(), + &mut l, + c_uint::try_from(input.len() + self.expansion())?, + ) + }?; + Ok(l.try_into()?) + } + /// Decrypt a ciphertext. /// /// Note that NSS insists upon having extra space available for decryption, so diff --git a/neqo-transport/src/crypto.rs b/neqo-transport/src/crypto.rs index 64d0e09a3e..ec7044df31 100644 --- a/neqo-transport/src/crypto.rs +++ b/neqo-transport/src/crypto.rs @@ -40,7 +40,6 @@ use crate::{ Error, Res, }; -const MAX_AUTH_TAG: usize = 32; /// The number of invocations remaining on a write cipher before we try /// to update keys. This has to be much smaller than the number returned /// by `CryptoDxState::limit` or updates will happen too often. As we don't @@ -634,12 +633,18 @@ impl CryptoDxState { self.used_pn.end } - pub fn encrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res> { + pub fn encrypt( + &mut self, + pn: PacketNumber, + hdr: Range, + body: Range, + data: &mut [u8], + ) -> Res { debug_assert_eq!(self.direction, CryptoDxDirection::Write); qtrace!( - "[{self}] encrypt pn={pn} hdr={} body={}", - hex(hdr), - hex(body) + "[{self}] encrypt_in_place pn={pn} hdr={} body={}", + hex(data[hdr.clone()].as_ref()), + hex(data[body.clone()].as_ref()) ); // The numbers in `Self::limit` assume a maximum packet size of `LIMIT`. @@ -653,14 +658,12 @@ impl CryptoDxState { } self.invoked()?; - let size = body.len() + MAX_AUTH_TAG; - let mut out = vec![0; size]; - let res = self.aead.encrypt(pn, hdr, body, &mut out)?; + let len = self.aead.encrypt_in_place(pn, hdr, body, data)?; - qtrace!("[{self}] encrypt ct={}", hex(res)); + qtrace!("[{self}] encrypt ct={}", hex(data)); debug_assert_eq!(pn, self.next_pn()); self.used(pn)?; - Ok(res.to_vec()) + Ok(len) } #[must_use] diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index 0c11a4a869..5cdaabdb84 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -418,9 +418,22 @@ impl PacketBuilder { hex(hdr), hex(body) ); - let ciphertext = crypto.encrypt(self.pn, hdr, body)?; + + // Add space for crypto expansion. + let data_end = self.encoder.len(); + for _i in 0..crypto.expansion() { + self.encode_byte(123); + } + + let ciphertext_len = crypto.encrypt( + self.pn, + self.header.clone(), + self.header.end..data_end, + self.encoder.as_mut(), + )?; // Calculate the mask. + let ciphertext = &self.encoder.as_mut()[self.header.end..self.header.end + ciphertext_len]; let offset = SAMPLE_OFFSET - self.offsets.pn.len(); if offset + SAMPLE_SIZE > ciphertext.len() { return Err(Error::InternalError); @@ -434,9 +447,6 @@ impl PacketBuilder { self.encoder.as_mut()[j] ^= mask[i]; } - // Finally, cut off the plaintext and add back the ciphertext. - self.encoder.truncate(self.header.end); - self.encoder.encode(&ciphertext); qtrace!("Packet built {}", hex(&self.encoder)); Ok(self.encoder) } @@ -998,7 +1008,8 @@ mod tests { // The spec uses PN=1, but our crypto refuses to skip packet numbers. // So burn an encryption: - let burn = prot.encrypt(0, &[], &[]).expect("burn OK"); + let mut burn = [0; 16]; + prot.encrypt(0, 0..0, 0..0, &mut burn).expect("burn OK"); assert_eq!(burn.len(), prot.expansion()); let mut builder = PacketBuilder::long( From 03a4e154c6dadacb4bc1e4a43219dcc22e4312cc Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Thu, 23 Jan 2025 13:34:28 +0200 Subject: [PATCH 02/19] aead_null --- neqo-crypto/src/aead_null.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/neqo-crypto/src/aead_null.rs b/neqo-crypto/src/aead_null.rs index a74c89f35d..bc0174492b 100644 --- a/neqo-crypto/src/aead_null.rs +++ b/neqo-crypto/src/aead_null.rs @@ -4,7 +4,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::fmt; +use std::{fmt, ops::Range}; use crate::{ constants::{Cipher, Version}, @@ -46,6 +46,17 @@ impl AeadNull { Ok(&output[..l + 16]) } + #[allow(clippy::missing_errors_doc)] + pub fn encrypt_in_place( + &self, + _count: u64, + _aad: Range, + input: Range, + _data: &mut [u8], + ) -> Res { + Ok(input.len() + 16) + } + #[allow(clippy::missing_errors_doc)] pub fn decrypt<'a>( &self, From ff398b6fd12d5925c9d253e7d00a27eff039a3eb Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Thu, 23 Jan 2025 16:56:38 +0200 Subject: [PATCH 03/19] WIP decrypt --- neqo-common/src/datagram.rs | 12 +++- neqo-crypto/src/aead.rs | 37 +++++++++++- neqo-transport/src/connection/mod.rs | 57 ++++++++++++------- neqo-transport/src/crypto.rs | 19 +++++++ neqo-transport/src/packet/mod.rs | 85 +++++++++++++++------------- neqo-transport/src/server.rs | 29 +++++++--- 6 files changed, 167 insertions(+), 72 deletions(-) diff --git a/neqo-common/src/datagram.rs b/neqo-common/src/datagram.rs index c3b8713c69..d10118ec9f 100644 --- a/neqo-common/src/datagram.rs +++ b/neqo-common/src/datagram.rs @@ -4,7 +4,10 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::{net::SocketAddr, ops::Deref}; +use std::{ + net::SocketAddr, + ops::{Deref, DerefMut}, +}; use crate::{hex_with_len, IpTos}; @@ -47,7 +50,6 @@ impl> Datagram { } } -#[cfg(test)] impl + AsRef<[u8]>> AsMut<[u8]> for Datagram { fn as_mut(&mut self) -> &mut [u8] { self.d.as_mut() @@ -65,6 +67,12 @@ impl Datagram> { } } +impl + AsMut<[u8]>> DerefMut for Datagram { + fn deref_mut(&mut self) -> &mut Self::Target { + AsMut::<[u8]>::as_mut(self) + } +} + impl> Deref for Datagram { type Target = [u8]; #[must_use] diff --git a/neqo-crypto/src/aead.rs b/neqo-crypto/src/aead.rs index eb8bee0c70..c4af631301 100644 --- a/neqo-crypto/src/aead.rs +++ b/neqo-crypto/src/aead.rs @@ -126,7 +126,7 @@ impl RealAead { Ok(&output[0..(l.try_into()?)]) } - /// Encrypt a plaintext in place. + /// Encrypt `data` consisting of `aad` and plaintext `input` in place. /// /// The space provided in `data` needs to allow `Aead::expansion` more bytes to be appended. /// @@ -191,6 +191,41 @@ impl RealAead { }?; Ok(&output[0..(l.try_into()?)]) } + + /// Decrypt a ciphertext in place. + /// + /// Note that NSS insists upon having extra space available for decryption, so + /// the buffer for `output` should be the same length as `input`, even though + /// the final result will be shorter. + /// + /// # Errors + /// + /// If the input isn't authenticated or any input is too large for NSS. + pub fn decrypt_in_place( + &self, + count: u64, + aad: Range, + input: Range, + data: &mut [u8], + ) -> Res { + let aad = &data[aad]; + let input = &data[input]; + let mut l: c_uint = 0; + unsafe { + SSL_AeadDecrypt( + *self.ctx, + count, + aad.as_ptr(), + c_uint::try_from(aad.len())?, + input.as_ptr(), + c_uint::try_from(input.len())?, + input.as_ptr(), + &mut l, + c_uint::try_from(input.len())?, + ) + }?; + Ok(l.try_into()?) + } } impl fmt::Debug for RealAead { diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index 023545354b..c903be694d 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -22,7 +22,7 @@ use std::{ use neqo_common::{ event::Provider as EventProvider, hex, hex_snip_middle, hrtime, qdebug, qerror, qinfo, - qlog::NeqoQlog, qtrace, qwarn, Datagram, Decoder, Encoder, Role, + qlog::NeqoQlog, qtrace, qwarn, Datagram, Decoder, Encoder, IpTos, Role, }; use neqo_crypto::{ agent::CertificateInfo, Agent, AntiReplay, AuthenticationStatus, Cipher, Client, Group, @@ -1020,14 +1020,14 @@ impl Connection { } /// Process new input datagrams on the connection. - pub fn process_input(&mut self, d: Datagram>, now: Instant) { + pub fn process_input(&mut self, d: Datagram + AsMut<[u8]>>, now: Instant) { self.process_multiple_input(iter::once(d), now); } /// Process new input datagrams on the connection. pub fn process_multiple_input( &mut self, - dgrams: impl IntoIterator>>, + dgrams: impl IntoIterator + AsMut<[u8]>>>, now: Instant, ) { let mut dgrams = dgrams.into_iter().peekable(); @@ -1160,7 +1160,11 @@ impl Connection { /// Process input and generate output. #[must_use = "Output of the process function must be handled"] - pub fn process(&mut self, dgram: Option>>, now: Instant) -> Output { + pub fn process( + &mut self, + dgram: Option + AsMut<[u8]>>>, + now: Instant, + ) -> Output { if let Some(d) = dgram { self.input(d, now, now); self.process_saved(now); @@ -1502,7 +1506,8 @@ impl Connection { fn postprocess_packet( &mut self, path: &PathRef, - d: &Datagram>, + tos: IpTos, + remote: SocketAddr, packet: &PublicPacket, migrate: bool, now: Instant, @@ -1510,7 +1515,7 @@ impl Connection { let space = PacketNumberSpace::from(packet.packet_type()); if let Some(space) = self.acks.get_mut(space) { let space_ecn_marks = space.ecn_marks(); - *space_ecn_marks += d.tos().into(); + *space_ecn_marks += tos.into(); self.stats.borrow_mut().ecn_rx = *space_ecn_marks; } else { qtrace!("Not tracking ECN for dropped packet number space"); @@ -1521,7 +1526,7 @@ impl Connection { } if self.state.connected() { - self.handle_migration(path, d, migrate, now); + self.handle_migration(path, remote, migrate, now); } else if self.role != Role::Client && (packet.packet_type() == PacketType::Handshake || (packet.dcid().len() >= 8 && packet.dcid() == self.local_initial_source_cid)) @@ -1534,7 +1539,12 @@ impl Connection { /// Take a datagram as input. This reports an error if the packet was bad. /// This takes two times: when the datagram was received, and the current time. - fn input(&mut self, d: Datagram>, received: Instant, now: Instant) { + fn input( + &mut self, + d: Datagram + AsMut<[u8]>>, + received: Instant, + now: Instant, + ) { // First determine the path. let path = self.paths.find_path( d.destination(), @@ -1552,19 +1562,22 @@ impl Connection { fn input_path( &mut self, path: &PathRef, - d: Datagram>, + mut d: Datagram + AsMut<[u8]>>, now: Instant, ) -> Res<()> { - let mut slc = d.as_ref(); - let mut dcid = None; - qtrace!("[{self}] {} input {}", path.borrow(), hex(&d)); + let tos = d.tos(); + let remote = d.source(); + let len = d.len(); + let mut slc = d.as_mut(); + let mut dcid = None; let pto = path.borrow().rtt().pto(self.confirmed()); // Handle each packet in the datagram. while !slc.is_empty() { self.stats.borrow_mut().packets_rx += 1; - let (packet, remainder) = + let slc_len = slc.len(); + let (mut packet, remainder) = match PublicPacket::decode(slc, self.cid_manager.decoder().as_ref()) { Ok((packet, remainder)) => (packet, remainder), Err(e) => { @@ -1592,9 +1605,9 @@ impl Connection { "-> RX", payload.packet_type(), payload.pn(), - &payload[..], - d.tos(), - d.len(), + payload.as_ref(), + tos, + len, ); #[cfg(feature = "build-fuzzing-corpus")] @@ -1607,7 +1620,7 @@ impl Connection { neqo_common::write_item_to_fuzzing_corpus(target, &payload[..]); } - qlog::packet_received(&self.qlog, &packet, &payload, now); + // qlog::packet_received(&self.qlog, &packet, &payload, now); let space = PacketNumberSpace::from(payload.packet_type()); if let Some(space) = self.acks.get_mut(space) { if space.is_duplicate(payload.pn()) { @@ -1616,7 +1629,9 @@ impl Connection { } else { match self.process_packet(path, &payload, now) { Ok(migrate) => { - self.postprocess_packet(path, &d, &packet, migrate, now); + self.postprocess_packet( + path, tos, remote, &packet, migrate, now, + ); } Err(e) => { self.ensure_error_path(path, &packet, now); @@ -1637,7 +1652,7 @@ impl Connection { Error::KeysPending(cspace) => { // This packet can't be decrypted because we don't have the keys yet. // Don't check this packet for a stateless reset, just return. - let remaining = slc.len(); + let remaining = slc_len; self.save_datagram(cspace, d, remaining, now); return Ok(()); } @@ -1980,7 +1995,7 @@ impl Connection { fn handle_migration( &mut self, path: &PathRef, - d: &Datagram>, + remote: SocketAddr, migrate: bool, now: Instant, ) { @@ -1993,7 +2008,7 @@ impl Connection { if self.ensure_permanent(path, now).is_ok() { self.paths - .handle_migration(path, d.source(), now, &mut self.stats.borrow_mut()); + .handle_migration(path, remote, now, &mut self.stats.borrow_mut()); } else { qinfo!( "[{self}] {} Peer migrated, but no connection ID available", diff --git a/neqo-transport/src/crypto.rs b/neqo-transport/src/crypto.rs index ec7044df31..2f0a685154 100644 --- a/neqo-transport/src/crypto.rs +++ b/neqo-transport/src/crypto.rs @@ -685,6 +685,25 @@ impl CryptoDxState { Ok(res.to_vec()) } + pub fn decrypt_in_place( + &mut self, + pn: PacketNumber, + hdr: Range, + body: Range, + data: &mut [u8], + ) -> Res { + debug_assert_eq!(self.direction, CryptoDxDirection::Read); + qtrace!( + "[{self}] decrypt pn={pn} hdr={} body={}", + hex(data[hdr.clone()].as_ref()), + hex(data[body.clone()].as_ref()) + ); + self.invoked()?; + let len = self.aead.decrypt_in_place(pn, hdr, body, data)?; + self.used(pn)?; + Ok(len) + } + #[cfg(all(test, not(feature = "disable-encryption")))] pub(crate) fn test_default() -> Self { // This matches the value in packet.rs diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index 5cdaabdb84..b3bedd27ea 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -574,7 +574,7 @@ pub struct PublicPacket<'a> { /// Protocol version, if present in header. version: Option, /// A reference to the entire packet, including the header. - data: &'a [u8], + data: &'a mut [u8], } impl<'a> PublicPacket<'a> { @@ -621,7 +621,10 @@ impl<'a> PublicPacket<'a> { /// /// This will return an error if the packet could not be decoded. #[allow(clippy::similar_names)] - pub fn decode(data: &'a [u8], dcid_decoder: &dyn ConnectionIdDecoder) -> Res<(Self, &'a [u8])> { + pub fn decode( + data: &'a mut [u8], + dcid_decoder: &dyn ConnectionIdDecoder, + ) -> Res<(Self, &'a mut [u8])> { let mut decoder = Decoder::new(data); let first = Self::opt(decoder.decode_uint::())?; @@ -646,7 +649,7 @@ impl<'a> PublicPacket<'a> { version: None, data, }, - &[], + &mut [], )); } @@ -667,7 +670,7 @@ impl<'a> PublicPacket<'a> { version: None, data, }, - &[], + &mut [], )); } @@ -683,7 +686,7 @@ impl<'a> PublicPacket<'a> { version: Some(version), data, }, - &[], + &mut [], )); }; @@ -695,7 +698,7 @@ impl<'a> PublicPacket<'a> { // The type-specific code includes a token. This consumes the remainder of the packet. let (token, header_len) = Self::decode_long(&mut decoder, packet_type, version)?; let end = data.len() - decoder.remaining(); - let (data, remainder) = data.split_at(end); + let (data, remainder) = data.split_at_mut(end); Ok(( Self { packet_type, @@ -800,9 +803,9 @@ impl<'a> PublicPacket<'a> { /// Decrypt the header of the packet. fn decrypt_header( - &self, + &mut self, crypto: &CryptoDxState, - ) -> Res<(bool, PacketNumber, Vec, &'a [u8])> { + ) -> Res<(bool, PacketNumber, Range, Range)> { assert_ne!(self.packet_type, PacketType::Retry); assert_ne!(self.packet_type, PacketType::VersionNegotiation); @@ -824,23 +827,23 @@ impl<'a> PublicPacket<'a> { let first_byte = self.data[0] ^ (mask[0] & bits); // Make a copy of the header to work on. - let mut hdrbytes = self.data[..self.header_len + 4].to_vec(); - hdrbytes[0] = first_byte; + let mut hdrbytes = 0..self.header_len + 4; + self.data[0] = first_byte; // Unmask the PN. let mut pn_encoded: u64 = 0; for i in 0..MAX_PACKET_NUMBER_LEN { - hdrbytes[self.header_len + i] ^= mask[1 + i]; + self.data[self.header_len + i] ^= mask[1 + i]; pn_encoded <<= 8; - pn_encoded += u64::from(hdrbytes[self.header_len + i]); + pn_encoded += u64::from(self.data[self.header_len + i]); } // Now decode the packet number length and apply it, hopefully in constant time. let pn_len = usize::from((first_byte & 0x3) + 1); - hdrbytes.truncate(self.header_len + pn_len); + hdrbytes.end = self.header_len + pn_len; pn_encoded >>= 8 * (MAX_PACKET_NUMBER_LEN - pn_len); - qtrace!("unmasked hdr={}", hex(&hdrbytes)); + qtrace!("unmasked hdr={}", hex(&self.data[hdrbytes.clone()])); let key_phase = self.packet_type == PacketType::Short && (first_byte & PACKET_BIT_KEY_PHASE) == PACKET_BIT_KEY_PHASE; @@ -849,14 +852,18 @@ impl<'a> PublicPacket<'a> { key_phase, pn, hdrbytes, - &self.data[self.header_len + pn_len..], + self.header_len + pn_len..self.data.len(), )) } /// # Errors /// /// This will return an error if the packet cannot be decrypted. - pub fn decrypt(&self, crypto: &mut CryptoStates, release_at: Instant) -> Res { + pub fn decrypt( + &mut self, + crypto: &mut CryptoStates, + release_at: Instant, + ) -> Res { let cspace: CryptoSpace = self.packet_type.into(); // When we don't have a version, the crypto code doesn't need a version // for lookup, so use the default, but fix it up if decryption succeeds. @@ -874,7 +881,7 @@ impl<'a> PublicPacket<'a> { return Err(Error::DecryptError); }; let version = rx.version(); // Version fixup; see above. - let d = rx.decrypt(pn, &header, body)?; + let len = rx.decrypt_in_place(pn, header, body, &mut self.data)?; // If this is the first packet ever successfully decrypted // using `rx`, make sure to initiate a key update. if rx.needs_update() { @@ -885,7 +892,7 @@ impl<'a> PublicPacket<'a> { version, pt: self.packet_type, pn, - data: d, + data: &self.data[..len], }) } else if crypto.rx_pending(cspace) { Err(Error::KeysPending(cspace)) @@ -925,14 +932,14 @@ impl fmt::Debug for PublicPacket<'_> { } } -pub struct DecryptedPacket { +pub struct DecryptedPacket<'a> { version: Version, pt: PacketType, pn: PacketNumber, - data: Vec, + data: &'a [u8], } -impl DecryptedPacket { +impl DecryptedPacket<'_> { #[must_use] pub const fn version(&self) -> Version { self.version @@ -949,7 +956,7 @@ impl DecryptedPacket { } } -impl Deref for DecryptedPacket { +impl Deref for DecryptedPacket<'_> { type Target = [u8]; fn deref(&self) -> &Self::Target { @@ -1031,9 +1038,9 @@ mod tests { const EXTRA: &[u8] = &[0xce; 33]; fixture_init(); - let mut padded = SAMPLE_INITIAL.to_vec(); + let mut padded = SAMPLE_INITIAL; padded.extend_from_slice(EXTRA); - let (packet, remainder) = PublicPacket::decode(&padded, &cid_mgr()).unwrap(); + let (packet, remainder) = PublicPacket::decode(&mut padded, &cid_mgr()).unwrap(); assert_eq!(packet.packet_type(), PacketType::Initial); assert_eq!(&packet.dcid()[..], &[] as &[u8]); assert_eq!(&packet.scid()[..], SERVER_CID); @@ -1055,7 +1062,7 @@ mod tests { enc.encode_vec(1, &[]); enc.encode(&[0xff; 40]); // junk - assert!(PublicPacket::decode(enc.as_ref(), &cid_mgr()).is_err()); + assert!(PublicPacket::decode(enc.as_mut(), &cid_mgr()).is_err()); } #[test] @@ -1067,7 +1074,7 @@ mod tests { enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 2]); enc.encode(&[0xff; 40]); // junk - assert!(PublicPacket::decode(enc.as_ref(), &cid_mgr()).is_err()); + assert!(PublicPacket::decode(enc.as_mut(), &cid_mgr()).is_err()); } const SAMPLE_SHORT: &[u8] = &[ @@ -1114,7 +1121,7 @@ mod tests { #[test] fn decode_short() { fixture_init(); - let (packet, remainder) = PublicPacket::decode(SAMPLE_SHORT, &cid_mgr()).unwrap(); + let (packet, remainder) = PublicPacket::decode(&mut SAMPLE_SHORT, &cid_mgr()).unwrap(); assert_eq!(packet.packet_type(), PacketType::Short); assert!(remainder.is_empty()); let decrypted = packet @@ -1129,7 +1136,7 @@ mod tests { fn decode_short_bad_cid() { fixture_init(); let (packet, remainder) = PublicPacket::decode( - SAMPLE_SHORT, + &mut SAMPLE_SHORT, &RandomConnectionIdGenerator::new(SERVER_CID.len() - 1), ) .unwrap(); @@ -1144,7 +1151,7 @@ mod tests { #[test] fn decode_short_long_cid() { assert!(PublicPacket::decode( - SAMPLE_SHORT, + &mut SAMPLE_SHORT, &RandomConnectionIdGenerator::new(SERVER_CID.len() + 1) ) .is_err()); @@ -1299,7 +1306,7 @@ mod tests { let retry = PacketBuilder::retry(version, &[], SERVER_CID, RETRY_TOKEN, CLIENT_CID).unwrap(); - let (packet, remainder) = PublicPacket::decode(&retry, &cid_mgr()).unwrap(); + let (packet, remainder) = PublicPacket::decode(&mut &retry, &cid_mgr()).unwrap(); assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID))); assert!(remainder.is_empty()); @@ -1348,7 +1355,7 @@ mod tests { fn decode_retry(version: Version, sample_retry: &[u8]) { fixture_init(); let (packet, remainder) = - PublicPacket::decode(sample_retry, &RandomConnectionIdGenerator::new(5)).unwrap(); + PublicPacket::decode(&mut sample_retry, &RandomConnectionIdGenerator::new(5)).unwrap(); assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID))); assert_eq!(Some(version), packet.version()); assert!(packet.dcid().is_empty()); @@ -1381,28 +1388,28 @@ mod tests { assert!(PublicPacket::decode(&[], &cid_mgr).is_err()); - let (packet, remainder) = PublicPacket::decode(SAMPLE_RETRY_V1, &cid_mgr).unwrap(); + let (packet, remainder) = PublicPacket::decode(&mut SAMPLE_RETRY_V1, &cid_mgr).unwrap(); assert!(remainder.is_empty()); assert!(packet.is_valid_retry(&odcid)); let mut damaged_retry = SAMPLE_RETRY_V1.to_vec(); let last = damaged_retry.len() - 1; damaged_retry[last] ^= 66; - let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap(); + let (packet, remainder) = PublicPacket::decode(&mut &damaged_retry, &cid_mgr).unwrap(); assert!(remainder.is_empty()); assert!(!packet.is_valid_retry(&odcid)); damaged_retry.truncate(last); - let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap(); + let (packet, remainder) = PublicPacket::decode(&mut &damaged_retry, &cid_mgr).unwrap(); assert!(remainder.is_empty()); assert!(!packet.is_valid_retry(&odcid)); // An invalid token should be rejected sooner. damaged_retry.truncate(last - 4); - assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err()); + assert!(PublicPacket::decode(&mut &damaged_retry, &cid_mgr).is_err()); damaged_retry.truncate(last - 1); - assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err()); + assert!(PublicPacket::decode(&mut &damaged_retry, &cid_mgr).is_err()); } const SAMPLE_VN: &[u8] = &[ @@ -1444,7 +1451,7 @@ mod tests { #[test] fn parse_vn() { let (packet, remainder) = - PublicPacket::decode(SAMPLE_VN, &EmptyConnectionIdGenerator::default()).unwrap(); + PublicPacket::decode(&mut SAMPLE_VN, &EmptyConnectionIdGenerator::default()).unwrap(); assert!(remainder.is_empty()); assert_eq!(&packet.dcid[..], SERVER_CID); assert!(packet.scid.is_some()); @@ -1465,7 +1472,7 @@ mod tests { enc.encode_uint(4, 0x5a6a_7a8a_u64); let (packet, remainder) = - PublicPacket::decode(enc.as_ref(), &EmptyConnectionIdGenerator::default()).unwrap(); + PublicPacket::decode(enc.as_mut(), &EmptyConnectionIdGenerator::default()).unwrap(); assert!(remainder.is_empty()); assert_eq!(&packet.dcid[..], BIG_DCID); assert!(packet.scid.is_some()); @@ -1501,7 +1508,7 @@ mod tests { ]; fixture_init(); let (packet, slice) = - PublicPacket::decode(PACKET, &EmptyConnectionIdGenerator::default()).unwrap(); + PublicPacket::decode(&mut PACKET, &EmptyConnectionIdGenerator::default()).unwrap(); assert!(slice.is_empty()); let decrypted = packet .decrypt(&mut CryptoStates::test_chacha(), now()) diff --git a/neqo-transport/src/server.rs b/neqo-transport/src/server.rs index 9f3056a93e..17e2c0b6b3 100644 --- a/neqo-transport/src/server.rs +++ b/neqo-transport/src/server.rs @@ -196,7 +196,7 @@ impl Server { fn handle_initial( &mut self, initial: InitialDetails, - dgram: Datagram>, + dgram: Datagram + AsMut<[u8]>>, now: Instant, ) -> Output { qdebug!("[{self}] Handle initial"); @@ -307,7 +307,7 @@ impl Server { fn accept_connection( &mut self, initial: InitialDetails, - dgram: Datagram>, + dgram: Datagram + AsMut<[u8]>>, orig_dcid: Option, now: Instant, ) -> Output { @@ -349,12 +349,19 @@ impl Server { } } - fn process_input(&mut self, dgram: Datagram>, now: Instant) -> Output { + fn process_input( + &mut self, + mut dgram: Datagram + AsMut<[u8]>>, + now: Instant, + ) -> Output { qtrace!("Process datagram: {}", hex(&dgram[..])); // This is only looking at the first packet header in the datagram. // All packets in the datagram are routed to the same connection. - let res = PublicPacket::decode(&dgram[..], self.cid_generator.borrow().as_decoder()); + let len = dgram.len(); + let destination = dgram.destination(); + let source = dgram.source(); + let res = PublicPacket::decode(&mut dgram[..], self.cid_generator.borrow().as_decoder()); let Ok((packet, _remainder)) = res else { qtrace!("[{self}] Discarding {dgram:?}"); return Output::None; @@ -383,7 +390,7 @@ impl Server { .all() .contains(&packet.version().unwrap())) { - if dgram.len() < MIN_INITIAL_PACKET_SIZE { + if len < MIN_INITIAL_PACKET_SIZE { qdebug!("[{self}] Unsupported version: too short"); return Output::None; } @@ -399,8 +406,8 @@ impl Server { "[{self}] type={:?} path:{} {}->{} {:?} len {}", PacketType::VersionNegotiation, packet.dcid(), - dgram.destination(), - dgram.source(), + destination, + source, IpTos::default(), vn.len(), ); @@ -422,7 +429,7 @@ impl Server { match packet.packet_type() { PacketType::Initial => { - if dgram.len() < MIN_INITIAL_PACKET_SIZE { + if len < MIN_INITIAL_PACKET_SIZE { qdebug!("[{self}] Drop initial: too short"); return Output::None; } @@ -470,7 +477,11 @@ impl Server { } #[must_use] - pub fn process(&mut self, dgram: Option>>, now: Instant) -> Output { + pub fn process( + &mut self, + dgram: Option + AsMut<[u8]>>>, + now: Instant, + ) -> Output { let out = dgram .map_or(Output::None, |d| self.process_input(d, now)) .or_else(|| self.process_next_output(now)); From bdd0900763b621155f3ce72084ac5daa133e1122 Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Fri, 24 Jan 2025 10:02:13 +0200 Subject: [PATCH 04/19] More --- neqo-transport/src/connection/mod.rs | 15 +++++++++++---- neqo-transport/src/packet/mod.rs | 4 ++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index c903be694d..f42fa10a1d 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -1242,7 +1242,12 @@ impl Connection { } } - fn is_stateless_reset(&self, path: &PathRef, d: &Datagram>) -> bool { + fn is_stateless_reset( + &self, + path: &PathRef, + d: &[u8], + // d: &Datagram> + ) -> bool { // If the datagram is too small, don't try. // If the connection is connected, then the reset token will be invalid. if d.len() < 16 || !self.state.connected() { @@ -1255,7 +1260,8 @@ impl Connection { fn check_stateless_reset( &mut self, path: &PathRef, - d: &Datagram>, + d: &[u8], + // d: &Datagram>, first: bool, now: Instant, ) -> Res<()> { @@ -1582,7 +1588,7 @@ impl Connection { Ok((packet, remainder)) => (packet, remainder), Err(e) => { qinfo!("[{self}] Garbage packet: {e}"); - qtrace!("[{self}] Garbage packet contents: {}", hex(slc)); + // qtrace!("[{self}] Garbage packet contents: {}", hex(slc)); self.stats.borrow_mut().pkt_dropped("Garbage packet"); break; } @@ -1671,9 +1677,10 @@ impl Connection { // Decryption failure, or not having keys is not fatal. // If the state isn't available, or we can't decrypt the packet, drop // the rest of the datagram on the floor, but don't generate an error. - self.check_stateless_reset(path, &d, dcid.is_none(), now)?; + self.check_stateless_reset(path, packet.data(), dcid.is_none(), now)?; self.stats.borrow_mut().pkt_dropped("Decryption failure"); qlog::packet_dropped(&self.qlog, &packet, now); + break; } } slc = remainder; diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index b3bedd27ea..73b47bccd8 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -786,6 +786,10 @@ impl<'a> PublicPacket<'a> { self.data.len() } + pub fn data(&self) -> &[u8] { + self.data + } + const fn decode_pn(expected: PacketNumber, pn: u64, w: usize) -> PacketNumber { let window = 1_u64 << (w * 8); let candidate = (expected & !(window - 1)) | pn; From 1488a4eec6000e0e4cbaf953b587a333149b4139 Mon Sep 17 00:00:00 2001 From: Max Leonard Inden Date: Sun, 26 Jan 2025 18:58:43 +0100 Subject: [PATCH 05/19] fix(transport/packet): don't (mutably) borrow data multiple times --- neqo-transport/src/packet/mod.rs | 48 +++++++++++++++++--------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index 73b47bccd8..a476b1a655 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -563,12 +563,12 @@ pub struct PublicPacket<'a> { /// The packet type. packet_type: PacketType, /// The recovered destination connection ID. - dcid: ConnectionIdRef<'a>, + dcid: ConnectionId, /// The source connection ID, if this is a long header packet. - scid: Option>, + scid: Option, /// Any token that is included in the packet (Retry always has a token; Initial sometimes /// does). This is empty when there is no token. - token: &'a [u8], + token: Vec, /// The size of the header, not including the packet number. header_len: usize, /// Protocol version, if present in header. @@ -624,9 +624,9 @@ impl<'a> PublicPacket<'a> { pub fn decode( data: &'a mut [u8], dcid_decoder: &dyn ConnectionIdDecoder, - ) -> Res<(Self, &'a mut [u8])> { + ) -> Res<(PublicPacket<'a>, &'a mut [u8])> { let mut decoder = Decoder::new(data); - let first = Self::opt(decoder.decode_uint::())?; + let first = PublicPacket::opt(decoder.decode_uint::())?; if first & 0x80 == PACKET_BIT_SHORT { // Conveniently, this also guarantees that there is enough space @@ -634,17 +634,18 @@ impl<'a> PublicPacket<'a> { if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE { return Err(Error::InvalidPacket); } - let dcid = Self::opt(dcid_decoder.decode_cid(&mut decoder))?; + let dcid = PublicPacket::opt(dcid_decoder.decode_cid(&mut decoder))?.into(); if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE { return Err(Error::InvalidPacket); } let header_len = decoder.offset(); + return Ok(( - Self { + PublicPacket { packet_type: PacketType::Short, dcid, scid: None, - token: &[], + token: vec![], header_len, version: None, data, @@ -654,18 +655,18 @@ impl<'a> PublicPacket<'a> { } // Generic long header. - let version = Self::opt(decoder.decode_uint())?; - let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?); - let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?); + let version = PublicPacket::opt(decoder.decode_uint())?; + let dcid = ConnectionIdRef::from(PublicPacket::opt(decoder.decode_vec(1))?).into(); + let scid = ConnectionIdRef::from(PublicPacket::opt(decoder.decode_vec(1))?).into(); // Version negotiation. if version == 0 { return Ok(( - Self { + PublicPacket { packet_type: PacketType::VersionNegotiation, dcid, scid: Some(scid), - token: &[], + token: vec![], header_len: decoder.offset(), version: None, data, @@ -677,11 +678,11 @@ impl<'a> PublicPacket<'a> { // Check that this is a long header from a supported version. let Ok(version) = Version::try_from(version) else { return Ok(( - Self { + PublicPacket { packet_type: PacketType::OtherVersion, dcid, scid: Some(scid), - token: &[], + token: vec![], header_len: decoder.offset(), version: Some(version), data, @@ -696,11 +697,12 @@ impl<'a> PublicPacket<'a> { let packet_type = PacketType::from_byte((first >> 4) & 3, version); // The type-specific code includes a token. This consumes the remainder of the packet. - let (token, header_len) = Self::decode_long(&mut decoder, packet_type, version)?; + let (token, header_len) = PublicPacket::decode_long(&mut decoder, packet_type, version)?; + let token = token.to_vec(); let end = data.len() - decoder.remaining(); let (data, remainder) = data.split_at_mut(end); Ok(( - Self { + PublicPacket { packet_type, dcid, scid: Some(scid), @@ -751,22 +753,24 @@ impl<'a> PublicPacket<'a> { } #[must_use] - pub const fn dcid(&self) -> ConnectionIdRef<'a> { - self.dcid + pub fn dcid(&self) -> ConnectionIdRef { + self.dcid.as_cid_ref() } /// # Panics /// /// This will panic if called for a short header packet. #[must_use] - pub fn scid(&self) -> ConnectionIdRef<'a> { + pub fn scid(&self) -> ConnectionIdRef { self.scid + .as_ref() .expect("should only be called for long header packets") + .as_cid_ref() } #[must_use] - pub const fn token(&self) -> &'a [u8] { - self.token + pub fn token(&self) -> &[u8] { + &self.token } #[must_use] From b0fcf23de8304c6bc89da9781164456e76caa7ae Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Wed, 29 Jan 2025 09:02:50 +0200 Subject: [PATCH 06/19] Fixes --- neqo-bin/src/client/http09.rs | 2 +- neqo-bin/src/client/http3.rs | 2 +- neqo-bin/src/client/mod.rs | 2 +- neqo-bin/src/server/http09.rs | 2 +- neqo-bin/src/server/http3.rs | 2 +- neqo-bin/src/server/mod.rs | 4 +- neqo-common/src/datagram.rs | 4 +- neqo-http3/src/connection_client.rs | 10 ++-- neqo-http3/src/server.rs | 6 ++- neqo-transport/src/connection/mod.rs | 4 +- neqo-transport/src/crypto.rs | 2 +- neqo-transport/src/packet/mod.rs | 78 +++++++++++++++------------- neqo-udp/src/lib.rs | 12 ++--- 13 files changed, 73 insertions(+), 57 deletions(-) diff --git a/neqo-bin/src/client/http09.rs b/neqo-bin/src/client/http09.rs index 21b0cf2883..e44a949d34 100644 --- a/neqo-bin/src/client/http09.rs +++ b/neqo-bin/src/client/http09.rs @@ -189,7 +189,7 @@ impl super::Client for Connection { fn process_multiple_input<'a>( &mut self, - dgrams: impl IntoIterator>, + dgrams: impl IntoIterator>, now: Instant, ) { self.process_multiple_input(dgrams, now); diff --git a/neqo-bin/src/client/http3.rs b/neqo-bin/src/client/http3.rs index ac4b22dba1..b00c227b66 100644 --- a/neqo-bin/src/client/http3.rs +++ b/neqo-bin/src/client/http3.rs @@ -137,7 +137,7 @@ impl super::Client for Http3Client { fn process_multiple_input<'a>( &mut self, - dgrams: impl IntoIterator>, + dgrams: impl IntoIterator>, now: Instant, ) { self.process_multiple_input(dgrams, now); diff --git a/neqo-bin/src/client/mod.rs b/neqo-bin/src/client/mod.rs index 9da2800d80..648278a13d 100644 --- a/neqo-bin/src/client/mod.rs +++ b/neqo-bin/src/client/mod.rs @@ -383,7 +383,7 @@ trait Client { fn process_output(&mut self, now: Instant) -> Output; fn process_multiple_input<'a>( &mut self, - dgrams: impl IntoIterator>, + dgrams: impl IntoIterator>, now: Instant, ); fn has_events(&self) -> bool; diff --git a/neqo-bin/src/server/http09.rs b/neqo-bin/src/server/http09.rs index 7c4fb792db..b1efc66f7b 100644 --- a/neqo-bin/src/server/http09.rs +++ b/neqo-bin/src/server/http09.rs @@ -185,7 +185,7 @@ impl HttpServer { } impl super::HttpServer for HttpServer { - fn process(&mut self, dgram: Option>, now: Instant) -> Output { + fn process(&mut self, dgram: Option>, now: Instant) -> Output { self.server.process(dgram, now) } diff --git a/neqo-bin/src/server/http3.rs b/neqo-bin/src/server/http3.rs index c22b95c6fd..aed25e9b8c 100644 --- a/neqo-bin/src/server/http3.rs +++ b/neqo-bin/src/server/http3.rs @@ -80,7 +80,7 @@ impl Display for HttpServer { } impl super::HttpServer for HttpServer { - fn process(&mut self, dgram: Option>, now: Instant) -> neqo_http3::Output { + fn process(&mut self, dgram: Option>, now: Instant) -> neqo_http3::Output { self.server.process(dgram, now) } diff --git a/neqo-bin/src/server/mod.rs b/neqo-bin/src/server/mod.rs index 2d8db660eb..7a65f2c547 100644 --- a/neqo-bin/src/server/mod.rs +++ b/neqo-bin/src/server/mod.rs @@ -192,7 +192,7 @@ fn qns_read_response(filename: &str) -> Result, io::Error> { #[allow(clippy::module_name_repetitions)] pub trait HttpServer: Display { - fn process(&mut self, dgram: Option>, now: Instant) -> Output; + fn process(&mut self, dgram: Option>, now: Instant) -> Output; fn process_events(&mut self, now: Instant); fn has_events(&self) -> bool; } @@ -242,7 +242,7 @@ impl ServerRunner { timeout: &mut Option>>, sockets: &mut [(SocketAddr, crate::udp::Socket)], now: &dyn Fn() -> Instant, - mut input_dgram: Option>, + mut input_dgram: Option>, ) -> Result<(), io::Error> { loop { match server.process(input_dgram.take(), now()) { diff --git a/neqo-common/src/datagram.rs b/neqo-common/src/datagram.rs index d10118ec9f..9bea6e0e4e 100644 --- a/neqo-common/src/datagram.rs +++ b/neqo-common/src/datagram.rs @@ -94,9 +94,9 @@ impl> std::fmt::Debug for Datagram { } } -impl<'a> Datagram<&'a [u8]> { +impl<'a> Datagram<&'a mut [u8]> { #[must_use] - pub const fn from_slice(src: SocketAddr, dst: SocketAddr, tos: IpTos, d: &'a [u8]) -> Self { + pub const fn from_slice(src: SocketAddr, dst: SocketAddr, tos: IpTos, d: &'a mut [u8]) -> Self { Self { src, dst, tos, d } } diff --git a/neqo-http3/src/connection_client.rs b/neqo-http3/src/connection_client.rs index 5641dc9401..05bbeaa357 100644 --- a/neqo-http3/src/connection_client.rs +++ b/neqo-http3/src/connection_client.rs @@ -841,7 +841,11 @@ impl Http3Client { } /// This function combines `process_input` and `process_output` function. - pub fn process(&mut self, dgram: Option>>, now: Instant) -> Output { + pub fn process( + &mut self, + dgram: Option + AsMut<[u8]>>>, + now: Instant, + ) -> Output { qtrace!("[{self}] Process"); if let Some(d) = dgram { self.process_input(d, now); @@ -859,13 +863,13 @@ impl Http3Client { /// packets need to be sent or if a timer needs to be updated. /// /// [1]: ../neqo_transport/enum.ConnectionEvent.html - pub fn process_input(&mut self, dgram: Datagram>, now: Instant) { + pub fn process_input(&mut self, dgram: Datagram + AsMut<[u8]>>, now: Instant) { self.process_multiple_input(iter::once(dgram), now); } pub fn process_multiple_input( &mut self, - dgrams: impl IntoIterator>>, + dgrams: impl IntoIterator + AsMut<[u8]>>>, now: Instant, ) { let mut dgrams = dgrams.into_iter().peekable(); diff --git a/neqo-http3/src/server.rs b/neqo-http3/src/server.rs index ad8025f55f..d20fd5455a 100644 --- a/neqo-http3/src/server.rs +++ b/neqo-http3/src/server.rs @@ -118,7 +118,11 @@ impl Http3Server { self.process(None::, now) } - pub fn process(&mut self, dgram: Option>>, now: Instant) -> Output { + pub fn process( + &mut self, + dgram: Option + AsMut<[u8]>>>, + now: Instant, + ) -> Output { qtrace!("[{self}] Process"); let out = self.server.process(dgram, now); self.process_http3(now); diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index 3e347054dd..678bc0b581 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -1253,7 +1253,7 @@ impl Connection { if d.len() < 16 || !self.state.connected() { return false; } - <&[u8; 16]>::try_from(&d.as_ref()[d.len() - 16..]) + <&[u8; 16]>::try_from(&d[d.len() - 16..]) .is_ok_and(|token| path.borrow().is_stateless_reset(token)) } @@ -1626,7 +1626,7 @@ impl Connection { neqo_common::write_item_to_fuzzing_corpus(target, &payload[..]); } - // qlog::packet_received(&self.qlog, &packet, &payload, now); + // FIXME: add back: qlog::packet_received(&self.qlog, &packet, &payload, now); let space = PacketNumberSpace::from(payload.packet_type()); if let Some(space) = self.acks.get_mut(space) { if space.is_duplicate(payload.pn()) { diff --git a/neqo-transport/src/crypto.rs b/neqo-transport/src/crypto.rs index fd272f0288..e1d0b3b867 100644 --- a/neqo-transport/src/crypto.rs +++ b/neqo-transport/src/crypto.rs @@ -696,7 +696,7 @@ impl CryptoDxState { ) -> Res { debug_assert_eq!(self.direction, CryptoDxDirection::Read); qtrace!( - "[{self}] decrypt pn={pn} hdr={} body={}", + "[{self}] decrypt_in_place pn={pn} hdr={} body={}", hex(data[hdr.clone()].as_ref()), hex(data[body.clone()].as_ref()) ); diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index a476b1a655..aa7809315f 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -624,7 +624,7 @@ impl<'a> PublicPacket<'a> { pub fn decode( data: &'a mut [u8], dcid_decoder: &dyn ConnectionIdDecoder, - ) -> Res<(PublicPacket<'a>, &'a mut [u8])> { + ) -> Res<(Self, &'a mut [u8])> { let mut decoder = Decoder::new(data); let first = PublicPacket::opt(decoder.decode_uint::())?; @@ -834,20 +834,21 @@ impl<'a> PublicPacket<'a> { }; let first_byte = self.data[0] ^ (mask[0] & bits); - // Make a copy of the header to work on. let mut hdrbytes = 0..self.header_len + 4; self.data[0] = first_byte; // Unmask the PN. let mut pn_encoded: u64 = 0; + let mut pn_bytes = + self.data[self.header_len..self.header_len + MAX_PACKET_NUMBER_LEN].to_vec(); for i in 0..MAX_PACKET_NUMBER_LEN { - self.data[self.header_len + i] ^= mask[1 + i]; + pn_bytes[i] ^= mask[1 + i]; pn_encoded <<= 8; - pn_encoded += u64::from(self.data[self.header_len + i]); + pn_encoded += u64::from(pn_bytes[i]); } - // Now decode the packet number length and apply it, hopefully in constant time. let pn_len = usize::from((first_byte & 0x3) + 1); + self.data[self.header_len..self.header_len + pn_len].copy_from_slice(&pn_bytes[..pn_len]); hdrbytes.end = self.header_len + pn_len; pn_encoded >>= 8 * (MAX_PACKET_NUMBER_LEN - pn_len); @@ -889,7 +890,7 @@ impl<'a> PublicPacket<'a> { return Err(Error::DecryptError); }; let version = rx.version(); // Version fixup; see above. - let len = rx.decrypt_in_place(pn, header, body, &mut self.data)?; + let len = rx.decrypt_in_place(pn, header, body, self.data)?; // If this is the first packet ever successfully decrypted // using `rx`, make sure to initiate a key update. if rx.needs_update() { @@ -900,7 +901,7 @@ impl<'a> PublicPacket<'a> { version, pt: self.packet_type, pn, - data: &self.data[..len], + data: &self.data[self.header_len + 1..self.header_len + 1 + len], }) } else if crypto.rx_pending(cspace) { Err(Error::KeysPending(cspace)) @@ -968,7 +969,7 @@ impl Deref for DecryptedPacket<'_> { type Target = [u8]; fn deref(&self) -> &Self::Target { - &self.data[..] + self.data } } @@ -1046,9 +1047,9 @@ mod tests { const EXTRA: &[u8] = &[0xce; 33]; fixture_init(); - let mut padded = SAMPLE_INITIAL; + let mut padded = SAMPLE_INITIAL.to_vec(); padded.extend_from_slice(EXTRA); - let (packet, remainder) = PublicPacket::decode(&mut padded, &cid_mgr()).unwrap(); + let (mut packet, remainder) = PublicPacket::decode(&mut padded, &cid_mgr()).unwrap(); assert_eq!(packet.packet_type(), PacketType::Initial); assert_eq!(&packet.dcid()[..], &[] as &[u8]); assert_eq!(&packet.scid()[..], SERVER_CID); @@ -1129,7 +1130,8 @@ mod tests { #[test] fn decode_short() { fixture_init(); - let (packet, remainder) = PublicPacket::decode(&mut SAMPLE_SHORT, &cid_mgr()).unwrap(); + let mut sample_short = SAMPLE_SHORT.to_vec(); + let (mut packet, remainder) = PublicPacket::decode(&mut sample_short, &cid_mgr()).unwrap(); assert_eq!(packet.packet_type(), PacketType::Short); assert!(remainder.is_empty()); let decrypted = packet @@ -1143,8 +1145,9 @@ mod tests { #[test] fn decode_short_bad_cid() { fixture_init(); - let (packet, remainder) = PublicPacket::decode( - &mut SAMPLE_SHORT, + let mut sample_short = SAMPLE_SHORT.to_vec(); + let (mut packet, remainder) = PublicPacket::decode( + &mut sample_short, &RandomConnectionIdGenerator::new(SERVER_CID.len() - 1), ) .unwrap(); @@ -1158,8 +1161,9 @@ mod tests { /// Saying that the connection ID is longer causes the initial decode to fail. #[test] fn decode_short_long_cid() { + let mut sample_short = SAMPLE_SHORT.to_vec(); assert!(PublicPacket::decode( - &mut SAMPLE_SHORT, + &mut sample_short, &RandomConnectionIdGenerator::new(SERVER_CID.len() + 1) ) .is_err()); @@ -1311,10 +1315,10 @@ mod tests { fn build_retry_single(version: Version, sample_retry: &[u8]) { fixture_init(); - let retry = + let mut retry = PacketBuilder::retry(version, &[], SERVER_CID, RETRY_TOKEN, CLIENT_CID).unwrap(); - let (packet, remainder) = PublicPacket::decode(&mut &retry, &cid_mgr()).unwrap(); + let (packet, remainder) = PublicPacket::decode(&mut retry, &cid_mgr()).unwrap(); assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID))); assert!(remainder.is_empty()); @@ -1360,10 +1364,10 @@ mod tests { } } - fn decode_retry(version: Version, sample_retry: &[u8]) { + fn decode_retry(version: Version, sample_retry: &mut [u8]) { fixture_init(); let (packet, remainder) = - PublicPacket::decode(&mut sample_retry, &RandomConnectionIdGenerator::new(5)).unwrap(); + PublicPacket::decode(sample_retry, &RandomConnectionIdGenerator::new(5)).unwrap(); assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID))); assert_eq!(Some(version), packet.version()); assert!(packet.dcid().is_empty()); @@ -1374,17 +1378,20 @@ mod tests { #[test] fn decode_retry_v2() { - decode_retry(Version::Version2, SAMPLE_RETRY_V2); + let mut sample_retry_v2 = SAMPLE_RETRY_V2.to_vec(); + decode_retry(Version::Version2, &mut sample_retry_v2); } #[test] fn decode_retry_v1() { - decode_retry(Version::Version1, SAMPLE_RETRY_V1); + let mut sample_retry_v1 = SAMPLE_RETRY_V1.to_vec(); + decode_retry(Version::Version1, &mut sample_retry_v1); } #[test] fn decode_retry_29() { - decode_retry(Version::Draft29, SAMPLE_RETRY_29); + let mut sample_retry_29 = SAMPLE_RETRY_29.to_vec(); + decode_retry(Version::Draft29, &mut sample_retry_29); } /// Check some packets that are clearly not valid Retry packets. @@ -1394,30 +1401,31 @@ mod tests { let cid_mgr = RandomConnectionIdGenerator::new(5); let odcid = ConnectionId::from(CLIENT_CID); - assert!(PublicPacket::decode(&[], &cid_mgr).is_err()); + assert!(PublicPacket::decode(&mut [], &cid_mgr).is_err()); - let (packet, remainder) = PublicPacket::decode(&mut SAMPLE_RETRY_V1, &cid_mgr).unwrap(); + let mut sample_retry_v1 = SAMPLE_RETRY_V1.to_vec(); + let (packet, remainder) = PublicPacket::decode(&mut sample_retry_v1, &cid_mgr).unwrap(); assert!(remainder.is_empty()); assert!(packet.is_valid_retry(&odcid)); let mut damaged_retry = SAMPLE_RETRY_V1.to_vec(); let last = damaged_retry.len() - 1; damaged_retry[last] ^= 66; - let (packet, remainder) = PublicPacket::decode(&mut &damaged_retry, &cid_mgr).unwrap(); + let (packet, remainder) = PublicPacket::decode(&mut damaged_retry, &cid_mgr).unwrap(); assert!(remainder.is_empty()); assert!(!packet.is_valid_retry(&odcid)); damaged_retry.truncate(last); - let (packet, remainder) = PublicPacket::decode(&mut &damaged_retry, &cid_mgr).unwrap(); + let (packet, remainder) = PublicPacket::decode(&mut damaged_retry, &cid_mgr).unwrap(); assert!(remainder.is_empty()); assert!(!packet.is_valid_retry(&odcid)); // An invalid token should be rejected sooner. damaged_retry.truncate(last - 4); - assert!(PublicPacket::decode(&mut &damaged_retry, &cid_mgr).is_err()); + assert!(PublicPacket::decode(&mut damaged_retry, &cid_mgr).is_err()); damaged_retry.truncate(last - 1); - assert!(PublicPacket::decode(&mut &damaged_retry, &cid_mgr).is_err()); + assert!(PublicPacket::decode(&mut damaged_retry, &cid_mgr).is_err()); } const SAMPLE_VN: &[u8] = &[ @@ -1458,8 +1466,9 @@ mod tests { #[test] fn parse_vn() { + let mut sample_vn = SAMPLE_VN.to_vec(); let (packet, remainder) = - PublicPacket::decode(&mut SAMPLE_VN, &EmptyConnectionIdGenerator::default()).unwrap(); + PublicPacket::decode(&mut sample_vn, &EmptyConnectionIdGenerator::default()).unwrap(); assert!(remainder.is_empty()); assert_eq!(&packet.dcid[..], SERVER_CID); assert!(packet.scid.is_some()); @@ -1515,8 +1524,9 @@ mod tests { 0x5d, 0x79, 0x99, 0xc2, 0x5a, 0x5b, 0xfb, ]; fixture_init(); - let (packet, slice) = - PublicPacket::decode(&mut PACKET, &EmptyConnectionIdGenerator::default()).unwrap(); + let mut packet = PACKET.to_vec(); + let (mut packet, slice) = + PublicPacket::decode(&mut packet, &EmptyConnectionIdGenerator::default()).unwrap(); assert!(slice.is_empty()); let decrypted = packet .decrypt(&mut CryptoStates::test_chacha(), now()) @@ -1529,17 +1539,15 @@ mod tests { #[test] fn decode_empty() { neqo_crypto::init().unwrap(); - let res = PublicPacket::decode(&[], &EmptyConnectionIdGenerator::default()); + let res = PublicPacket::decode(&mut [], &EmptyConnectionIdGenerator::default()); assert!(res.is_err()); } #[test] fn decode_too_short() { neqo_crypto::init().unwrap(); - let res = PublicPacket::decode( - &[179, 255, 0, 0, 29, 0, 0], - &EmptyConnectionIdGenerator::default(), - ); + let mut data = [179, 255, 0, 0, 29, 0, 0]; + let res = PublicPacket::decode(&mut data, &EmptyConnectionIdGenerator::default()); assert!(res.is_err()); } } diff --git a/neqo-udp/src/lib.rs b/neqo-udp/src/lib.rs index 615f62cdf4..88fe5a61da 100644 --- a/neqo-udp/src/lib.rs +++ b/neqo-udp/src/lib.rs @@ -11,7 +11,7 @@ use std::{ io::{self, IoSliceMut}, iter, net::SocketAddr, - slice::{self, Chunks}, + slice::{self, ChunksMut}, }; use log::{log_enabled, Level}; @@ -120,7 +120,7 @@ pub fn recv_inner<'a>( Ok(DatagramIter { current_buffer: None, - remaining_buffers: metas.into_iter().zip(recv_buf.0.iter()).take(n), + remaining_buffers: metas.into_iter().zip(recv_buf.0.iter_mut()).take(n), local_address, }) } @@ -128,17 +128,17 @@ pub fn recv_inner<'a>( pub struct DatagramIter<'a> { /// The current buffer, containing zero or more datagrams, each sharing the /// same [`RecvMeta`]. - current_buffer: Option<(RecvMeta, Chunks<'a, u8>)>, + current_buffer: Option<(RecvMeta, ChunksMut<'a, u8>)>, /// Remaining buffers, each containing zero or more datagrams, one /// [`RecvMeta`] per buffer. remaining_buffers: - iter::Take, slice::Iter<'a, Vec>>>, + iter::Take, slice::IterMut<'a, Vec>>>, /// The local address of the UDP socket used to receive the datagrams. local_address: SocketAddr, } impl<'a> Iterator for DatagramIter<'a> { - type Item = Datagram<&'a [u8]>; + type Item = Datagram<&'a mut [u8]>; fn next(&mut self) -> Option { loop { @@ -177,7 +177,7 @@ impl<'a> Iterator for DatagramIter<'a> { // Got another buffer. Let's chunk it into datagrams and return the // first datagram in the next loop iteration. - self.current_buffer = Some((meta, buf[0..meta.len].chunks(meta.stride))); + self.current_buffer = Some((meta, buf[0..meta.len].chunks_mut(meta.stride))); } } } From 8f940332c83b5e4713c2fabfa02efc6c75cd8f89 Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Wed, 29 Jan 2025 18:20:38 +0200 Subject: [PATCH 07/19] Minimize --- neqo-transport/src/connection/mod.rs | 9 +-------- neqo-transport/src/crypto.rs | 16 +--------------- neqo-transport/src/packet/mod.rs | 2 +- 3 files changed, 3 insertions(+), 24 deletions(-) diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index d61f0f3ad1..9dd3a323e9 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -1242,12 +1242,7 @@ impl Connection { } } - fn is_stateless_reset( - &self, - path: &PathRef, - d: &[u8], - // d: &Datagram> - ) -> bool { + fn is_stateless_reset(&self, path: &PathRef, d: &[u8]) -> bool { // If the datagram is too small, don't try. // If the connection is connected, then the reset token will be invalid. if d.len() < 16 || !self.state.connected() { @@ -1261,7 +1256,6 @@ impl Connection { &mut self, path: &PathRef, d: &[u8], - // d: &Datagram>, first: bool, now: Instant, ) -> Res<()> { @@ -1588,7 +1582,6 @@ impl Connection { Ok((packet, remainder)) => (packet, remainder), Err(e) => { qinfo!("[{self}] Garbage packet: {e}"); - // qtrace!("[{self}] Garbage packet contents: {}", hex(slc)); self.stats.borrow_mut().pkt_dropped("Garbage packet"); break; } diff --git a/neqo-transport/src/crypto.rs b/neqo-transport/src/crypto.rs index e1d0b3b867..033084b24f 100644 --- a/neqo-transport/src/crypto.rs +++ b/neqo-transport/src/crypto.rs @@ -673,21 +673,7 @@ impl CryptoDxState { self.aead.expansion() } - pub fn decrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res> { - debug_assert_eq!(self.direction, CryptoDxDirection::Read); - qtrace!( - "[{self}] decrypt pn={pn} hdr={} body={}", - hex(hdr), - hex(body) - ); - self.invoked()?; - let mut out = vec![0; body.len()]; - let res = self.aead.decrypt(pn, hdr, body, &mut out)?; - self.used(pn)?; - Ok(res.to_vec()) - } - - pub fn decrypt_in_place( + pub fn decrypt( &mut self, pn: PacketNumber, hdr: Range, diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index 106a137e20..b540a0fcbb 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -890,7 +890,7 @@ impl<'a> PublicPacket<'a> { return Err(Error::DecryptError); }; let version = rx.version(); // Version fixup; see above. - let len = rx.decrypt_in_place(pn, header.clone(), body, self.data)?; + let len = rx.decrypt(pn, header.clone(), body, self.data)?; // If this is the first packet ever successfully decrypted // using `rx`, make sure to initiate a key update. if rx.needs_update() { From 5d8c5dcd45b8f07d3143c02d230a495bf21c93f5 Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Wed, 29 Jan 2025 19:05:54 +0200 Subject: [PATCH 08/19] Some suggestions from @martinthomson --- neqo-crypto/src/aead.rs | 38 +++++++++++++++++--------------- neqo-transport/src/crypto.rs | 10 ++++----- neqo-transport/src/packet/mod.rs | 4 ++-- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/neqo-crypto/src/aead.rs b/neqo-crypto/src/aead.rs index c4af631301..389597be59 100644 --- a/neqo-crypto/src/aead.rs +++ b/neqo-crypto/src/aead.rs @@ -18,6 +18,7 @@ use crate::{ p11::{PK11SymKey, SymKey}, scoped_ptr, ssl::{PRUint16, PRUint64, PRUint8, SSLAeadContext}, + Error, }; experimental_api!(SSL_MakeAead( @@ -132,7 +133,8 @@ impl RealAead { /// /// # Errors /// - /// If the input can't be protected or any input is too large for NSS. + /// If the input can't be protected or any input is too large for NSS, or `aad` or `input` are + /// not valid ranges into `data` pub fn encrypt_in_place( &self, count: u64, @@ -140,8 +142,8 @@ impl RealAead { input: Range, data: &mut [u8], ) -> Res { - let aad = &data[aad]; - let input = &data[input]; + let aad = data.get(aad).ok_or(Error::AeadError)?; + let input = data.get(input).ok_or(Error::AeadError)?; let mut l: c_uint = 0; unsafe { SSL_AeadEncrypt( @@ -194,37 +196,37 @@ impl RealAead { /// Decrypt a ciphertext in place. /// - /// Note that NSS insists upon having extra space available for decryption, so - /// the buffer for `output` should be the same length as `input`, even though - /// the final result will be shorter. - /// /// # Errors /// - /// If the input isn't authenticated or any input is too large for NSS. - pub fn decrypt_in_place( + /// If the input isn't authenticated or any input is too large for NSS, or `aad` or `input` are + /// not valid ranges into `data` + pub fn decrypt_in_place<'a>( &self, count: u64, aad: Range, input: Range, - data: &mut [u8], - ) -> Res { - let aad = &data[aad]; - let input = &data[input]; + data: &'a mut [u8], + ) -> Res<&'a mut [u8]> { + let aad = data.get(aad).ok_or(Error::AeadError)?; + let inp = data.get(input.clone()).ok_or(Error::AeadError)?; let mut l: c_uint = 0; unsafe { + // Note that NSS insists upon having extra space available for decryption, so + // the buffer for `output` should be the same length as `input`, even though + // the final result will be shorter. SSL_AeadDecrypt( *self.ctx, count, aad.as_ptr(), c_uint::try_from(aad.len())?, - input.as_ptr(), - c_uint::try_from(input.len())?, - input.as_ptr(), + inp.as_ptr(), + c_uint::try_from(inp.len())?, + inp.as_ptr(), &mut l, - c_uint::try_from(input.len())?, + c_uint::try_from(inp.len())?, ) }?; - Ok(l.try_into()?) + Ok(&mut data[input.start..input.start + usize::try_from(l)?]) } } diff --git a/neqo-transport/src/crypto.rs b/neqo-transport/src/crypto.rs index 033084b24f..76b5c3b7d2 100644 --- a/neqo-transport/src/crypto.rs +++ b/neqo-transport/src/crypto.rs @@ -673,13 +673,13 @@ impl CryptoDxState { self.aead.expansion() } - pub fn decrypt( + pub fn decrypt<'a>( &mut self, pn: PacketNumber, hdr: Range, body: Range, - data: &mut [u8], - ) -> Res { + data: &'a mut [u8], + ) -> Res<&'a mut [u8]> { debug_assert_eq!(self.direction, CryptoDxDirection::Read); qtrace!( "[{self}] decrypt_in_place pn={pn} hdr={} body={}", @@ -687,9 +687,9 @@ impl CryptoDxState { hex(data[body.clone()].as_ref()) ); self.invoked()?; - let len = self.aead.decrypt_in_place(pn, hdr, body, data)?; + let data = self.aead.decrypt_in_place(pn, hdr, body, data)?; self.used(pn)?; - Ok(len) + Ok(data) } #[cfg(all(test, not(feature = "disable-encryption")))] diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index b540a0fcbb..d2548304f4 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -890,7 +890,7 @@ impl<'a> PublicPacket<'a> { return Err(Error::DecryptError); }; let version = rx.version(); // Version fixup; see above. - let len = rx.decrypt(pn, header.clone(), body, self.data)?; + let d = rx.decrypt(pn, header, body, self.data)?; // If this is the first packet ever successfully decrypted // using `rx`, make sure to initiate a key update. if rx.needs_update() { @@ -901,7 +901,7 @@ impl<'a> PublicPacket<'a> { version, pt: self.packet_type, pn, - data: &self.data[header.end..header.end + len], + data: d, }) } else if crypto.rx_pending(cspace) { Err(Error::KeysPending(cspace)) From 3285712acdc2d890fb375cd534b28f0c37723548 Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Wed, 29 Jan 2025 19:12:35 +0200 Subject: [PATCH 09/19] More --- neqo-common/src/datagram.rs | 2 +- neqo-crypto/src/aead.rs | 2 +- neqo-transport/src/packet/mod.rs | 4 +--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/neqo-common/src/datagram.rs b/neqo-common/src/datagram.rs index 9bea6e0e4e..e671d70382 100644 --- a/neqo-common/src/datagram.rs +++ b/neqo-common/src/datagram.rs @@ -96,7 +96,7 @@ impl> std::fmt::Debug for Datagram { impl<'a> Datagram<&'a mut [u8]> { #[must_use] - pub const fn from_slice(src: SocketAddr, dst: SocketAddr, tos: IpTos, d: &'a mut [u8]) -> Self { + pub fn from_slice(src: SocketAddr, dst: SocketAddr, tos: IpTos, d: &'a mut [u8]) -> Self { Self { src, dst, tos, d } } diff --git a/neqo-crypto/src/aead.rs b/neqo-crypto/src/aead.rs index 389597be59..59496ce8de 100644 --- a/neqo-crypto/src/aead.rs +++ b/neqo-crypto/src/aead.rs @@ -140,7 +140,7 @@ impl RealAead { count: u64, aad: Range, input: Range, - data: &mut [u8], + data: &[u8], ) -> Res { let aad = data.get(aad).ok_or(Error::AeadError)?; let input = data.get(input).ok_or(Error::AeadError)?; diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index d2548304f4..45955bd181 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -421,9 +421,7 @@ impl PacketBuilder { // Add space for crypto expansion. let data_end = self.encoder.len(); - for _i in 0..crypto.expansion() { - self.encode_byte(123); - } + self.pad_to(data_end + crypto.expansion(), 0); let ciphertext_len = crypto.encrypt( self.pn, From 2745a34c8291c8e7bb83fe83a730ec95859845a8 Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Wed, 29 Jan 2025 19:51:53 +0200 Subject: [PATCH 10/19] More --- fuzz/fuzz_targets/packet.rs | 3 ++- neqo-transport/src/qlog.rs | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/fuzz/fuzz_targets/packet.rs b/fuzz/fuzz_targets/packet.rs index 91a9af4223..40f59ba4f0 100644 --- a/fuzz/fuzz_targets/packet.rs +++ b/fuzz/fuzz_targets/packet.rs @@ -14,7 +14,8 @@ fuzz_target!(|data: &[u8]| { neqo_crypto::init().unwrap(); // Run the fuzzer - _ = PublicPacket::decode(data, decoder); + let mut d = data.to_vec(); + _ = PublicPacket::decode(&mut d, decoder); }); #[cfg(any(not(fuzzing), windows))] diff --git a/neqo-transport/src/qlog.rs b/neqo-transport/src/qlog.rs index 81060e890b..e181655cfc 100644 --- a/neqo-transport/src/qlog.rs +++ b/neqo-transport/src/qlog.rs @@ -294,6 +294,7 @@ pub fn packets_lost(qlog: &NeqoQlog, pkts: &[SentPacket], now: Instant) { }); } +#[allow(dead_code)] // FIXME pub fn packet_received( qlog: &NeqoQlog, public_packet: &PublicPacket, From 006eb03b18f0d55900cdf8d4c1699d359b9321e1 Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Thu, 30 Jan 2025 13:02:46 +0200 Subject: [PATCH 11/19] Fix `AeadNull` --- neqo-crypto/src/aead_null.rs | 57 +++++++++++++++++++++++--------- neqo-transport/src/packet/mod.rs | 1 + 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/neqo-crypto/src/aead_null.rs b/neqo-crypto/src/aead_null.rs index bc0174492b..28dfc204b7 100644 --- a/neqo-crypto/src/aead_null.rs +++ b/neqo-crypto/src/aead_null.rs @@ -42,8 +42,8 @@ impl AeadNull { ) -> Res<&'a [u8]> { let l = input.len(); output[..l].copy_from_slice(input); - output[l..l + 16].copy_from_slice(AEAD_NULL_TAG); - Ok(&output[..l + 16]) + output[l..l + self.expansion()].copy_from_slice(AEAD_NULL_TAG); + Ok(&output[..l + self.expansion()]) } #[allow(clippy::missing_errors_doc)] @@ -52,24 +52,18 @@ impl AeadNull { _count: u64, _aad: Range, input: Range, - _data: &mut [u8], + data: &mut [u8], ) -> Res { - Ok(input.len() + 16) + data[input.end..input.end + self.expansion()].copy_from_slice(AEAD_NULL_TAG); + Ok(input.len() + self.expansion()) } - #[allow(clippy::missing_errors_doc)] - pub fn decrypt<'a>( - &self, - _count: u64, - _aad: &[u8], - input: &[u8], - output: &'a mut [u8], - ) -> Res<&'a [u8]> { - if input.len() < AEAD_NULL_TAG.len() { + fn decrypt_check(&self, _count: u64, _aad: &[u8], input: &[u8]) -> Res { + if input.len() < self.expansion() { return Err(Error::from(SEC_ERROR_BAD_DATA)); } - let len_encrypted = input.len() - AEAD_NULL_TAG.len(); + let len_encrypted = input.len() - self.expansion(); // Check that: // 1) expansion is all zeros and // 2) if the encrypted data is also supplied that at least some values are no zero @@ -77,12 +71,43 @@ impl AeadNull { if &input[len_encrypted..] == AEAD_NULL_TAG && (len_encrypted == 0 || input[..len_encrypted].iter().any(|x| *x != 0x0)) { - output[..len_encrypted].copy_from_slice(&input[..len_encrypted]); - Ok(&output[..len_encrypted]) + Ok(len_encrypted) } else { Err(Error::from(SEC_ERROR_BAD_DATA)) } } + + #[allow(clippy::missing_errors_doc)] + pub fn decrypt<'a>( + &self, + count: u64, + aad: &[u8], + input: &[u8], + output: &'a mut [u8], + ) -> Res<&'a [u8]> { + self.decrypt_check(count, aad, input).map(|len| { + output[..len].copy_from_slice(&input[..len]); + &output[..len] + }) + } + + #[allow(clippy::missing_errors_doc)] + pub fn decrypt_in_place<'a>( + &self, + count: u64, + aad: Range, + input: Range, + data: &'a mut [u8], + ) -> Res<&'a mut [u8]> { + let aad = data + .get(aad) + .ok_or_else(|| Error::from(SEC_ERROR_BAD_DATA))?; + let inp = data + .get(input.clone()) + .ok_or_else(|| Error::from(SEC_ERROR_BAD_DATA))?; + self.decrypt_check(count, aad, inp) + .map(move |len| &mut data[input.start..input.start + len]) + } } impl fmt::Debug for AeadNull { diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index 45955bd181..24a1b2d285 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -788,6 +788,7 @@ impl<'a> PublicPacket<'a> { self.data.len() } + #[must_use] pub fn data(&self) -> &[u8] { self.data } From 16d598e4486bb24a8d077582eead4643f448a49b Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Thu, 30 Jan 2025 14:56:24 +0200 Subject: [PATCH 12/19] More suggestions from @martinthomson --- fuzz/fuzz_targets/packet.rs | 3 +- neqo-common/src/datagram.rs | 2 +- neqo-crypto/src/aead.rs | 55 ++++++++++++-------------------- neqo-crypto/src/aead_null.rs | 28 +++++----------- neqo-transport/src/crypto.rs | 30 ++++++++--------- neqo-transport/src/packet/mod.rs | 22 ++++--------- 6 files changed, 50 insertions(+), 90 deletions(-) diff --git a/fuzz/fuzz_targets/packet.rs b/fuzz/fuzz_targets/packet.rs index 40f59ba4f0..f887f7b6f0 100644 --- a/fuzz/fuzz_targets/packet.rs +++ b/fuzz/fuzz_targets/packet.rs @@ -14,8 +14,7 @@ fuzz_target!(|data: &[u8]| { neqo_crypto::init().unwrap(); // Run the fuzzer - let mut d = data.to_vec(); - _ = PublicPacket::decode(&mut d, decoder); + _ = PublicPacket::decode(&mut data.to_vec(), decoder); }); #[cfg(any(not(fuzzing), windows))] diff --git a/neqo-common/src/datagram.rs b/neqo-common/src/datagram.rs index e671d70382..9bea6e0e4e 100644 --- a/neqo-common/src/datagram.rs +++ b/neqo-common/src/datagram.rs @@ -96,7 +96,7 @@ impl> std::fmt::Debug for Datagram { impl<'a> Datagram<&'a mut [u8]> { #[must_use] - pub fn from_slice(src: SocketAddr, dst: SocketAddr, tos: IpTos, d: &'a mut [u8]) -> Self { + pub const fn from_slice(src: SocketAddr, dst: SocketAddr, tos: IpTos, d: &'a mut [u8]) -> Self { Self { src, dst, tos, d } } diff --git a/neqo-crypto/src/aead.rs b/neqo-crypto/src/aead.rs index 59496ce8de..f1868480d7 100644 --- a/neqo-crypto/src/aead.rs +++ b/neqo-crypto/src/aead.rs @@ -6,7 +6,7 @@ use std::{ fmt, - ops::{Deref, DerefMut, Range}, + ops::{Deref, DerefMut}, os::raw::{c_char, c_uint}, ptr::null_mut, }; @@ -18,7 +18,6 @@ use crate::{ p11::{PK11SymKey, SymKey}, scoped_ptr, ssl::{PRUint16, PRUint64, PRUint8, SSLAeadContext}, - Error, }; experimental_api!(SSL_MakeAead( @@ -124,26 +123,17 @@ impl RealAead { c_uint::try_from(output.len())?, ) }?; - Ok(&output[0..(l.try_into()?)]) + Ok(&output[..l.try_into()?]) } - /// Encrypt `data` consisting of `aad` and plaintext `input` in place. + /// Encrypt `data` consisting of `aad` and plaintext `data` in place. /// /// The space provided in `data` needs to allow `Aead::expansion` more bytes to be appended. /// /// # Errors /// - /// If the input can't be protected or any input is too large for NSS, or `aad` or `input` are - /// not valid ranges into `data` - pub fn encrypt_in_place( - &self, - count: u64, - aad: Range, - input: Range, - data: &[u8], - ) -> Res { - let aad = data.get(aad).ok_or(Error::AeadError)?; - let input = data.get(input).ok_or(Error::AeadError)?; + /// If the input can't be protected or any input is too large for NSS. + pub fn encrypt_in_place(&self, count: u64, aad: &[u8], data: &[u8]) -> Res { let mut l: c_uint = 0; unsafe { SSL_AeadEncrypt( @@ -151,11 +141,11 @@ impl RealAead { count, aad.as_ptr(), c_uint::try_from(aad.len())?, - input.as_ptr(), - c_uint::try_from(input.len())?, - input.as_ptr(), + data.as_ptr(), + c_uint::try_from(data.len() - self.expansion())?, + data.as_ptr(), &mut l, - c_uint::try_from(input.len() + self.expansion())?, + c_uint::try_from(data.len())?, ) }?; Ok(l.try_into()?) @@ -163,10 +153,6 @@ impl RealAead { /// Decrypt a ciphertext. /// - /// Note that NSS insists upon having extra space available for decryption, so - /// the buffer for `output` should be the same length as `input`, even though - /// the final result will be shorter. - /// /// # Errors /// /// If the input isn't authenticated or any input is too large for NSS. @@ -179,6 +165,9 @@ impl RealAead { ) -> Res<&'a [u8]> { let mut l: c_uint = 0; unsafe { + // Note that NSS insists upon having extra space available for decryption, so + // the buffer for `output` should be the same length as `input`, even though + // the final result will be shorter. SSL_AeadDecrypt( *self.ctx, count, @@ -191,24 +180,20 @@ impl RealAead { c_uint::try_from(output.len())?, ) }?; - Ok(&output[0..(l.try_into()?)]) + Ok(&output[..l.try_into()?]) } /// Decrypt a ciphertext in place. /// /// # Errors /// - /// If the input isn't authenticated or any input is too large for NSS, or `aad` or `input` are - /// not valid ranges into `data` + /// If the input isn't authenticated or any input is too large for NSS. pub fn decrypt_in_place<'a>( &self, count: u64, - aad: Range, - input: Range, + aad: &[u8], data: &'a mut [u8], ) -> Res<&'a mut [u8]> { - let aad = data.get(aad).ok_or(Error::AeadError)?; - let inp = data.get(input.clone()).ok_or(Error::AeadError)?; let mut l: c_uint = 0; unsafe { // Note that NSS insists upon having extra space available for decryption, so @@ -219,14 +204,14 @@ impl RealAead { count, aad.as_ptr(), c_uint::try_from(aad.len())?, - inp.as_ptr(), - c_uint::try_from(inp.len())?, - inp.as_ptr(), + data.as_ptr(), + c_uint::try_from(data.len())?, + data.as_ptr(), &mut l, - c_uint::try_from(inp.len())?, + c_uint::try_from(data.len())?, ) }?; - Ok(&mut data[input.start..input.start + usize::try_from(l)?]) + Ok(&mut data[..l.try_into()?]) } } diff --git a/neqo-crypto/src/aead_null.rs b/neqo-crypto/src/aead_null.rs index 28dfc204b7..9e02da0e7b 100644 --- a/neqo-crypto/src/aead_null.rs +++ b/neqo-crypto/src/aead_null.rs @@ -4,7 +4,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::{fmt, ops::Range}; +use std::fmt; use crate::{ constants::{Cipher, Version}, @@ -47,15 +47,10 @@ impl AeadNull { } #[allow(clippy::missing_errors_doc)] - pub fn encrypt_in_place( - &self, - _count: u64, - _aad: Range, - input: Range, - data: &mut [u8], - ) -> Res { - data[input.end..input.end + self.expansion()].copy_from_slice(AEAD_NULL_TAG); - Ok(input.len() + self.expansion()) + pub fn encrypt_in_place(&self, _count: u64, _aad: &[u8], data: &mut [u8]) -> Res { + let pos = data.len() - self.expansion(); + data[pos..].copy_from_slice(AEAD_NULL_TAG); + Ok(data.len()) } fn decrypt_check(&self, _count: u64, _aad: &[u8], input: &[u8]) -> Res { @@ -95,18 +90,11 @@ impl AeadNull { pub fn decrypt_in_place<'a>( &self, count: u64, - aad: Range, - input: Range, + aad: &[u8], data: &'a mut [u8], ) -> Res<&'a mut [u8]> { - let aad = data - .get(aad) - .ok_or_else(|| Error::from(SEC_ERROR_BAD_DATA))?; - let inp = data - .get(input.clone()) - .ok_or_else(|| Error::from(SEC_ERROR_BAD_DATA))?; - self.decrypt_check(count, aad, inp) - .map(move |len| &mut data[input.start..input.start + len]) + self.decrypt_check(count, aad, data) + .map(move |len| &mut data[..len]) } } diff --git a/neqo-transport/src/crypto.rs b/neqo-transport/src/crypto.rs index 76b5c3b7d2..8f091ce7dd 100644 --- a/neqo-transport/src/crypto.rs +++ b/neqo-transport/src/crypto.rs @@ -635,32 +635,30 @@ impl CryptoDxState { self.used_pn.end } - pub fn encrypt( - &mut self, - pn: PacketNumber, - hdr: Range, - body: Range, - data: &mut [u8], - ) -> Res { + pub fn encrypt(&mut self, pn: PacketNumber, hdr: Range, data: &mut [u8]) -> Res { debug_assert_eq!(self.direction, CryptoDxDirection::Write); qtrace!( "[{self}] encrypt_in_place pn={pn} hdr={} body={}", hex(data[hdr.clone()].as_ref()), - hex(data[body.clone()].as_ref()) + hex(data[hdr.end..].as_ref()) ); // The numbers in `Self::limit` assume a maximum packet size of `LIMIT`. // Adjust them as we encounter larger packets. - debug_assert!(body.len() < 65536); - if body.len() > self.largest_packet_len { + let body_len = data.len() - hdr.len() - self.aead.expansion(); + debug_assert!(body_len <= u16::MAX.into()); + if body_len > self.largest_packet_len { let new_bits = usize::leading_zeros(self.largest_packet_len - 1) - - usize::leading_zeros(body.len() - 1); + - usize::leading_zeros(body_len - 1); self.invocations >>= new_bits; - self.largest_packet_len = body.len(); + self.largest_packet_len = body_len; } self.invoked()?; - let len = self.aead.encrypt_in_place(pn, hdr, body, data)?; + let (prev, data) = data.split_at_mut(hdr.end); + // `prev` may have already-encrypted packets this one is being coalesced with. + // Use only the actual current header for AAD. + let len = self.aead.encrypt_in_place(pn, &prev[hdr], data)?; qtrace!("[{self}] encrypt ct={}", hex(data)); debug_assert_eq!(pn, self.next_pn()); @@ -677,17 +675,17 @@ impl CryptoDxState { &mut self, pn: PacketNumber, hdr: Range, - body: Range, data: &'a mut [u8], ) -> Res<&'a mut [u8]> { debug_assert_eq!(self.direction, CryptoDxDirection::Read); qtrace!( "[{self}] decrypt_in_place pn={pn} hdr={} body={}", hex(data[hdr.clone()].as_ref()), - hex(data[body.clone()].as_ref()) + hex(data[hdr.end..].as_ref()) ); self.invoked()?; - let data = self.aead.decrypt_in_place(pn, hdr, body, data)?; + let (hdr, data) = data.split_at_mut(hdr.end); + let data = self.aead.decrypt_in_place(pn, hdr, data)?; self.used(pn)?; Ok(data) } diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index 24a1b2d285..fed1fa474a 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -423,12 +423,7 @@ impl PacketBuilder { let data_end = self.encoder.len(); self.pad_to(data_end + crypto.expansion(), 0); - let ciphertext_len = crypto.encrypt( - self.pn, - self.header.clone(), - self.header.end..data_end, - self.encoder.as_mut(), - )?; + let ciphertext_len = crypto.encrypt(self.pn, self.header.clone(), self.encoder.as_mut())?; // Calculate the mask. let ciphertext = &self.encoder.as_mut()[self.header.end..self.header.end + ciphertext_len]; @@ -812,7 +807,7 @@ impl<'a> PublicPacket<'a> { fn decrypt_header( &mut self, crypto: &CryptoDxState, - ) -> Res<(bool, PacketNumber, Range, Range)> { + ) -> Res<(bool, PacketNumber, Range)> { assert_ne!(self.packet_type, PacketType::Retry); assert_ne!(self.packet_type, PacketType::VersionNegotiation); @@ -856,12 +851,7 @@ impl<'a> PublicPacket<'a> { let key_phase = self.packet_type == PacketType::Short && (first_byte & PACKET_BIT_KEY_PHASE) == PACKET_BIT_KEY_PHASE; let pn = Self::decode_pn(crypto.next_pn(), pn_encoded, pn_len); - Ok(( - key_phase, - pn, - hdrbytes, - self.header_len + pn_len..self.data.len(), - )) + Ok((key_phase, pn, hdrbytes)) } /// # Errors @@ -883,13 +873,13 @@ impl<'a> PublicPacket<'a> { // This is OK in this case because we the only reason this can // fail is if the cryptographic module is bad or the packet is // too small (which is public information). - let (key_phase, pn, header, body) = self.decrypt_header(rx)?; + let (key_phase, pn, header) = self.decrypt_header(rx)?; qtrace!("[{rx}] decoded header: {header:?}"); let Some(rx) = crypto.rx(version, cspace, key_phase) else { return Err(Error::DecryptError); }; let version = rx.version(); // Version fixup; see above. - let d = rx.decrypt(pn, header, body, self.data)?; + let d = rx.decrypt(pn, header, self.data)?; // If this is the first packet ever successfully decrypted // using `rx`, make sure to initiate a key update. if rx.needs_update() { @@ -1024,7 +1014,7 @@ mod tests { // The spec uses PN=1, but our crypto refuses to skip packet numbers. // So burn an encryption: let mut burn = [0; 16]; - prot.encrypt(0, 0..0, 0..0, &mut burn).expect("burn OK"); + prot.encrypt(0, 0..0, &mut burn).expect("burn OK"); assert_eq!(burn.len(), prot.expansion()); let mut builder = PacketBuilder::long( From 3b81fe50275317236a1a3416f9ec5b170d55f3a9 Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Thu, 30 Jan 2025 17:22:16 +0200 Subject: [PATCH 13/19] clippy --- neqo-common/src/datagram.rs | 2 +- neqo-transport/src/connection/mod.rs | 2 +- neqo-transport/src/fc.rs | 8 ++------ 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/neqo-common/src/datagram.rs b/neqo-common/src/datagram.rs index 9bea6e0e4e..e671d70382 100644 --- a/neqo-common/src/datagram.rs +++ b/neqo-common/src/datagram.rs @@ -96,7 +96,7 @@ impl> std::fmt::Debug for Datagram { impl<'a> Datagram<&'a mut [u8]> { #[must_use] - pub const fn from_slice(src: SocketAddr, dst: SocketAddr, tos: IpTos, d: &'a mut [u8]) -> Self { + pub fn from_slice(src: SocketAddr, dst: SocketAddr, tos: IpTos, d: &'a mut [u8]) -> Self { Self { src, dst, tos, d } } diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index 9dd3a323e9..e570efcd09 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -1610,7 +1610,7 @@ impl Connection { ); #[cfg(feature = "build-fuzzing-corpus")] - if packet.packet_type() == PacketType::Initial { + if payload.packet_type() == PacketType::Initial { let target = if self.role == Role::Client { "server_initial" } else { diff --git a/neqo-transport/src/fc.rs b/neqo-transport/src/fc.rs index b670350682..1e6c46105c 100644 --- a/neqo-transport/src/fc.rs +++ b/neqo-transport/src/fc.rs @@ -105,12 +105,8 @@ where /// This is `Some` with the active limit if `blocked` has been called, /// if a blocking frame has not been sent (or it has been lost), and /// if the blocking condition remains. - const fn blocked_needed(&self) -> Option { - if self.blocked_frame && self.limit < self.blocked_at { - Some(self.blocked_at - 1) - } else { - None - } + fn blocked_needed(&self) -> Option { + (self.blocked_frame && self.limit < self.blocked_at).then(|| self.blocked_at - 1) } /// Clear the need to send a blocked frame. From 5498522b32d528f649d065fd0e71e2bc4239cfe8 Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Thu, 30 Jan 2025 17:39:07 +0200 Subject: [PATCH 14/19] fixme --- neqo-transport/src/connection/mod.rs | 4 +++- neqo-transport/src/qlog.rs | 12 +++--------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index e570efcd09..7b91e32103 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -1559,6 +1559,7 @@ impl Connection { _ = self.capture_error(Some(path), now, 0, res); } + #[allow(clippy::too_many_lines)] // Will be addressed as part of https://github.com/mozilla/neqo/pull/2396 fn input_path( &mut self, path: &PathRef, @@ -1594,6 +1595,7 @@ impl Connection { qtrace!("[{self}] Received unverified packet {packet:?}"); + let packet_len = packet.len(); match packet.decrypt(&mut self.crypto.states, now + pto) { Ok(payload) => { // OK, we have a valid packet. @@ -1619,7 +1621,7 @@ impl Connection { neqo_common::write_item_to_fuzzing_corpus(target, &payload[..]); } - // FIXME: add back: qlog::packet_received(&self.qlog, &packet, &payload, now); + qlog::packet_received(&self.qlog, &payload, packet_len, now); let space = PacketNumberSpace::from(payload.packet_type()); if let Some(space) = self.acks.get_mut(space) { if space.is_duplicate(payload.pn()) { diff --git a/neqo-transport/src/qlog.rs b/neqo-transport/src/qlog.rs index e181655cfc..f5ebcc814d 100644 --- a/neqo-transport/src/qlog.rs +++ b/neqo-transport/src/qlog.rs @@ -294,26 +294,20 @@ pub fn packets_lost(qlog: &NeqoQlog, pkts: &[SentPacket], now: Instant) { }); } -#[allow(dead_code)] // FIXME -pub fn packet_received( - qlog: &NeqoQlog, - public_packet: &PublicPacket, - payload: &DecryptedPacket, - now: Instant, -) { +pub fn packet_received(qlog: &NeqoQlog, payload: &DecryptedPacket, len: usize, now: Instant) { qlog.add_event_data_with_instant( || { let mut d = Decoder::from(&payload[..]); let header = PacketHeader::with_type( - public_packet.packet_type().into(), + payload.packet_type().into(), Some(payload.pn()), None, None, None, ); let raw = RawInfo { - length: Some(public_packet.len() as u64), + length: Some(len as u64), payload_length: None, data: None, }; From 12e9dde50d835d2f25cc7fe5578dcf0c9e85fe3d Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Fri, 31 Jan 2025 11:48:17 +0200 Subject: [PATCH 15/19] Update neqo-crypto/src/aead.rs Co-authored-by: Martin Thomson Signed-off-by: Lars Eggert --- neqo-crypto/src/aead.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/neqo-crypto/src/aead.rs b/neqo-crypto/src/aead.rs index f1868480d7..9b7341df7a 100644 --- a/neqo-crypto/src/aead.rs +++ b/neqo-crypto/src/aead.rs @@ -128,7 +128,12 @@ impl RealAead { /// Encrypt `data` consisting of `aad` and plaintext `data` in place. /// - /// The space provided in `data` needs to allow `Aead::expansion` more bytes to be appended. + /// The last `Aead::expansion` of `data` is overwritten by the AEAD tag by this function. + /// Therefore, a buffer should be provided that is that much larger than the plaintext. + /// + /// # Panics + /// + /// If `data` is shorter than `::expansion()`. /// /// # Errors /// From e28fc82199ef918bbb76cd7b43d749e205d8d1cb Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Fri, 31 Jan 2025 11:48:30 +0200 Subject: [PATCH 16/19] Update neqo-crypto/src/aead.rs Co-authored-by: Martin Thomson Signed-off-by: Lars Eggert --- neqo-crypto/src/aead.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/neqo-crypto/src/aead.rs b/neqo-crypto/src/aead.rs index 9b7341df7a..02079a355a 100644 --- a/neqo-crypto/src/aead.rs +++ b/neqo-crypto/src/aead.rs @@ -189,6 +189,8 @@ impl RealAead { } /// Decrypt a ciphertext in place. + /// Returns a subslice of `data` (without the last `::expansion()` bytes), + /// that has been decrypted in place. /// /// # Errors /// From a244f52d88198efd687dabfa70c1b4a25e4b5b53 Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Fri, 31 Jan 2025 11:48:40 +0200 Subject: [PATCH 17/19] Update neqo-crypto/src/aead.rs Co-authored-by: Martin Thomson Signed-off-by: Lars Eggert --- neqo-crypto/src/aead.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/neqo-crypto/src/aead.rs b/neqo-crypto/src/aead.rs index 02079a355a..896d589a69 100644 --- a/neqo-crypto/src/aead.rs +++ b/neqo-crypto/src/aead.rs @@ -218,6 +218,7 @@ impl RealAead { c_uint::try_from(data.len())?, ) }?; + debug_assert_eq!(usize::try_from(l), data.len() - self.expansion()); Ok(&mut data[..l.try_into()?]) } } From 5690eeba021fdffa3b02fbf9be9948316868f1dc Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Fri, 31 Jan 2025 11:49:12 +0200 Subject: [PATCH 18/19] Update neqo-crypto/src/aead_null.rs Co-authored-by: Martin Thomson Signed-off-by: Lars Eggert --- neqo-crypto/src/aead_null.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/neqo-crypto/src/aead_null.rs b/neqo-crypto/src/aead_null.rs index 9e02da0e7b..41ec24f77c 100644 --- a/neqo-crypto/src/aead_null.rs +++ b/neqo-crypto/src/aead_null.rs @@ -58,7 +58,8 @@ impl AeadNull { return Err(Error::from(SEC_ERROR_BAD_DATA)); } - let len_encrypted = input.len() - self.expansion(); + let len_encrypted = input.len().checked_sub(self.expansion()) + .ok_or(Error::from(SEC_ERROR_BAD_DATA))?; // Check that: // 1) expansion is all zeros and // 2) if the encrypted data is also supplied that at least some values are no zero From f2c35584bf0e5b1c1e528cb2c623defe2dd2e3ad Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Fri, 31 Jan 2025 22:37:21 +0100 Subject: [PATCH 19/19] Minimize diff --- neqo-crypto/src/aead.rs | 11 ++++++++--- neqo-crypto/src/aead_null.rs | 15 +++++++++++---- neqo-transport/src/connection/mod.rs | 6 ++++-- neqo-transport/src/crypto.rs | 13 +++++++++---- neqo-transport/src/packet/mod.rs | 22 ++++++++++------------ 5 files changed, 42 insertions(+), 25 deletions(-) diff --git a/neqo-crypto/src/aead.rs b/neqo-crypto/src/aead.rs index 896d589a69..7fed060f75 100644 --- a/neqo-crypto/src/aead.rs +++ b/neqo-crypto/src/aead.rs @@ -138,7 +138,12 @@ impl RealAead { /// # Errors /// /// If the input can't be protected or any input is too large for NSS. - pub fn encrypt_in_place(&self, count: u64, aad: &[u8], data: &[u8]) -> Res { + pub fn encrypt_in_place<'a>( + &self, + count: u64, + aad: &[u8], + data: &'a mut [u8], + ) -> Res<&'a mut [u8]> { let mut l: c_uint = 0; unsafe { SSL_AeadEncrypt( @@ -153,7 +158,7 @@ impl RealAead { c_uint::try_from(data.len())?, ) }?; - Ok(l.try_into()?) + Ok(&mut data[..l.try_into()?]) } /// Decrypt a ciphertext. @@ -218,7 +223,7 @@ impl RealAead { c_uint::try_from(data.len())?, ) }?; - debug_assert_eq!(usize::try_from(l), data.len() - self.expansion()); + debug_assert_eq!(usize::try_from(l)?, data.len() - self.expansion()); Ok(&mut data[..l.try_into()?]) } } diff --git a/neqo-crypto/src/aead_null.rs b/neqo-crypto/src/aead_null.rs index 41ec24f77c..cbe4834d10 100644 --- a/neqo-crypto/src/aead_null.rs +++ b/neqo-crypto/src/aead_null.rs @@ -47,10 +47,15 @@ impl AeadNull { } #[allow(clippy::missing_errors_doc)] - pub fn encrypt_in_place(&self, _count: u64, _aad: &[u8], data: &mut [u8]) -> Res { + pub fn encrypt_in_place<'a>( + &self, + _count: u64, + _aad: &[u8], + data: &'a mut [u8], + ) -> Res<&'a mut [u8]> { let pos = data.len() - self.expansion(); data[pos..].copy_from_slice(AEAD_NULL_TAG); - Ok(data.len()) + Ok(data) } fn decrypt_check(&self, _count: u64, _aad: &[u8], input: &[u8]) -> Res { @@ -58,8 +63,10 @@ impl AeadNull { return Err(Error::from(SEC_ERROR_BAD_DATA)); } - let len_encrypted = input.len().checked_sub(self.expansion()) - .ok_or(Error::from(SEC_ERROR_BAD_DATA))?; + let len_encrypted = input + .len() + .checked_sub(self.expansion()) + .ok_or_else(|| Error::from(SEC_ERROR_BAD_DATA))?; // Check that: // 1) expansion is all zeros and // 2) if the encrypted data is also supplied that at least some values are no zero diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index 64d5b5a70c..7f72932b0c 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -1567,7 +1567,6 @@ impl Connection { qtrace!("[{self}] {} input {}", path.borrow(), hex(&d)); let tos = d.tos(); let remote = d.source(); - let len = d.len(); let mut slc = d.as_mut(); let mut dcid = None; let pto = path.borrow().rtt().pto(self.confirmed()); @@ -1598,7 +1597,10 @@ impl Connection { Ok(payload) => { // OK, we have a valid packet. self.idle_timeout.on_packet_received(now); - self.log_packet(packet::MetaData::new_in(path, tos, packet_len, &payload), now); + self.log_packet( + packet::MetaData::new_in(path, tos, packet_len, &payload), + now, + ); #[cfg(feature = "build-fuzzing-corpus")] if payload.packet_type() == PacketType::Initial { diff --git a/neqo-transport/src/crypto.rs b/neqo-transport/src/crypto.rs index 8f091ce7dd..d2c81f0c58 100644 --- a/neqo-transport/src/crypto.rs +++ b/neqo-transport/src/crypto.rs @@ -635,7 +635,12 @@ impl CryptoDxState { self.used_pn.end } - pub fn encrypt(&mut self, pn: PacketNumber, hdr: Range, data: &mut [u8]) -> Res { + pub fn encrypt<'a>( + &mut self, + pn: PacketNumber, + hdr: Range, + data: &'a mut [u8], + ) -> Res<&'a mut [u8]> { debug_assert_eq!(self.direction, CryptoDxDirection::Write); qtrace!( "[{self}] encrypt_in_place pn={pn} hdr={} body={}", @@ -658,12 +663,12 @@ impl CryptoDxState { let (prev, data) = data.split_at_mut(hdr.end); // `prev` may have already-encrypted packets this one is being coalesced with. // Use only the actual current header for AAD. - let len = self.aead.encrypt_in_place(pn, &prev[hdr], data)?; + let data = self.aead.encrypt_in_place(pn, &prev[hdr], data)?; - qtrace!("[{self}] encrypt ct={}", hex(data)); + qtrace!("[{self}] encrypt ct={}", hex(&data)); debug_assert_eq!(pn, self.next_pn()); self.used(pn)?; - Ok(len) + Ok(data) } #[must_use] diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index ab4692ba9a..a6687574d5 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -426,10 +426,8 @@ impl PacketBuilder { let data_end = self.encoder.len(); self.pad_to(data_end + crypto.expansion(), 0); - let ciphertext_len = crypto.encrypt(self.pn, self.header.clone(), self.encoder.as_mut())?; - // Calculate the mask. - let ciphertext = &self.encoder.as_mut()[self.header.end..self.header.end + ciphertext_len]; + let ciphertext = crypto.encrypt(self.pn, self.header.clone(), self.encoder.as_mut())?; let offset = SAMPLE_OFFSET - self.offsets.pn.len(); if offset + SAMPLE_SIZE > ciphertext.len() { return Err(Error::InternalError); @@ -622,7 +620,7 @@ impl<'a> PublicPacket<'a> { dcid_decoder: &dyn ConnectionIdDecoder, ) -> Res<(Self, &'a mut [u8])> { let mut decoder = Decoder::new(data); - let first = PublicPacket::opt(decoder.decode_uint::())?; + let first = Self::opt(decoder.decode_uint::())?; if first & 0x80 == PACKET_BIT_SHORT { // Conveniently, this also guarantees that there is enough space @@ -630,14 +628,14 @@ impl<'a> PublicPacket<'a> { if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE { return Err(Error::InvalidPacket); } - let dcid = PublicPacket::opt(dcid_decoder.decode_cid(&mut decoder))?.into(); + let dcid = Self::opt(dcid_decoder.decode_cid(&mut decoder))?.into(); if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE { return Err(Error::InvalidPacket); } let header_len = decoder.offset(); return Ok(( - PublicPacket { + Self { packet_type: PacketType::Short, dcid, scid: None, @@ -651,14 +649,14 @@ impl<'a> PublicPacket<'a> { } // Generic long header. - let version = PublicPacket::opt(decoder.decode_uint())?; - let dcid = ConnectionIdRef::from(PublicPacket::opt(decoder.decode_vec(1))?).into(); - let scid = ConnectionIdRef::from(PublicPacket::opt(decoder.decode_vec(1))?).into(); + let version = Self::opt(decoder.decode_uint())?; + let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?).into(); + let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?).into(); // Version negotiation. if version == 0 { return Ok(( - PublicPacket { + Self { packet_type: PacketType::VersionNegotiation, dcid, scid: Some(scid), @@ -674,7 +672,7 @@ impl<'a> PublicPacket<'a> { // Check that this is a long header from a supported version. let Ok(version) = Version::try_from(version) else { return Ok(( - PublicPacket { + Self { packet_type: PacketType::OtherVersion, dcid, scid: Some(scid), @@ -698,7 +696,7 @@ impl<'a> PublicPacket<'a> { let end = data.len() - decoder.remaining(); let (data, remainder) = data.split_at_mut(end); Ok(( - PublicPacket { + Self { packet_type, dcid, scid: Some(scid),