Skip to content

Commit

Permalink
Minimize diff
Browse files Browse the repository at this point in the history
  • Loading branch information
larseggert committed Jan 31, 2025
1 parent 337a894 commit f2c3558
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 25 deletions.
11 changes: 8 additions & 3 deletions neqo-crypto/src/aead.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ impl RealAead {
/// # Errors
///
/// If the input can't be protected or any input is too large for NSS.
pub fn encrypt_in_place(&self, count: u64, aad: &[u8], data: &[u8]) -> Res<usize> {
pub fn encrypt_in_place<'a>(
&self,
count: u64,
aad: &[u8],
data: &'a mut [u8],
) -> Res<&'a mut [u8]> {
let mut l: c_uint = 0;
unsafe {
SSL_AeadEncrypt(
Expand All @@ -153,7 +158,7 @@ impl RealAead {
c_uint::try_from(data.len())?,
)
}?;
Ok(l.try_into()?)
Ok(&mut data[..l.try_into()?])
}

/// Decrypt a ciphertext.
Expand Down Expand Up @@ -218,7 +223,7 @@ impl RealAead {
c_uint::try_from(data.len())?,
)
}?;
debug_assert_eq!(usize::try_from(l), data.len() - self.expansion());
debug_assert_eq!(usize::try_from(l)?, data.len() - self.expansion());
Ok(&mut data[..l.try_into()?])
}
}
Expand Down
15 changes: 11 additions & 4 deletions neqo-crypto/src/aead_null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,26 @@ impl AeadNull {
}

#[allow(clippy::missing_errors_doc)]
pub fn encrypt_in_place(&self, _count: u64, _aad: &[u8], data: &mut [u8]) -> Res<usize> {
pub fn encrypt_in_place<'a>(
&self,
_count: u64,
_aad: &[u8],
data: &'a mut [u8],
) -> Res<&'a mut [u8]> {
let pos = data.len() - self.expansion();
data[pos..].copy_from_slice(AEAD_NULL_TAG);
Ok(data.len())
Ok(data)
}

fn decrypt_check(&self, _count: u64, _aad: &[u8], input: &[u8]) -> Res<usize> {
if input.len() < self.expansion() {
return Err(Error::from(SEC_ERROR_BAD_DATA));
}

let len_encrypted = input.len().checked_sub(self.expansion())
.ok_or(Error::from(SEC_ERROR_BAD_DATA))?;
let len_encrypted = input
.len()
.checked_sub(self.expansion())
.ok_or_else(|| Error::from(SEC_ERROR_BAD_DATA))?;
// Check that:
// 1) expansion is all zeros and
// 2) if the encrypted data is also supplied that at least some values are no zero
Expand Down
6 changes: 4 additions & 2 deletions neqo-transport/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1567,7 +1567,6 @@ impl Connection {
qtrace!("[{self}] {} input {}", path.borrow(), hex(&d));
let tos = d.tos();
let remote = d.source();
let len = d.len();
let mut slc = d.as_mut();
let mut dcid = None;
let pto = path.borrow().rtt().pto(self.confirmed());
Expand Down Expand Up @@ -1598,7 +1597,10 @@ impl Connection {
Ok(payload) => {
// OK, we have a valid packet.
self.idle_timeout.on_packet_received(now);
self.log_packet(packet::MetaData::new_in(path, tos, packet_len, &payload), now);
self.log_packet(
packet::MetaData::new_in(path, tos, packet_len, &payload),
now,
);

#[cfg(feature = "build-fuzzing-corpus")]
if payload.packet_type() == PacketType::Initial {
Expand Down
13 changes: 9 additions & 4 deletions neqo-transport/src/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,12 @@ impl CryptoDxState {
self.used_pn.end
}

pub fn encrypt(&mut self, pn: PacketNumber, hdr: Range<usize>, data: &mut [u8]) -> Res<usize> {
pub fn encrypt<'a>(
&mut self,
pn: PacketNumber,
hdr: Range<usize>,
data: &'a mut [u8],
) -> Res<&'a mut [u8]> {
debug_assert_eq!(self.direction, CryptoDxDirection::Write);
qtrace!(
"[{self}] encrypt_in_place pn={pn} hdr={} body={}",
Expand All @@ -658,12 +663,12 @@ impl CryptoDxState {
let (prev, data) = data.split_at_mut(hdr.end);
// `prev` may have already-encrypted packets this one is being coalesced with.
// Use only the actual current header for AAD.
let len = self.aead.encrypt_in_place(pn, &prev[hdr], data)?;
let data = self.aead.encrypt_in_place(pn, &prev[hdr], data)?;

qtrace!("[{self}] encrypt ct={}", hex(data));
qtrace!("[{self}] encrypt ct={}", hex(&data));
debug_assert_eq!(pn, self.next_pn());
self.used(pn)?;
Ok(len)
Ok(data)
}

#[must_use]
Expand Down
22 changes: 10 additions & 12 deletions neqo-transport/src/packet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,8 @@ impl PacketBuilder {
let data_end = self.encoder.len();
self.pad_to(data_end + crypto.expansion(), 0);

let ciphertext_len = crypto.encrypt(self.pn, self.header.clone(), self.encoder.as_mut())?;

// Calculate the mask.
let ciphertext = &self.encoder.as_mut()[self.header.end..self.header.end + ciphertext_len];
let ciphertext = crypto.encrypt(self.pn, self.header.clone(), self.encoder.as_mut())?;
let offset = SAMPLE_OFFSET - self.offsets.pn.len();
if offset + SAMPLE_SIZE > ciphertext.len() {
return Err(Error::InternalError);
Expand Down Expand Up @@ -622,22 +620,22 @@ impl<'a> PublicPacket<'a> {
dcid_decoder: &dyn ConnectionIdDecoder,
) -> Res<(Self, &'a mut [u8])> {
let mut decoder = Decoder::new(data);
let first = PublicPacket::opt(decoder.decode_uint::<u8>())?;
let first = Self::opt(decoder.decode_uint::<u8>())?;

if first & 0x80 == PACKET_BIT_SHORT {
// Conveniently, this also guarantees that there is enough space
// for a connection ID of any size.
if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE {
return Err(Error::InvalidPacket);
}
let dcid = PublicPacket::opt(dcid_decoder.decode_cid(&mut decoder))?.into();
let dcid = Self::opt(dcid_decoder.decode_cid(&mut decoder))?.into();
if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE {
return Err(Error::InvalidPacket);
}
let header_len = decoder.offset();

return Ok((
PublicPacket {
Self {
packet_type: PacketType::Short,
dcid,
scid: None,
Expand All @@ -651,14 +649,14 @@ impl<'a> PublicPacket<'a> {
}

// Generic long header.
let version = PublicPacket::opt(decoder.decode_uint())?;
let dcid = ConnectionIdRef::from(PublicPacket::opt(decoder.decode_vec(1))?).into();
let scid = ConnectionIdRef::from(PublicPacket::opt(decoder.decode_vec(1))?).into();
let version = Self::opt(decoder.decode_uint())?;
let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?).into();
let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?).into();

// Version negotiation.
if version == 0 {
return Ok((
PublicPacket {
Self {
packet_type: PacketType::VersionNegotiation,
dcid,
scid: Some(scid),
Expand All @@ -674,7 +672,7 @@ impl<'a> PublicPacket<'a> {
// Check that this is a long header from a supported version.
let Ok(version) = Version::try_from(version) else {
return Ok((
PublicPacket {
Self {
packet_type: PacketType::OtherVersion,
dcid,
scid: Some(scid),
Expand All @@ -698,7 +696,7 @@ impl<'a> PublicPacket<'a> {
let end = data.len() - decoder.remaining();
let (data, remainder) = data.split_at_mut(end);
Ok((
PublicPacket {
Self {
packet_type,
dcid,
scid: Some(scid),
Expand Down

0 comments on commit f2c3558

Please sign in to comment.