From 536c5bf7a9a91c8e19221fec22edac219facd6c7 Mon Sep 17 00:00:00 2001 From: behzad nouri Date: Thu, 30 Jan 2025 18:30:38 +0000 Subject: [PATCH] generates erasure codes in-place using mutable references into shreds' payload (#4609) When making Merkle shreds from data, parity shards are first generated externally in a vector of vectors: https://github.com/anza-xyz/agave/blob/6ff4dee5f/ledger/src/shred/merkle.rs#L1325 and then copied into coding shreds payload: https://github.com/anza-xyz/agave/blob/6ff4dee5f/ledger/src/shred/merkle.rs#L1346 There are also many intermediate vector allocations in the process. The commit avoids this and minimizes allocations by first initializing all data and coding shreds in all erasure batches in a single vector. Then the erasure codes are generated and populated in-place using mutable references into the coding shreds' payload. --- ledger/src/shred.rs | 10 +- ledger/src/shred/merkle.rs | 429 ++++++++++++++++++++----------------- 2 files changed, 238 insertions(+), 201 deletions(-) diff --git a/ledger/src/shred.rs b/ledger/src/shred.rs index bd8c4b242383e4..dbac96794c9ff4 100644 --- a/ledger/src/shred.rs +++ b/ledger/src/shred.rs @@ -1216,7 +1216,7 @@ pub(crate) fn make_merkle_shreds_from_entries( reed_solomon_cache, stats, )?; - Ok(shreds.into_iter().flatten().map(Shred::from).collect()) + Ok(shreds.into_iter().map(Shred::from).collect()) } // Accepts shreds in the slot range [root + 1, max_slot]. @@ -1385,6 +1385,7 @@ mod tests { super::*, assert_matches::assert_matches, bincode::serialized_size, + itertools::Itertools, rand::Rng, rand_chacha::{rand_core::SeedableRng, ChaChaRng}, rayon::ThreadPoolBuilder, @@ -1410,7 +1411,7 @@ mod tests { data_size: usize, chained: bool, is_last_in_slot: bool, - ) -> Result>, Error> { + ) -> Result, Error> { let thread_pool = ThreadPoolBuilder::new().num_threads(2).build().unwrap(); let chained_merkle_root = chained.then(|| Hash::new_from_array(rng.gen())); let parent_offset = rng.gen_range(1..=u16::try_from(slot).unwrap_or(u16::MAX)); @@ -1567,8 +1568,8 @@ mod tests { is_last_in_slot, ) .unwrap(); - assert_eq!(shreds.len(), 1); - let shreds: Vec<_> = shreds.into_iter().flatten().map(Shred::from).collect(); + let shreds: Vec<_> = shreds.into_iter().map(Shred::from).collect(); + assert_eq!(shreds.iter().map(Shred::fec_set_index).dedup().count(), 1); assert_matches!(shreds[0].shred_type(), ShredType::Data); let parent_slot = shreds[0].parent().unwrap(); @@ -2279,7 +2280,6 @@ mod tests { ) .unwrap() .into_iter() - .flatten() .map(Shred::from) .map(|shred| fill_retransmitter_signature(&mut rng, shred, chained, is_last_in_slot)) .collect(); diff --git a/ledger/src/shred/merkle.rs b/ledger/src/shred/merkle.rs index e10733b6781899..96a3581df69f4f 100644 --- a/ledger/src/shred/merkle.rs +++ b/ledger/src/shred/merkle.rs @@ -20,7 +20,7 @@ use { assert_matches::debug_assert_matches, itertools::{Either, Itertools}, rayon::{prelude::*, ThreadPool}, - reed_solomon_erasure::Error::{InvalidIndex, TooFewParityShards, TooFewShards}, + reed_solomon_erasure::Error::{InvalidIndex, TooFewParityShards}, solana_perf::packet::deserialize_from_with_limit, solana_sdk::{ clock::Slot, @@ -99,6 +99,11 @@ impl Shred { dispatch!(fn set_signature(&mut self, signature: Signature)); dispatch!(fn signed_data(&self) -> Result); + #[inline] + fn fec_set_index(&self) -> u32 { + self.common_header().fec_set_index + } + #[inline] fn merkle_proof(&self) -> Result, Error> { match self { @@ -150,10 +155,6 @@ impl Shred { dispatch!(fn merkle_root(&self) -> Result); dispatch!(fn proof_size(&self) -> Result); - fn fec_set_index(&self) -> u32 { - self.common_header().fec_set_index - } - fn index(&self) -> u32 { self.common_header().index } @@ -1085,21 +1086,59 @@ pub(super) fn make_shreds_from_data( next_code_index: u32, reed_solomon_cache: &ReedSolomonCache, stats: &mut ProcessShredsStats, -) -> Result>, Error> { - fn new_shred_data( - common_header: ShredCommonHeader, +) -> Result, Error> { + // Generates data shreds for the current erasure batch. + // Updates ShredCommonHeader.index for data shreds of the next batch. + fn make_shreds_data<'a>( + common_header: &'a mut ShredCommonHeader, mut data_header: DataShredHeader, - data: &[u8], - ) -> ShredData { - let size = ShredData::SIZE_OF_HEADERS + data.len(); - let mut payload = vec![0u8; ShredData::SIZE_OF_PAYLOAD]; - payload[ShredData::SIZE_OF_HEADERS..size].copy_from_slice(data); - data_header.size = size as u16; - ShredData { - common_header, - data_header, - payload: Payload::from(payload), - } + chunks: impl IntoIterator + 'a, + ) -> impl Iterator + 'a { + debug_assert_matches!(common_header.shred_variant, ShredVariant::MerkleData { .. }); + chunks.into_iter().map(move |chunk| { + debug_assert_matches!(common_header.shred_variant, + ShredVariant::MerkleData { proof_size, chained, resigned } + if chunk.len() <= ShredData::capacity(proof_size, chained, resigned).unwrap() + ); + let size = ShredData::SIZE_OF_HEADERS + chunk.len(); + let mut payload = vec![0u8; ShredData::SIZE_OF_PAYLOAD]; + payload[ShredData::SIZE_OF_HEADERS..size].copy_from_slice(chunk); + data_header.size = size as u16; + let shred = ShredData { + common_header: *common_header, + data_header, + payload: Payload::from(payload), + }; + common_header.index += 1; + shred + }) + } + // Generates coding shreds for the current erasure batch. + // Updates ShredCommonHeader.index for coding shreds of the next batch. + fn make_shreds_code( + common_header: &mut ShredCommonHeader, + num_data_shreds: usize, + is_last_in_slot: bool, + ) -> impl Iterator + '_ { + debug_assert_matches!(common_header.shred_variant, ShredVariant::MerkleCode { .. }); + let erasure_batch_size = shredder::get_erasure_batch_size(num_data_shreds, is_last_in_slot); + let num_coding_shreds = erasure_batch_size - num_data_shreds; + let mut coding_header = CodingShredHeader { + num_data_shreds: num_data_shreds as u16, + num_coding_shreds: num_coding_shreds as u16, + position: 0, + }; + std::iter::repeat_with(move || { + let shred = ShredCode { + common_header: *common_header, + coding_header, + payload: Payload::from(vec![0u8; ShredCode::SIZE_OF_PAYLOAD]), + }; + common_header.index += 1; + coding_header.position += 1; + shred + }) + .take(num_coding_shreds) } let now = Instant::now(); let chained = chained_merkle_root.is_some(); @@ -1109,7 +1148,8 @@ pub(super) fn make_shreds_from_data( let proof_size = get_proof_size(erasure_batch_size); let data_buffer_size = ShredData::capacity(proof_size, chained, resigned)?; let chunk_size = DATA_SHREDS_PER_FEC_BLOCK * data_buffer_size; - let mut common_header = ShredCommonHeader { + // Common header for the data shreds. + let mut common_header_data = ShredCommonHeader { signature: Signature::default(), shred_variant: ShredVariant::MerkleData { proof_size, @@ -1121,6 +1161,16 @@ pub(super) fn make_shreds_from_data( version: shred_version, fec_set_index: next_shred_index, }; + // Common header for the coding shreds. + let mut common_header_code = ShredCommonHeader { + shred_variant: ShredVariant::MerkleCode { + proof_size, + chained, + resigned, + }, + index: next_code_index, + ..common_header_data + }; let data_header = { let parent_offset = slot .checked_sub(parent_slot) @@ -1133,17 +1183,33 @@ pub(super) fn make_shreds_from_data( size: 0u16, } }; + let mut shreds = { + let capacity = 2 * DATA_SHREDS_PER_FEC_BLOCK * data.len().div_ceil(chunk_size); + Vec::::with_capacity(capacity) + }; // Split the data into erasure batches and initialize - // data shreds from chunks of each batch. - let mut shreds = Vec::::new(); + // data and coding shreds for each batch. while data.len() >= 2 * chunk_size || data.len() == chunk_size { let (chunk, rest) = data.split_at(chunk_size); - common_header.fec_set_index = common_header.index; - for shred in chunk.chunks(data_buffer_size) { - let shred = new_shred_data(common_header, data_header, shred); - shreds.push(shred); - common_header.index += 1; - } + debug_assert_eq!(chunk.len(), DATA_SHREDS_PER_FEC_BLOCK * data_buffer_size); + common_header_data.fec_set_index = common_header_data.index; + common_header_code.fec_set_index = common_header_data.fec_set_index; + shreds.extend( + make_shreds_data( + &mut common_header_data, + data_header, + chunk.chunks(data_buffer_size), + ) + .map(Shred::ShredData), + ); + shreds.extend( + make_shreds_code( + &mut common_header_code, + DATA_SHREDS_PER_FEC_BLOCK, // num_data_shreds + is_last_in_slot && rest.is_empty(), // is_last_in_slot + ) + .map(Shred::ShredCode), + ); data = rest; } // If shreds.is_empty() then the data argument was empty. In that case we @@ -1174,29 +1240,41 @@ pub(super) fn make_shreds_from_data( )) }) .ok_or(Error::UnknownProofSize)?; - common_header.shred_variant = ShredVariant::MerkleData { + common_header_data.shred_variant = ShredVariant::MerkleData { proof_size, chained, resigned, }; - common_header.fec_set_index = common_header.index; - for shred in data - .chunks(data_buffer_size) - .chain(std::iter::repeat(&[][..])) - .take(num_data_shreds) - { - let shred = new_shred_data(common_header, data_header, shred); - shreds.push(shred); - common_header.index += 1; - } - if let Some(shred) = shreds.last() { + common_header_code.shred_variant = ShredVariant::MerkleCode { + proof_size, + chained, + resigned, + }; + common_header_data.fec_set_index = common_header_data.index; + common_header_code.fec_set_index = common_header_data.fec_set_index; + shreds.extend({ + let chunks = data + .chunks(data_buffer_size) + .chain(std::iter::repeat(&[][..])) // possible padding + .take(num_data_shreds); + make_shreds_data(&mut common_header_data, data_header, chunks).map(Shred::ShredData) + }); + if let Some(Shred::ShredData(shred)) = shreds.last() { stats.data_buffer_residual += data_buffer_size - shred.data()?.len(); } + shreds.extend( + make_shreds_code(&mut common_header_code, num_data_shreds, is_last_in_slot) + .map(Shred::ShredCode), + ); } // Only the trailing data shreds may have residual data buffer. debug_assert!(shreds .iter() .rev() + .filter_map(|shred| match shred { + Shred::ShredCode(_) => None, + Shred::ShredData(shred) => Some(shred), + }) .skip_while(|shred| is_last_in_slot && shred.data().unwrap().is_empty()) .skip(1) .all(|shred| { @@ -1204,189 +1282,149 @@ pub(super) fn make_shreds_from_data( let capacity = ShredData::capacity(proof_size, chained, resigned).unwrap(); shred.data().unwrap().len() == capacity })); - // Adjust flags for the very last shred. - if let Some(shred) = shreds.last_mut() { + // Adjust flags for the very last data shred. + if let Some(Shred::ShredData(shred)) = shreds + .iter_mut() + .rev() + .find(|shred| matches!(shred, Shred::ShredData(_))) + { shred.data_header.flags |= if is_last_in_slot { ShredFlags::LAST_SHRED_IN_SLOT // also implies DATA_COMPLETE_SHRED } else { ShredFlags::DATA_COMPLETE_SHRED }; + let num_data_shreds = shred.common_header.index - next_shred_index; + stats.record_num_data_shreds(num_data_shreds as usize); } - // Write common and data headers into data shreds' payload buffer. - thread_pool.install(|| { - shreds.par_iter_mut().try_for_each(|shred| { - bincode::serialize_into( - &mut shred.payload[..], - &(&shred.common_header, &shred.data_header), - ) - }) - })?; stats.gen_data_elapsed += now.elapsed().as_micros() as u64; - stats.record_num_data_shreds(shreds.len()); let now = Instant::now(); // Group shreds by their respective erasure-batch. - let shreds: Vec> = shreds - .into_iter() - .group_by(|shred| shred.common_header.fec_set_index) - .into_iter() - .map(|(_, shreds)| shreds.collect()) + let batches: Vec<&mut [Shred]> = shreds + .chunk_by_mut(|a, b| a.fec_set_index() == b.fec_set_index()) .collect(); - // Obtain the shred index for the first coding shred of each batch. - let next_code_index: Vec<_> = shreds - .iter() - .scan(next_code_index, |next_code_index, chunk| { - let out = Some(*next_code_index); - let num_data_shreds = chunk.len(); - let erasure_batch_size = - shredder::get_erasure_batch_size(num_data_shreds, is_last_in_slot); - let num_coding_shreds = erasure_batch_size - num_data_shreds; - *next_code_index += num_coding_shreds as u32; - out - }) - .collect(); - // Generate coding shreds, populate merkle proof - // for all shreds and attach signature. - let shreds: Result, Error> = if let Some(chained_merkle_root) = chained_merkle_root { - shreds - .into_iter() - .zip(next_code_index) - .scan( - chained_merkle_root, - |chained_merkle_root, (shreds, next_code_index)| { - Some( - make_erasure_batch( - keypair, - shreds, - Some(*chained_merkle_root), - next_code_index, - is_last_in_slot, - reed_solomon_cache, - ) - .map(|(merkle_root, shreds)| { - *chained_merkle_root = merkle_root; - shreds - }), - ) - }, - ) - .collect() - } else if shreds.len() <= 1 { - shreds + if let Some(chained_merkle_root) = chained_merkle_root { + // We have to process erasure batches serially because the Merkle tree + // (and so the signature) cannot be computed without the Merkle root of + // the previous erasure batch. + batches .into_iter() - .zip(next_code_index) - .map(|(shreds, next_code_index)| { - make_erasure_batch( + .try_fold(chained_merkle_root, |chained_merkle_root, batch| { + finish_erasure_batch( + Some(thread_pool), keypair, - shreds, - None, // chained_merkle_root - next_code_index, - is_last_in_slot, + batch, + Some(chained_merkle_root), reed_solomon_cache, ) - .map(|(_merkle_root, shreds)| shreds) - }) - .collect() + })?; + } else if batches.len() <= 1 { + for batch in batches { + finish_erasure_batch( + Some(thread_pool), + keypair, + batch, + None, // chained_merkle_root + reed_solomon_cache, + )?; + } } else { thread_pool.install(|| { - shreds - .into_par_iter() - .zip(next_code_index) - .map(|(shreds, next_code_index)| { - make_erasure_batch( - keypair, - shreds, - None, // chained_merkle_root - next_code_index, - is_last_in_slot, - reed_solomon_cache, - ) - .map(|(_merkle_root, shreds)| shreds) - }) - .collect() - }) - }; + batches.into_par_iter().try_for_each(|batch| { + finish_erasure_batch( + None, // thread_pool + keypair, + batch, + None, // chained_merkle_root + reed_solomon_cache, + ) + .map(|_| ()) + }) + })?; + } stats.gen_coding_elapsed += now.elapsed().as_micros() as u64; - shreds + Ok(shreds) } -// Generates coding shreds from data shreds, populates merke proof for all -// shreds and attaches signature. -fn make_erasure_batch( +// Given shreds of the same erasure batch: +// - Writes common and {data,coding} headers into shreds' payload. +// - Fills in erasure code buffers in the coding shreds. +// - Sets the chained_merkle_root for each shred. +// - Computes the Merkle tree for the erasure batch. +// - Signs the root of the Merkle tree. +// - Populates Merkle proof for each shred and attaches the signature. +// Returns the root of the Merkle tree (for chaining Merkle roots). +fn finish_erasure_batch( + thread_pool: Option<&ThreadPool>, keypair: &Keypair, - mut shreds: Vec, + shreds: &mut [Shred], // The Merkle root of the previous erasure batch if chained. chained_merkle_root: Option, - next_code_index: u32, - is_last_in_slot: bool, reed_solomon_cache: &ReedSolomonCache, -) -> Result<(/*merkle root:*/ Hash, Vec), Error> { - let num_data_shreds = shreds.len(); - let chained = chained_merkle_root.is_some(); - let resigned = chained && is_last_in_slot; - let erasure_batch_size = shredder::get_erasure_batch_size(num_data_shreds, is_last_in_slot); - let num_coding_shreds = erasure_batch_size - num_data_shreds; - let proof_size = get_proof_size(erasure_batch_size); - debug_assert!(shreds.iter().all(|shred| shred.common_header.shred_variant - == ShredVariant::MerkleData { - proof_size, - chained, - resigned - })); - let mut common_header = match shreds.first() { - None => return Err(Error::from(TooFewShards)), - Some(shred) => shred.common_header, - }; - if let Some(hash) = chained_merkle_root { - for shred in &mut shreds { - shred.set_chained_merkle_root(&hash)?; +) -> Result { + debug_assert_eq!(shreds.iter().map(Shred::fec_set_index).dedup().count(), 1); + // Write common and {data,coding} headers into shreds' payload. + fn write_headers(shred: &mut Shred) -> Result<(), bincode::Error> { + match shred { + Shred::ShredCode(shred) => bincode::serialize_into( + &mut shred.payload[..], + &(&shred.common_header, &shred.coding_header), + ), + Shred::ShredData(shred) => bincode::serialize_into( + &mut shred.payload[..], + &(&shred.common_header, &shred.data_header), + ), } } - // Generate erasure codings for encoded shard of data shreds. - let data: Vec<_> = shreds - .iter() - .map(ShredData::erasure_shard_as_slice) - .collect::>()?; - // Shreds should have erasure encoded shard of the same length. - debug_assert_eq!(data.iter().map(|shard| shard.len()).dedup().count(), 1); - let mut parity = vec![vec![0u8; data[0].len()]; num_coding_shreds]; + match thread_pool { + None => shreds.iter_mut().try_for_each(write_headers), + Some(thread_pool) => { + thread_pool.install(|| shreds.par_iter_mut().try_for_each(write_headers)) + } + }?; + // Fill in erasure code buffers in the coding shreds. + let CodingShredHeader { + num_data_shreds, + num_coding_shreds, + .. + } = { + // Last shred in the erasure batch should be a coding shred. + let Some(Shred::ShredCode(shred)) = shreds.last() else { + return Err(Error::from(TooFewParityShards)); + }; + shred.coding_header + }; + let num_data_shreds = usize::from(num_data_shreds); + let num_coding_shreds = usize::from(num_coding_shreds); + let erasure_batch_size = num_data_shreds + num_coding_shreds; reed_solomon_cache .get(num_data_shreds, num_coding_shreds)? - .encode_sep(&data, &mut parity[..])?; - let mut shreds: Vec<_> = shreds.into_iter().map(Shred::ShredData).collect(); - // Initialize coding shreds from erasure coding shards. - common_header.index = next_code_index; - common_header.shred_variant = ShredVariant::MerkleCode { - proof_size, - chained, - resigned, - }; - let mut coding_header = CodingShredHeader { - num_data_shreds: num_data_shreds as u16, - num_coding_shreds: num_coding_shreds as u16, - position: 0, - }; - for code in parity { - let mut payload = vec![0u8; ShredCode::SIZE_OF_PAYLOAD]; - let mut cursor = Cursor::new(&mut payload[..]); - bincode::serialize_into(&mut cursor, &(&common_header, &coding_header))?; - cursor.write_all(&code)?; - if let Some(chained_merkle_root) = chained_merkle_root { - cursor.write_all(chained_merkle_root.as_ref())?; + .encode( + shreds + .iter_mut() + .map(Shred::erasure_shard_as_slice_mut) + .collect::, _>>()?, + )?; + // Set the chained_merkle_root for each shred. + if let Some(chained_merkle_root) = chained_merkle_root { + for shred in shreds.iter_mut() { + shred.set_chained_merkle_root(&chained_merkle_root)?; } - let shred = ShredCode { - common_header, - coding_header, - payload: Payload::from(payload), - }; - shreds.push(Shred::ShredCode(shred)); - common_header.index += 1; - coding_header.position += 1; } - // Compute Merkle tree for the erasure batch. - let nodes = shreds.iter().map(Shred::merkle_node); - let tree = make_merkle_tree(nodes)?; - // Sign root of Merkle tree. - let root = tree.last().ok_or(Error::InvalidMerkleProof)?; + // Compute the Merkle tree for the erasure batch. + let tree = match thread_pool { + None => { + let nodes = shreds.iter().map(Shred::merkle_node); + make_merkle_tree(nodes) + } + Some(thread_pool) => make_merkle_tree(thread_pool.install(|| { + shreds + .par_iter() + .map(Shred::merkle_node) + .collect::>() + })), + }?; + // Sign the root of the Merkle tree. + let root = tree.last().copied().ok_or(Error::InvalidMerkleProof)?; let signature = keypair.sign_message(root.as_ref()); // Populate merkle proof for all shreds and attach signature. for (index, shred) in shreds.iter_mut().enumerate() { @@ -1401,7 +1439,7 @@ fn make_erasure_batch( &Shred::from_payload(shred).unwrap() }); } - Ok((*root, shreds)) + Ok(root) } #[cfg(test)] @@ -1874,7 +1912,6 @@ mod test { &mut ProcessShredsStats::default(), ) .unwrap(); - let shreds: Vec<_> = shreds.into_iter().flatten().collect(); let data_shreds: Vec<_> = shreds .iter() .filter_map(|shred| match shred {