diff --git a/ledger/src/shred.rs b/ledger/src/shred.rs index 69c45d6cccc217..93d1de65e96955 100644 --- a/ledger/src/shred.rs +++ b/ledger/src/shred.rs @@ -1167,7 +1167,7 @@ pub(crate) fn make_merkle_shreds_from_entries( reed_solomon_cache, stats, )?; - Ok(shreds.into_iter().flatten().map(Shred::from).collect()) + Ok(shreds.into_iter().map(Shred::from).collect()) } // Accepts shreds in the slot range [root + 1, max_slot]. @@ -1336,6 +1336,7 @@ mod tests { super::*, assert_matches::assert_matches, bincode::serialized_size, + itertools::Itertools, rand::Rng, rand_chacha::{rand_core::SeedableRng, ChaChaRng}, rayon::ThreadPoolBuilder, @@ -1356,7 +1357,7 @@ mod tests { data_size: usize, chained: bool, is_last_in_slot: bool, - ) -> Result>, Error> { + ) -> Result, Error> { let thread_pool = ThreadPoolBuilder::new().num_threads(2).build().unwrap(); let chained_merkle_root = chained.then(|| Hash::new_from_array(rng.gen())); let parent_offset = rng.gen_range(1..=u16::try_from(slot).unwrap_or(u16::MAX)); @@ -1504,8 +1505,8 @@ mod tests { is_last_in_slot, ) .unwrap(); - assert_eq!(shreds.len(), 1); - let shreds: Vec<_> = shreds.into_iter().flatten().map(Shred::from).collect(); + let shreds: Vec<_> = shreds.into_iter().map(Shred::from).collect(); + assert_eq!(shreds.iter().map(Shred::fec_set_index).dedup().count(), 1); assert_matches!(shreds[0].shred_type(), ShredType::Data); let parent_slot = shreds[0].parent().unwrap(); @@ -2216,7 +2217,6 @@ mod tests { ) .unwrap() .into_iter() - .flatten() .map(Shred::from) .map(|shred| fill_retransmitter_signature(&mut rng, shred, chained, is_last_in_slot)) .collect(); diff --git a/ledger/src/shred/merkle.rs b/ledger/src/shred/merkle.rs index f9b35e17f62c0b..c8277a9e5e81b8 100644 --- a/ledger/src/shred/merkle.rs +++ b/ledger/src/shred/merkle.rs @@ -18,7 +18,7 @@ use { assert_matches::debug_assert_matches, itertools::{Either, Itertools}, rayon::{prelude::*, ThreadPool}, - reed_solomon_erasure::Error::{InvalidIndex, TooFewParityShards, TooFewShards}, + reed_solomon_erasure::Error::{InvalidIndex, TooFewParityShards}, solana_perf::packet::deserialize_from_with_limit, solana_sdk::{ clock::Slot, @@ -97,6 +97,11 @@ impl Shred { dispatch!(fn set_signature(&mut self, signature: Signature)); dispatch!(fn signed_data(&self) -> Result); + #[inline] + fn fec_set_index(&self) -> u32 { + self.common_header().fec_set_index + } + #[inline] fn merkle_proof(&self) -> Result, Error> { match self { @@ -145,10 +150,6 @@ impl Shred { dispatch!(fn merkle_root(&self) -> Result); dispatch!(fn proof_size(&self) -> Result); - fn fec_set_index(&self) -> u32 { - self.common_header().fec_set_index - } - fn index(&self) -> u32 { self.common_header().index } @@ -1055,21 +1056,59 @@ pub(super) fn make_shreds_from_data( next_code_index: u32, reed_solomon_cache: &ReedSolomonCache, stats: &mut ProcessShredsStats, -) -> Result>, Error> { - fn new_shred_data( - common_header: ShredCommonHeader, +) -> Result, Error> { + // Generates data shreds for the current erasure batch. + // Updates ShredCommonHeader.index for data shreds of the next batch. + fn make_shreds_data<'a>( + common_header: &'a mut ShredCommonHeader, mut data_header: DataShredHeader, - data: &[u8], - ) -> ShredData { - let size = ShredData::SIZE_OF_HEADERS + data.len(); - let mut payload = vec![0u8; ShredData::SIZE_OF_PAYLOAD]; - payload[ShredData::SIZE_OF_HEADERS..size].copy_from_slice(data); - data_header.size = size as u16; - ShredData { - common_header, - data_header, - payload, - } + chunks: impl IntoIterator + 'a, + ) -> impl Iterator + 'a { + debug_assert_matches!(common_header.shred_variant, ShredVariant::MerkleData { .. }); + chunks.into_iter().map(move |chunk| { + debug_assert_matches!(common_header.shred_variant, + ShredVariant::MerkleData { proof_size, chained, resigned } + if chunk.len() <= ShredData::capacity(proof_size, chained, resigned).unwrap() + ); + let size = ShredData::SIZE_OF_HEADERS + chunk.len(); + let mut payload = vec![0u8; ShredData::SIZE_OF_PAYLOAD]; + payload[ShredData::SIZE_OF_HEADERS..size].copy_from_slice(chunk); + data_header.size = size as u16; + let shred = ShredData { + common_header: *common_header, + data_header, + payload, + }; + common_header.index += 1; + shred + }) + } + // Generates coding shreds for the current erasure batch. + // Updates ShredCommonHeader.index for coding shreds of the next batch. + fn make_shreds_code( + common_header: &mut ShredCommonHeader, + num_data_shreds: usize, + is_last_in_slot: bool, + ) -> impl Iterator + '_ { + debug_assert_matches!(common_header.shred_variant, ShredVariant::MerkleCode { .. }); + let erasure_batch_size = shredder::get_erasure_batch_size(num_data_shreds, is_last_in_slot); + let num_coding_shreds = erasure_batch_size - num_data_shreds; + let mut coding_header = CodingShredHeader { + num_data_shreds: num_data_shreds as u16, + num_coding_shreds: num_coding_shreds as u16, + position: 0, + }; + std::iter::repeat_with(move || { + let shred = ShredCode { + common_header: *common_header, + coding_header, + payload: vec![0u8; ShredCode::SIZE_OF_PAYLOAD], + }; + common_header.index += 1; + coding_header.position += 1; + shred + }) + .take(num_coding_shreds) } let now = Instant::now(); let chained = chained_merkle_root.is_some(); @@ -1079,7 +1118,8 @@ pub(super) fn make_shreds_from_data( let proof_size = get_proof_size(erasure_batch_size); let data_buffer_size = ShredData::capacity(proof_size, chained, resigned)?; let chunk_size = DATA_SHREDS_PER_FEC_BLOCK * data_buffer_size; - let mut common_header = ShredCommonHeader { + // Common header for the data shreds. + let mut common_header_data = ShredCommonHeader { signature: Signature::default(), shred_variant: ShredVariant::MerkleData { proof_size, @@ -1091,6 +1131,16 @@ pub(super) fn make_shreds_from_data( version: shred_version, fec_set_index: next_shred_index, }; + // Common header for the coding shreds. + let mut common_header_code = ShredCommonHeader { + shred_variant: ShredVariant::MerkleCode { + proof_size, + chained, + resigned, + }, + index: next_code_index, + ..common_header_data + }; let data_header = { let parent_offset = slot .checked_sub(parent_slot) @@ -1107,17 +1157,33 @@ pub(super) fn make_shreds_from_data( size: 0u16, } }; + let mut shreds = { + let capacity = 2 * DATA_SHREDS_PER_FEC_BLOCK * data.len().div_ceil(chunk_size); + Vec::::with_capacity(capacity) + }; // Split the data into erasure batches and initialize - // data shreds from chunks of each batch. - let mut shreds = Vec::::new(); + // data and coding shreds for each batch. while data.len() >= 2 * chunk_size || data.len() == chunk_size { let (chunk, rest) = data.split_at(chunk_size); - common_header.fec_set_index = common_header.index; - for shred in chunk.chunks(data_buffer_size) { - let shred = new_shred_data(common_header, data_header, shred); - shreds.push(shred); - common_header.index += 1; - } + debug_assert_eq!(chunk.len(), DATA_SHREDS_PER_FEC_BLOCK * data_buffer_size); + common_header_data.fec_set_index = common_header_data.index; + common_header_code.fec_set_index = common_header_data.fec_set_index; + shreds.extend( + make_shreds_data( + &mut common_header_data, + data_header, + chunk.chunks(data_buffer_size), + ) + .map(Shred::ShredData), + ); + shreds.extend( + make_shreds_code( + &mut common_header_code, + DATA_SHREDS_PER_FEC_BLOCK, // num_data_shreds + is_last_in_slot && rest.is_empty(), // is_last_in_slot + ) + .map(Shred::ShredCode), + ); data = rest; } // If shreds.is_empty() then the data argument was empty. In that case we @@ -1148,29 +1214,41 @@ pub(super) fn make_shreds_from_data( )) }) .ok_or(Error::UnknownProofSize)?; - common_header.shred_variant = ShredVariant::MerkleData { + common_header_data.shred_variant = ShredVariant::MerkleData { proof_size, chained, resigned, }; - common_header.fec_set_index = common_header.index; - for shred in data - .chunks(data_buffer_size) - .chain(std::iter::repeat(&[][..])) - .take(num_data_shreds) - { - let shred = new_shred_data(common_header, data_header, shred); - shreds.push(shred); - common_header.index += 1; - } - if let Some(shred) = shreds.last() { + common_header_code.shred_variant = ShredVariant::MerkleCode { + proof_size, + chained, + resigned, + }; + common_header_data.fec_set_index = common_header_data.index; + common_header_code.fec_set_index = common_header_data.fec_set_index; + shreds.extend({ + let chunks = data + .chunks(data_buffer_size) + .chain(std::iter::repeat(&[][..])) // possible padding + .take(num_data_shreds); + make_shreds_data(&mut common_header_data, data_header, chunks).map(Shred::ShredData) + }); + if let Some(Shred::ShredData(shred)) = shreds.last() { stats.data_buffer_residual += data_buffer_size - shred.data()?.len(); } + shreds.extend( + make_shreds_code(&mut common_header_code, num_data_shreds, is_last_in_slot) + .map(Shred::ShredCode), + ); } // Only the trailing data shreds may have residual data buffer. debug_assert!(shreds .iter() .rev() + .filter_map(|shred| match shred { + Shred::ShredCode(_) => None, + Shred::ShredData(shred) => Some(shred), + }) .skip_while(|shred| is_last_in_slot && shred.data().unwrap().is_empty()) .skip(1) .all(|shred| { @@ -1178,189 +1256,149 @@ pub(super) fn make_shreds_from_data( let capacity = ShredData::capacity(proof_size, chained, resigned).unwrap(); shred.data().unwrap().len() == capacity })); - // Adjust flags for the very last shred. - if let Some(shred) = shreds.last_mut() { + // Adjust flags for the very last data shred. + if let Some(Shred::ShredData(shred)) = shreds + .iter_mut() + .rev() + .find(|shred| matches!(shred, Shred::ShredData(_))) + { shred.data_header.flags |= if is_last_in_slot { ShredFlags::LAST_SHRED_IN_SLOT // also implies DATA_COMPLETE_SHRED } else { ShredFlags::DATA_COMPLETE_SHRED }; + let num_data_shreds = shred.common_header.index - next_shred_index; + stats.record_num_data_shreds(num_data_shreds as usize); } - // Write common and data headers into data shreds' payload buffer. - thread_pool.install(|| { - shreds.par_iter_mut().try_for_each(|shred| { - bincode::serialize_into( - &mut shred.payload[..], - &(&shred.common_header, &shred.data_header), - ) - }) - })?; stats.gen_data_elapsed += now.elapsed().as_micros() as u64; - stats.record_num_data_shreds(shreds.len()); let now = Instant::now(); // Group shreds by their respective erasure-batch. - let shreds: Vec> = shreds - .into_iter() - .group_by(|shred| shred.common_header.fec_set_index) - .into_iter() - .map(|(_, shreds)| shreds.collect()) + let batches: Vec<&mut [Shred]> = shreds + .chunk_by_mut(|a, b| a.fec_set_index() == b.fec_set_index()) .collect(); - // Obtain the shred index for the first coding shred of each batch. - let next_code_index: Vec<_> = shreds - .iter() - .scan(next_code_index, |next_code_index, chunk| { - let out = Some(*next_code_index); - let num_data_shreds = chunk.len(); - let erasure_batch_size = - shredder::get_erasure_batch_size(num_data_shreds, is_last_in_slot); - let num_coding_shreds = erasure_batch_size - num_data_shreds; - *next_code_index += num_coding_shreds as u32; - out - }) - .collect(); - // Generate coding shreds, populate merkle proof - // for all shreds and attach signature. - let shreds: Result, Error> = if let Some(chained_merkle_root) = chained_merkle_root { - shreds - .into_iter() - .zip(next_code_index) - .scan( - chained_merkle_root, - |chained_merkle_root, (shreds, next_code_index)| { - Some( - make_erasure_batch( - keypair, - shreds, - Some(*chained_merkle_root), - next_code_index, - is_last_in_slot, - reed_solomon_cache, - ) - .map(|(merkle_root, shreds)| { - *chained_merkle_root = merkle_root; - shreds - }), - ) - }, - ) - .collect() - } else if shreds.len() <= 1 { - shreds + if let Some(chained_merkle_root) = chained_merkle_root { + // We have to process erasure batches serially because the Merkle tree + // (and so the signature) cannot be computed without the Merkle root of + // the previous erasure batch. + batches .into_iter() - .zip(next_code_index) - .map(|(shreds, next_code_index)| { - make_erasure_batch( + .try_fold(chained_merkle_root, |chained_merkle_root, batch| { + finish_erasure_batch( + Some(thread_pool), keypair, - shreds, - None, // chained_merkle_root - next_code_index, - is_last_in_slot, + batch, + Some(chained_merkle_root), reed_solomon_cache, ) - .map(|(_merkle_root, shreds)| shreds) - }) - .collect() + })?; + } else if batches.len() <= 1 { + for batch in batches { + finish_erasure_batch( + Some(thread_pool), + keypair, + batch, + None, // chained_merkle_root + reed_solomon_cache, + )?; + } } else { thread_pool.install(|| { - shreds - .into_par_iter() - .zip(next_code_index) - .map(|(shreds, next_code_index)| { - make_erasure_batch( - keypair, - shreds, - None, // chained_merkle_root - next_code_index, - is_last_in_slot, - reed_solomon_cache, - ) - .map(|(_merkle_root, shreds)| shreds) - }) - .collect() - }) - }; + batches.into_par_iter().try_for_each(|batch| { + finish_erasure_batch( + None, // thread_pool + keypair, + batch, + None, // chained_merkle_root + reed_solomon_cache, + ) + .map(|_| ()) + }) + })?; + } stats.gen_coding_elapsed += now.elapsed().as_micros() as u64; - shreds + Ok(shreds) } -// Generates coding shreds from data shreds, populates merke proof for all -// shreds and attaches signature. -fn make_erasure_batch( +// Given shreds of the same erasure batch: +// - Writes common and {data,coding} headers into shreds' payload. +// - Fills in erasure code buffers in the coding shreds. +// - Sets the chained_merkle_root for each shred. +// - Computes the Merkle tree for the erasure batch. +// - Signs the root of the Merkle tree. +// - Populates Merkle proof for each shred and attaches the signature. +// Returns the root of the Merkle tree (for chaining Merkle roots). +fn finish_erasure_batch( + thread_pool: Option<&ThreadPool>, keypair: &Keypair, - mut shreds: Vec, + shreds: &mut [Shred], // The Merkle root of the previous erasure batch if chained. chained_merkle_root: Option, - next_code_index: u32, - is_last_in_slot: bool, reed_solomon_cache: &ReedSolomonCache, -) -> Result<(/*merkle root:*/ Hash, Vec), Error> { - let num_data_shreds = shreds.len(); - let chained = chained_merkle_root.is_some(); - let resigned = chained && is_last_in_slot; - let erasure_batch_size = shredder::get_erasure_batch_size(num_data_shreds, is_last_in_slot); - let num_coding_shreds = erasure_batch_size - num_data_shreds; - let proof_size = get_proof_size(erasure_batch_size); - debug_assert!(shreds.iter().all(|shred| shred.common_header.shred_variant - == ShredVariant::MerkleData { - proof_size, - chained, - resigned - })); - let mut common_header = match shreds.first() { - None => return Err(Error::from(TooFewShards)), - Some(shred) => shred.common_header, - }; - if let Some(hash) = chained_merkle_root { - for shred in &mut shreds { - shred.set_chained_merkle_root(&hash)?; +) -> Result { + debug_assert_eq!(shreds.iter().map(Shred::fec_set_index).dedup().count(), 1); + // Write common and {data,coding} headers into shreds' payload. + fn write_headers(shred: &mut Shred) -> Result<(), bincode::Error> { + match shred { + Shred::ShredCode(shred) => bincode::serialize_into( + &mut shred.payload[..], + &(&shred.common_header, &shred.coding_header), + ), + Shred::ShredData(shred) => bincode::serialize_into( + &mut shred.payload[..], + &(&shred.common_header, &shred.data_header), + ), } } - // Generate erasure codings for encoded shard of data shreds. - let data: Vec<_> = shreds - .iter() - .map(ShredData::erasure_shard_as_slice) - .collect::>()?; - // Shreds should have erasure encoded shard of the same length. - debug_assert_eq!(data.iter().map(|shard| shard.len()).dedup().count(), 1); - let mut parity = vec![vec![0u8; data[0].len()]; num_coding_shreds]; + match thread_pool { + None => shreds.iter_mut().try_for_each(write_headers), + Some(thread_pool) => { + thread_pool.install(|| shreds.par_iter_mut().try_for_each(write_headers)) + } + }?; + // Fill in erasure code buffers in the coding shreds. + let CodingShredHeader { + num_data_shreds, + num_coding_shreds, + .. + } = { + // Last shred in the erasure batch should be a coding shred. + let Some(Shred::ShredCode(shred)) = shreds.last() else { + return Err(Error::from(TooFewParityShards)); + }; + shred.coding_header + }; + let num_data_shreds = usize::from(num_data_shreds); + let num_coding_shreds = usize::from(num_coding_shreds); + let erasure_batch_size = num_data_shreds + num_coding_shreds; reed_solomon_cache .get(num_data_shreds, num_coding_shreds)? - .encode_sep(&data, &mut parity[..])?; - let mut shreds: Vec<_> = shreds.into_iter().map(Shred::ShredData).collect(); - // Initialize coding shreds from erasure coding shards. - common_header.index = next_code_index; - common_header.shred_variant = ShredVariant::MerkleCode { - proof_size, - chained, - resigned, - }; - let mut coding_header = CodingShredHeader { - num_data_shreds: num_data_shreds as u16, - num_coding_shreds: num_coding_shreds as u16, - position: 0, - }; - for code in parity { - let mut payload = vec![0u8; ShredCode::SIZE_OF_PAYLOAD]; - let mut cursor = Cursor::new(&mut payload[..]); - bincode::serialize_into(&mut cursor, &(&common_header, &coding_header))?; - cursor.write_all(&code)?; - if let Some(chained_merkle_root) = chained_merkle_root { - cursor.write_all(chained_merkle_root.as_ref())?; + .encode( + shreds + .iter_mut() + .map(Shred::erasure_shard_as_slice_mut) + .collect::, _>>()?, + )?; + // Set the chained_merkle_root for each shred. + if let Some(chained_merkle_root) = chained_merkle_root { + for shred in shreds.iter_mut() { + shred.set_chained_merkle_root(&chained_merkle_root)?; } - let shred = ShredCode { - common_header, - coding_header, - payload, - }; - shreds.push(Shred::ShredCode(shred)); - common_header.index += 1; - coding_header.position += 1; } - // Compute Merkle tree for the erasure batch. - let nodes = shreds.iter().map(Shred::merkle_node); - let tree = make_merkle_tree(nodes)?; - // Sign root of Merkle tree. - let root = tree.last().ok_or(Error::InvalidMerkleProof)?; + // Compute the Merkle tree for the erasure batch. + let tree = match thread_pool { + None => { + let nodes = shreds.iter().map(Shred::merkle_node); + make_merkle_tree(nodes) + } + Some(thread_pool) => make_merkle_tree(thread_pool.install(|| { + shreds + .par_iter() + .map(Shred::merkle_node) + .collect::>() + })), + }?; + // Sign the root of the Merkle tree. + let root = tree.last().copied().ok_or(Error::InvalidMerkleProof)?; let signature = keypair.sign_message(root.as_ref()); // Populate merkle proof for all shreds and attach signature. for (index, shred) in shreds.iter_mut().enumerate() { @@ -1375,7 +1413,7 @@ fn make_erasure_batch( &Shred::from_payload(shred).unwrap() }); } - Ok((*root, shreds)) + Ok(root) } #[cfg(test)] @@ -1848,7 +1886,6 @@ mod test { &mut ProcessShredsStats::default(), ) .unwrap(); - let shreds: Vec<_> = shreds.into_iter().flatten().collect(); let data_shreds: Vec<_> = shreds .iter() .filter_map(|shred| match shred {