diff --git a/Cargo.toml b/Cargo.toml index ab748c8..e27ab87 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ rocksdb = "0.22.0" byte-unit = "5.1.6" getset = "0.1.3" chrono = "0.4" +tempfile = "3.15.0" [dev-dependencies] rstest = "0.23.0" diff --git a/src/primary/header_elector.rs b/src/primary/header_elector.rs index 812a4f0..66e0c5c 100644 --- a/src/primary/header_elector.rs +++ b/src/primary/header_elector.rs @@ -143,7 +143,7 @@ mod test { identity::ed25519::{self, Keypair}, PeerId, }; - use rstest::rstest; + use tempfile::tempdir; use tokio::{ sync::{broadcast, mpsc, watch}, time::timeout, @@ -152,9 +152,8 @@ mod test { use crate::{ db::{Column, Db}, - primary::test_utils::fixtures::{ - load_committee, random_digests, CHANNEL_CAPACITY, COMMITTEE_PATH, GENESIS_SEED, - }, + primary::test_utils::fixtures::{random_digests, CHANNEL_CAPACITY}, + settings::parser::Committee, types::{ block_header::BlockHeader, certificate::{Certificate, CertificateId}, @@ -178,18 +177,28 @@ mod test { mpsc::Receiver, Arc, CancellationToken, + tempfile::TempDir, ); - fn launch_header_elector(committee_path: String, db_path: &str) -> HeaderElectorFixutre { + async fn launch_header_elector() -> HeaderElectorFixutre { let (headers_tx, headers_rx) = broadcast::channel(CHANNEL_CAPACITY); let (round_tx, round_rx) = watch::channel((0, HashSet::new())); let (incomplete_headers_tx, incomplete_headers_rx) = mpsc::channel(CHANNEL_CAPACITY); let (network_tx, network_rx) = mpsc::channel(CHANNEL_CAPACITY); - let db = Arc::new(Db::new(db_path.into()).unwrap()); + + // Create a temporary directory for the test database + let temp_dir = tempdir().unwrap(); + let db_path = temp_dir.path().join("test.db"); + + let db = Arc::new(Db::new(db_path).unwrap()); + let validator_keypair = ed25519::Keypair::generate(); let token = CancellationToken::new(); let db_clone = db.clone(); let token_clone = token.clone(); + + let committee = Committee::new_test(); + tokio::spawn(async move { HeaderElector::spawn( token_clone, @@ -198,12 +207,13 @@ mod test { round_rx, validator_keypair, db_clone, - load_committee(&committee_path), + committee, incomplete_headers_tx, ) .await .unwrap() }); + ( headers_tx, round_tx, @@ -211,6 +221,7 @@ mod test { network_rx, db, token, + temp_dir, ) } @@ -259,21 +270,15 @@ mod test { let vote_status = vote.verify(&header_hash); assert!(vote_status.is_ok()); } - _ => { - assert!(false); - } + _ => assert!(false), } } #[tokio::test] - #[rstest] async fn test_first_round_valid_header_digests_stored() { - let (headers_tx, round_state_tx, _incomplete_headers_rx, network_rx, db, _) = - launch_header_elector( - COMMITTEE_PATH.into(), - "/tmp/test_first_round_valid_header_digests_stored_db", - ); - let genesis = Certificate::genesis(GENESIS_SEED); + let (headers_tx, round_state_tx, _incomplete_headers_rx, network_rx, db, _, _temp_dir) = + launch_header_elector().await; + let genesis = Certificate::genesis([0; 32]); let header = random_header(&[genesis.id()], 1); set_header_storage_in_db(&header, &db); set_certificates_in_db(&[genesis.clone()], &db); diff --git a/src/primary/mod.rs b/src/primary/mod.rs index b2f18fd..4d69870 100644 --- a/src/primary/mod.rs +++ b/src/primary/mod.rs @@ -259,3 +259,8 @@ impl BaseAgent for Primary { } } } + +#[cfg(test)] +mod tests { + mod header_tests; +} diff --git a/src/primary/tests/header_tests.rs b/src/primary/tests/header_tests.rs new file mode 100644 index 0000000..72b467a --- /dev/null +++ b/src/primary/tests/header_tests.rs @@ -0,0 +1,422 @@ +use crate::{ + db::Db, + primary::header_builder::{wait_for_quorum, HeaderBuilder}, + settings::parser::Committee, + types::{ + batch::BatchId, + block_header::{BlockHeader, HeaderId}, + certificate::Certificate, + network::{NetworkRequest, ReceivedObject, RequestPayload}, + signing::Signable, + sync::SyncStatus, + traits::{AsBytes, Hash}, + vote::Vote, + Round, + }, + utils::CircularBuffer, +}; +use libp2p::{identity::ed25519::Keypair, PeerId}; +use std::{collections::HashSet, sync::Arc, time::Duration}; +use tokio::{ + sync::{broadcast, mpsc, watch}, + time::sleep, +}; +use tokio_util::sync::CancellationToken; + +// test helper functions +impl BlockHeader { + #[cfg(test)] + pub fn new_test() -> Self { + let peer_id = PeerId::random(); + let mut author = [0u8; 32]; + let peer_bytes = peer_id.to_bytes(); + let hash = blake3::hash(&peer_bytes); + author.copy_from_slice(hash.as_bytes()); + + Self { + round: 1, + author, + timestamp_ms: chrono::Utc::now().timestamp_millis() as u128, + certificates_ids: vec![], + digests: vec![], + } + } +} + +impl Vote { + #[cfg(test)] + pub fn new_test() -> Self { + let keypair = Keypair::generate(); + let authority = keypair.public().to_bytes(); + let header = BlockHeader::new_test(); + let signature = header.sign(&keypair).unwrap(); + + Self { + authority, + signature, + } + } + + #[cfg(test)] + pub fn new_test_invalid() -> Self { + let keypair = Keypair::generate(); + let authority = keypair.public().to_bytes(); + + Self { + authority, + signature: vec![0; 32], + } + } +} + +impl Db { + #[cfg(test)] + pub async fn new_in_memory() -> anyhow::Result { + let temp_dir = tempfile::tempdir()?; + Self::new(temp_dir.path().to_path_buf()) + } +} + +impl BatchId { + #[cfg(test)] + pub fn test_digest(value: u8) -> Self { + let mut digest = [0u8; 32]; + digest.fill(value); + Self(digest) + } +} + +#[tokio::test] +async fn test_header_builder_initialization() { + let (network_tx, _) = mpsc::channel(100); + let (certificate_tx, _) = mpsc::channel(100); + let keypair = Keypair::generate(); + let db = Arc::new(Db::new_in_memory().await.unwrap()); + let (header_trigger_tx, header_trigger_rx) = watch::channel((0, HashSet::new())); + let (votes_tx, votes_rx) = broadcast::channel(100); + let digests_buffer = Arc::new(tokio::sync::Mutex::new(CircularBuffer::new(100))); + let committee = Committee::new_test(); + let (sync_status_tx, sync_status_rx) = watch::channel(SyncStatus::Complete); + + let handle = HeaderBuilder::spawn( + CancellationToken::new(), + network_tx, + certificate_tx, + keypair, + db, + header_trigger_rx, + votes_rx, + digests_buffer, + committee, + sync_status_rx, + ); + + handle.abort(); +} + +#[tokio::test] +async fn test_wait_for_quorum() { + let keypair = Keypair::generate(); + let (votes_tx, mut votes_rx) = broadcast::channel(100); + let header = BlockHeader::new_test(); + let threshold = 2; + + // Create votes before spawning the task + let votes: Vec<_> = (0..threshold) + .map(|_| Vote::from_header(header.clone(), &keypair).unwrap()) + .collect(); + + // send votes + tokio::spawn(async move { + for vote in votes { + let received = ReceivedObject { + object: vote, + sender: PeerId::random(), + }; + votes_tx.send(received).unwrap(); + sleep(Duration::from_millis(100)).await; + } + }); + + // Wait for quorum + let result = wait_for_quorum(&header, threshold, &mut votes_rx, &keypair).await; + assert!(result.is_ok()); + let votes = result.unwrap(); + assert_eq!(votes.len(), threshold); +} + +#[tokio::test] +async fn test_header_builder_sync_status() { + let (network_tx, _) = mpsc::channel(100); + let (certificate_tx, _) = mpsc::channel(100); + let keypair = Keypair::generate(); + let db = Arc::new(Db::new_in_memory().await.unwrap()); + let (header_trigger_tx, header_trigger_rx) = watch::channel((0, HashSet::new())); + let (votes_tx, votes_rx) = broadcast::channel(100); + let digests_buffer = Arc::new(tokio::sync::Mutex::new(CircularBuffer::new(100))); + let committee = Committee::new_test(); + let (sync_status_tx, sync_status_rx) = watch::channel(SyncStatus::Incomplete); + + // Spawn header builder returning a JoinHandle + let handle = HeaderBuilder::spawn( + CancellationToken::new(), + network_tx, + certificate_tx, + keypair, + db, + header_trigger_rx, + votes_rx, + digests_buffer, + committee, + sync_status_rx, + ); + + // Test that header builder waits for sync to complete + sleep(Duration::from_millis(100)).await; + sync_status_tx.send(SyncStatus::Complete).unwrap(); + + handle.abort(); +} + +#[tokio::test] +async fn test_header_builder_with_empty_digests() { + let (network_tx, _) = mpsc::channel(100); + let (certificate_tx, _) = mpsc::channel(100); + let keypair = Keypair::generate(); + let db = Arc::new(Db::new_in_memory().await.unwrap()); + let (header_trigger_tx, header_trigger_rx) = watch::channel((0, HashSet::new())); + let (votes_tx, votes_rx) = broadcast::channel(100); + let digests_buffer = Arc::new(tokio::sync::Mutex::new(CircularBuffer::new(100))); + let committee = Committee::new_test(); + let (sync_status_tx, sync_status_rx) = watch::channel(SyncStatus::Complete); + + let handle = HeaderBuilder::spawn( + CancellationToken::new(), + network_tx, + certificate_tx, + keypair, + db, + header_trigger_rx, + votes_rx, + digests_buffer, + committee, + sync_status_rx, + ); + + // trigger header building with empty digests + header_trigger_tx.send((1, HashSet::new())).unwrap(); + sleep(Duration::from_millis(100)).await; + + handle.abort(); +} + +#[tokio::test] +async fn test_header_builder_multiple_rounds() { + let (network_tx, mut network_rx) = mpsc::channel(100); + let (certificate_tx, mut cert_rx) = mpsc::channel(100); + let keypair = Keypair::generate(); + let db = Arc::new(Db::new_in_memory().await.unwrap()); + let (header_trigger_tx, header_trigger_rx) = watch::channel((0, HashSet::new())); + let (votes_tx, votes_rx) = broadcast::channel(100); + let digests_buffer = Arc::new(tokio::sync::Mutex::new(CircularBuffer::new(100))); + let committee = Committee::new_test(); + let (sync_status_tx, sync_status_rx) = watch::channel(SyncStatus::Complete); + + let handle = HeaderBuilder::spawn( + CancellationToken::new(), + network_tx, + certificate_tx, + keypair.clone(), + db, + header_trigger_rx, + votes_rx, + digests_buffer, + committee.clone(), + sync_status_rx, + ); + + // trigger multiple rounds and verify header building + for round in 1..=3 { + // Create certificates for this round + let mut certs = HashSet::new(); + let cert = Certificate::genesis([round as u8; 32]); + certs.insert(cert.clone()); + + // Trigger header building + header_trigger_tx.send((round, certs)).unwrap(); + + // Wait for header broadcast + let mut header_received = false; + while !header_received { + let network_request = tokio::time::timeout(Duration::from_secs(5), network_rx.recv()) + .await + .expect("Timed out waiting for network request") + .expect("Network channel closed unexpectedly"); + + match network_request { + NetworkRequest::BroadcastCounterparts(RequestPayload::Header(header)) => { + // Verify header round + assert_eq!(header.round, round); + + // Create and send enough votes to reach quorum (3 votes needed) + for _ in 0..3 { + let voting_keypair = Keypair::generate(); // Different keypair for each vote + let vote = Vote::from_header(header.clone(), &voting_keypair).unwrap(); + votes_tx + .send(ReceivedObject { + object: vote, + sender: PeerId::random(), + }) + .unwrap(); + } + + // Wait for certificate with timeout + let certificate = tokio::time::timeout(Duration::from_secs(5), cert_rx.recv()) + .await + .expect("Timed out waiting for certificate") + .expect("Certificate channel closed unexpectedly"); + + // Verify certificate + assert_eq!(certificate.round(), round); + assert!( + certificate.header().is_some(), + "Certificate should have a header" + ); + let header_id: HeaderId = header.id().into(); + assert_eq!(certificate.header().unwrap(), header_id); + + header_received = true; + } + NetworkRequest::BroadcastCounterparts(RequestPayload::Certificate(_)) => { + // Ignore certificate broadcasts, we verify them through the cert_rx channel + continue; + } + _ => panic!("Unexpected network request: {:?}", network_request), + } + } + + // Give some time for cleanup between rounds + sleep(Duration::from_millis(50)).await; + } + + // Clean shutdown + handle.abort(); + sleep(Duration::from_millis(50)).await; +} + +#[tokio::test] +async fn test_header_builder_quorum_timeout() { + let (network_tx, _) = mpsc::channel(100); + let (certificate_tx, _) = mpsc::channel(100); + let keypair = Keypair::generate(); + let db = Arc::new(Db::new_in_memory().await.unwrap()); + let (header_trigger_tx, header_trigger_rx) = watch::channel((0, HashSet::new())); + let (votes_tx, votes_rx) = broadcast::channel(100); + let digests_buffer = Arc::new(tokio::sync::Mutex::new(CircularBuffer::new(100))); + let committee = Committee::new_test(); + let (sync_status_tx, sync_status_rx) = watch::channel(SyncStatus::Complete); + + let handle = HeaderBuilder::spawn( + CancellationToken::new(), + network_tx, + certificate_tx, + keypair, + db, + header_trigger_rx, + votes_rx, + digests_buffer, + committee, + sync_status_rx, + ); + + // trigger header building but don't send any votes + let mut certs = HashSet::new(); + certs.insert(Certificate::genesis([0; 32])); + header_trigger_tx.send((1, certs)).unwrap(); + + sleep(Duration::from_millis(110)).await; + handle.abort(); +} + +#[tokio::test] +async fn test_header_builder_invalid_votes() { + let (network_tx, _) = mpsc::channel(100); + let (certificate_tx, _) = mpsc::channel(100); + let keypair = Keypair::generate(); + let db = Arc::new(Db::new_in_memory().await.unwrap()); + let (header_trigger_tx, header_trigger_rx) = watch::channel((0, HashSet::new())); + let (votes_tx, votes_rx) = broadcast::channel(100); + let digests_buffer = Arc::new(tokio::sync::Mutex::new(CircularBuffer::new(100))); + let committee = Committee::new_test(); + let (sync_status_tx, sync_status_rx) = watch::channel(SyncStatus::Complete); + + let handle = HeaderBuilder::spawn( + CancellationToken::new(), + network_tx, + certificate_tx, + keypair, + db, + header_trigger_rx, + votes_rx, + digests_buffer, + committee, + sync_status_rx, + ); + + // send invalid votes + let invalid_vote = Vote::new_test_invalid(); + votes_tx + .send(ReceivedObject { + object: invalid_vote, + sender: PeerId::random(), + }) + .unwrap(); + + sleep(Duration::from_millis(10)).await; + handle.abort(); +} + +#[tokio::test] +async fn test_header_builder_digest_buffer() { + let (network_tx, _) = mpsc::channel(100); + let (certificate_tx, _) = mpsc::channel(100); + let keypair = Keypair::generate(); + let db = Arc::new(Db::new_in_memory().await.unwrap()); + let (header_trigger_tx, header_trigger_rx) = watch::channel((0, HashSet::new())); + let (votes_tx, votes_rx) = broadcast::channel(100); + let digests_buffer = Arc::new(tokio::sync::Mutex::new(CircularBuffer::new(2))); // Small buffer + let committee = Committee::new_test(); + let (sync_status_tx, sync_status_rx) = watch::channel(SyncStatus::Complete); + + let handle = HeaderBuilder::spawn( + CancellationToken::new(), + network_tx, + certificate_tx, + keypair, + db, + header_trigger_rx, + votes_rx, + digests_buffer.clone(), + committee, + sync_status_rx, + ); + + // Add digests to buffer + { + let mut buffer = digests_buffer.lock().await; + buffer.push(BatchId::test_digest(1)); + buffer.push(BatchId::test_digest(2)); + buffer.push(BatchId::test_digest(3)); + } + + // Verify buffer state + { + let mut buffer = digests_buffer.lock().await; + let contents = buffer.drain(); + assert_eq!(contents.len(), 2); + assert_eq!(contents[0], BatchId::test_digest(2)); + assert_eq!(contents[1], BatchId::test_digest(3)); + } + + handle.abort(); +} diff --git a/src/settings/parser.rs b/src/settings/parser.rs index 4885573..ad4d540 100644 --- a/src/settings/parser.rs +++ b/src/settings/parser.rs @@ -1,4 +1,7 @@ -use libp2p::PeerId; +use libp2p::{ + identity::{self, ed25519}, + PeerId, +}; use serde::{Deserialize, Serialize}; use std::path::Path; @@ -37,9 +40,33 @@ impl Committee { } ((self.authorities.len() / 3) * 2 + 1) as u32 } + pub fn has_authority_id(&self, peer_id: &PeerId) -> bool { self.authorities.iter().any(|a| &a.authority_id == peer_id) } + + #[cfg(test)] + pub fn new_test() -> Self { + let mut authorities = Vec::new(); + + // Add three test authorities + for i in 0..3 { + let keypair = ed25519::Keypair::generate(); + let public_key = identity::PublicKey::from(keypair.public()); + let peer_id = PeerId::from_public_key(&public_key); + + let authority = AuthorityInfo { + authority_id: peer_id, + authority_pubkey: hex::encode(keypair.public().to_bytes()), + primary_address: ("127.0.0.1".to_string(), format!("800{}", i)), + stake: 1, + workers_addresses: vec![("127.0.0.1".to_string(), format!("900{}", i))], + }; + authorities.push(authority); + } + + Committee { authorities } + } } #[derive(Clone, Debug, Deserialize, Serialize)] diff --git a/src/synchronizer/mod.rs b/src/synchronizer/mod.rs index 2fdbb73..6adcaa0 100644 --- a/src/synchronizer/mod.rs +++ b/src/synchronizer/mod.rs @@ -210,3 +210,9 @@ pub enum FetchError { #[error("id error")] IdError, } + +#[cfg(test)] +mod tests { + mod fetcher_tests; + mod synchronizer_tests; +} diff --git a/src/synchronizer/tests/fetcher_tests.rs b/src/synchronizer/tests/fetcher_tests.rs new file mode 100644 index 0000000..068d81d --- /dev/null +++ b/src/synchronizer/tests/fetcher_tests.rs @@ -0,0 +1,424 @@ +use std::sync::Arc; +use tokio::sync::{broadcast, mpsc}; +use tokio_util::sync::CancellationToken; + +use crate::{ + network::Connect, + synchronizer::{ + fetcher::Fetcher, + traits::{DataProvider, IntoSyncRequest}, + RequestedObject, + }, + types::{ + batch::{Batch, BatchId}, + network::{ + NetworkRequest, ReceivedObject, RequestPayload, SyncData, SyncRequest, SyncResponse, + }, + traits::{AsBytes, Hash, Random}, + transaction::Transaction, + }, +}; + +use async_trait::async_trait; +use libp2p::PeerId; + +// Mock connector that allows testing network-related logic without real network dependencies +#[derive(Clone)] +struct MockConnector; + +#[async_trait] +impl Connect for MockConnector { + async fn dispatch(&self, request: &RequestPayload, peer_id: PeerId) -> anyhow::Result<()> { + Ok(()) + } +} + +#[async_trait] +impl Connect for Arc { + async fn dispatch(&self, request: &RequestPayload, peer_id: PeerId) -> anyhow::Result<()> { + self.as_ref().dispatch(request, peer_id).await + } +} + +// Mock connector that sleeps to simulate network delay +#[derive(Clone)] +struct SlowMockConnector; + +#[async_trait] +impl Connect for SlowMockConnector { + async fn dispatch(&self, _request: &RequestPayload, _peer_id: PeerId) -> anyhow::Result<()> { + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + Ok(()) + } +} + +#[async_trait] +impl Connect for Arc { + async fn dispatch(&self, request: &RequestPayload, peer_id: PeerId) -> anyhow::Result<()> { + self.as_ref().dispatch(request, peer_id).await + } +} + +#[derive(Debug, Clone)] +struct TestFetchable { + data: Vec, +} + +impl AsBytes for TestFetchable { + fn bytes(&self) -> Vec { + self.data.clone() + } +} + +impl Random for TestFetchable { + fn random(size: usize) -> Self { + let data = (0..size).map(|_| rand::random()).collect(); + Self { data } + } +} + +impl IntoSyncRequest for TestFetchable { + fn into_sync_request(&self) -> SyncRequest { + let digest = blake3::hash(&self.bytes()).into(); + SyncRequest::Batches(vec![digest]) + } +} + +#[derive(Clone)] +struct TestDataProvider { + peers: Vec, +} + +#[async_trait] +impl DataProvider for TestDataProvider { + async fn sources(&self) -> Box + Send> { + Box::new(self.peers.clone().into_iter()) + } +} + +#[tokio::test] +async fn test_fetcher_basic() { + let (network_tx, mut network_rx) = mpsc::channel(100); + let (sync_tx, sync_rx) = broadcast::channel(100); + let (commands_tx, commands_rx) = mpsc::channel(100); + let connector = Arc::new(MockConnector); + let token = CancellationToken::new(); + + let _handle = Fetcher::spawn( + token.clone(), + network_tx, + commands_rx, + sync_rx, + connector, + 10, + ); + + let test_data = TestFetchable { + data: vec![1, 2, 3], + }; + + let peer_id = PeerId::random(); + let request = Box::new(RequestedObject { + object: test_data.clone(), + source: Box::new(peer_id), + }); + commands_tx.send(request).await.unwrap(); + + // Verify request is sent + let request = network_rx.recv().await.unwrap(); + match request { + NetworkRequest::SendTo(pid, RequestPayload::SyncRequest(sync_req)) => { + assert_eq!(pid, peer_id); + let expected_digest = blake3::hash(&test_data.bytes()).into(); + assert_eq!(sync_req, SyncRequest::Batches(vec![expected_digest])); + } + _ => panic!("Expected SendTo request with SyncRequest payload"), + } + + // Send successful response + let tx = Transaction::random(32); + let batch = Batch::new(vec![tx]); + let sync_data = SyncData::Batches(vec![batch]); + let request_id = test_data.into_sync_request().digest(); + let response = SyncResponse::Success(request_id, sync_data); + let received = ReceivedObject { + object: response, + sender: peer_id, + }; + sync_tx.send(received).unwrap(); + + // Drop the sender to signal no more commands + drop(commands_tx); +} + +#[tokio::test] +async fn test_fetcher_empty() { + let (network_tx, _) = mpsc::channel(100); + let (_, sync_rx) = broadcast::channel(100); + let (commands_tx, commands_rx) = mpsc::channel(100); + + let _handle = Fetcher::spawn( + CancellationToken::new(), + network_tx, + commands_rx, + sync_rx, + Arc::new(MockConnector), + 10, + ); + + // Drop commands_tx to signal no more commands + drop(commands_tx); +} + +#[tokio::test] +async fn test_fetcher_single_request() { + let (network_tx, mut network_rx) = mpsc::channel(100); + let (sync_tx, sync_rx) = broadcast::channel(100); + let (commands_tx, commands_rx) = mpsc::channel(100); + + let test_data = TestFetchable { + data: vec![1, 2, 3], + }; + + let peer_id = PeerId::random(); + let request = Box::new(RequestedObject { + object: test_data.clone(), + source: Box::new(peer_id), + }); + + let _handle = Fetcher::spawn( + CancellationToken::new(), + network_tx.clone(), + commands_rx, + sync_rx, + Arc::new(MockConnector), + 10, + ); + + // Send the request through the commands channel + commands_tx.send(request).await.unwrap(); + + // Drop commands_tx to signal no more commands + drop(commands_tx); + + // verify request is sent + let request = network_rx.recv().await.unwrap(); + match request { + NetworkRequest::SendTo(peer_id, _) => { + // Get first peer from provider + let expected_peer = peer_id; + assert_eq!(peer_id, expected_peer); + } + _ => panic!("Expected SendTo request"), + } + + // Send successful response + let tx = Transaction::random(32); + let batch = Batch::new(vec![tx]); + let sync_data = SyncData::Batches(vec![batch]); + let request_id = test_data.into_sync_request().digest(); + let response = SyncResponse::Success(request_id, sync_data); + let received = ReceivedObject { + object: response, + sender: peer_id, + }; + sync_tx.send(received).unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; +} + +#[tokio::test] +async fn test_fetcher_timeout() { + let (requests_tx, _) = mpsc::channel(100); + let (responses_tx, responses_rx) = broadcast::channel(100); + let (commands_tx, commands_rx) = mpsc::channel(100); + + let test_data = TestFetchable { + data: vec![1, 2, 3], + }; + + let provider = Box::new(TestDataProvider { + peers: vec![PeerId::random()], + }); + + let requested_object = RequestedObject { + object: test_data, + source: provider, + }; + + let handle = Fetcher::spawn( + CancellationToken::new(), + requests_tx, + commands_rx, + responses_rx, + Arc::new(SlowMockConnector), + // Set a very short timeout for individual fetch operations + 1, + ); + + // Send the request through the commands channel + commands_tx.send(Box::new(requested_object)).await.unwrap(); + + // Drop commands_tx to signal no more commands + drop(commands_tx); + + // Run fetcher with a longer timeout to ensure it has time to process + let result = tokio::time::timeout(tokio::time::Duration::from_millis(100), handle).await; + + // The fetcher should complete successfully, but the fetch operation should have timed out + assert!(result.is_ok(), "Fetcher should complete"); +} + +#[tokio::test] +async fn test_fetcher_error_response() { + let (requests_tx, _) = mpsc::channel(100); + let (responses_tx, responses_rx) = broadcast::channel(100); + let (commands_tx, commands_rx) = mpsc::channel(100); + + let test_data = TestFetchable { + data: vec![1, 2, 3], + }; + + let provider = Box::new(TestDataProvider { + peers: vec![PeerId::random()], + }); + + let requested_object = RequestedObject { + object: test_data.clone(), + source: provider, + }; + + let _handle = Fetcher::spawn( + CancellationToken::new(), + requests_tx, + commands_rx, + responses_rx, + Arc::new(MockConnector), + 10, + ); + + // Send the request through the commands channel + commands_tx.send(Box::new(requested_object)).await.unwrap(); + + // Drop commands_tx to signal no more commands + drop(commands_tx); + + // Send failure response + let request_id = test_data.into_sync_request().digest(); + let response = SyncResponse::Failure(request_id); + let received = ReceivedObject { + object: response, + sender: PeerId::random(), + }; + responses_tx.send(received).unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; +} + +#[tokio::test] +async fn test_fetcher_multiple_peers() { + let (requests_tx, mut requests_rx) = mpsc::channel(100); + let (responses_tx, responses_rx) = broadcast::channel(100); + let (commands_tx, commands_rx) = mpsc::channel(100); + + let test_data = TestFetchable { + data: vec![1, 2, 3], + }; + + let peer1 = PeerId::random(); + let peer2 = PeerId::random(); + let peer3 = PeerId::random(); + let provider = Box::new(TestDataProvider { + peers: vec![peer1, peer2, peer3], + }); + + let requested_object = RequestedObject { + object: test_data.clone(), + source: provider.clone(), + }; + + let _handle = Fetcher::spawn( + CancellationToken::new(), + requests_tx, + commands_rx, + responses_rx, + Arc::new(MockConnector), + 10, + ); + + // Send the request through the commands channel + commands_tx.send(Box::new(requested_object)).await.unwrap(); + + // Drop commands_tx to signal no more commands + drop(commands_tx); + + // Verify first request is sent to peer1 + let request = requests_rx.recv().await.unwrap(); + match request { + NetworkRequest::SendTo(pid, RequestPayload::SyncRequest(sync_req)) => { + assert_eq!(pid, peer1, "First request should be sent to peer1"); + let expected_digest = blake3::hash(&test_data.bytes()).into(); + assert_eq!(sync_req, SyncRequest::Batches(vec![expected_digest])); + + // Send failure response from peer1 + let request_id = test_data.into_sync_request().digest(); + let response = SyncResponse::Failure(request_id); + let received = ReceivedObject { + object: response, + sender: peer1, + }; + responses_tx.send(received).unwrap(); + } + _ => panic!("Expected SendTo request with SyncRequest payload"), + }; + + // Verify second request is sent to peer2 + let request = requests_rx.recv().await.unwrap(); + match request { + NetworkRequest::SendTo(pid, RequestPayload::SyncRequest(sync_req)) => { + assert_eq!(pid, peer2, "Second request should be sent to peer2"); + let expected_digest = blake3::hash(&test_data.bytes()).into(); + assert_eq!(sync_req, SyncRequest::Batches(vec![expected_digest])); + + // Send failure response from peer2 + let request_id = test_data.into_sync_request().digest(); + let response = SyncResponse::Failure(request_id); + let received = ReceivedObject { + object: response, + sender: peer2, + }; + responses_tx.send(received).unwrap(); + } + _ => panic!("Expected SendTo request with SyncRequest payload"), + }; + + // Verify third request is sent to peer3 + let request = requests_rx.recv().await.unwrap(); + match request { + NetworkRequest::SendTo(pid, RequestPayload::SyncRequest(sync_req)) => { + assert_eq!(pid, peer3, "Third request should be sent to peer3"); + let expected_digest = blake3::hash(&test_data.bytes()).into(); + assert_eq!(sync_req, SyncRequest::Batches(vec![expected_digest])); + + // Send successful response from peer3 + let tx = Transaction::random(32); + let batch = Batch::new(vec![tx]); + let sync_data = SyncData::Batches(vec![batch]); + let request_id = test_data.into_sync_request().digest(); + let response = SyncResponse::Success(request_id, sync_data); + let received = ReceivedObject { + object: response, + sender: peer3, + }; + responses_tx.send(received).unwrap(); + } + _ => panic!("Expected SendTo request with SyncRequest payload"), + }; + + // Verify no more requests are sent after successful response + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + assert!( + requests_rx.try_recv().is_err(), + "No more requests should be sent after successful response" + ); +} diff --git a/src/synchronizer/tests/synchronizer_tests.rs b/src/synchronizer/tests/synchronizer_tests.rs new file mode 100644 index 0000000..af5b6b8 --- /dev/null +++ b/src/synchronizer/tests/synchronizer_tests.rs @@ -0,0 +1,553 @@ +use crate::{ + network::Connect, + synchronizer::{ + fetcher::Fetcher, + traits::{DataProvider, Fetch, IntoSyncRequest}, + RequestedObject, + }, + types::{ + batch::{Batch, BatchId}, + network::{ + NetworkRequest, ReceivedObject, RequestPayload, SyncData, SyncRequest, SyncResponse, + }, + traits::{AsBytes, Hash, Random}, + transaction::Transaction, + Digest, + }, +}; +use async_trait::async_trait; +use libp2p::PeerId; +use rstest::*; +use std::{collections::HashSet, sync::Arc}; +use tokio::sync::{broadcast, mpsc}; +use tokio_util::sync::CancellationToken; + +struct MockConnector; + +#[async_trait] +impl Connect for MockConnector { + async fn dispatch(&self, _request: &RequestPayload, _peer_id: PeerId) -> anyhow::Result<()> { + Ok(()) + } +} + +#[async_trait] +impl Connect for Arc { + async fn dispatch(&self, request: &RequestPayload, peer_id: PeerId) -> anyhow::Result<()> { + self.as_ref().dispatch(request, peer_id).await + } +} + +#[derive(Clone)] +struct MockData { + data: Vec, +} + +impl AsBytes for MockData { + fn bytes(&self) -> Vec { + self.data.clone() + } +} + +impl IntoSyncRequest for MockData { + fn into_sync_request(&self) -> SyncRequest { + let digest = blake3::hash(&self.bytes()).into(); + SyncRequest::Batches(vec![digest]) + } +} + +#[derive(Clone)] +struct MockDataProvider { + peers: Vec, +} + +#[async_trait] +impl DataProvider for MockDataProvider { + async fn sources(&self) -> Box + Send> { + Box::new(self.peers.clone().into_iter()) + } +} + +type TestReceivedObject = ReceivedObject; +type TestRequestedObject = RequestedObject; +type BoxedFetch = Box; + +#[fixture] +fn test_data() -> MockData { + MockData { + data: vec![1, 2, 3], + } +} + +#[fixture] +fn test_data_set() -> Vec { + vec![ + MockData { + data: vec![1, 2, 3], + }, + MockData { + data: vec![4, 5, 6], + }, + MockData { + data: vec![7, 8, 9], + }, + ] +} + +#[fixture] +fn channels() -> ( + mpsc::Sender, + mpsc::Receiver, + broadcast::Sender, + broadcast::Receiver, + mpsc::Sender, + mpsc::Receiver, +) { + let (network_tx, network_rx) = mpsc::channel(100); + let (sync_tx, sync_rx) = broadcast::channel(100); + let (commands_tx, commands_rx) = mpsc::channel(100); + ( + network_tx, + network_rx, + sync_tx, + sync_rx, + commands_tx, + commands_rx, + ) +} + +#[fixture] +fn peers() -> (PeerId, PeerId, PeerId) { + (PeerId::random(), PeerId::random(), PeerId::random()) +} + +#[fixture] +fn fetcher_handle( + channels: ( + mpsc::Sender, + mpsc::Receiver, + broadcast::Sender, + broadcast::Receiver, + mpsc::Sender, + mpsc::Receiver, + ), +) -> (CancellationToken, tokio::task::JoinHandle<()>) { + let (network_tx, _, _, sync_rx, _, commands_rx) = channels; + let token = CancellationToken::new(); + let handle = Fetcher::spawn( + token.clone(), + network_tx, + commands_rx, + sync_rx, + Arc::new(MockConnector), + 10, + ); + (token, handle) +} + +// Helper function to create a valid batch response +fn create_valid_response(request_id: Digest) -> SyncResponse { + let tx = Transaction::random(32); + let batch = Batch::new(vec![tx]); + let sync_data = SyncData::Batches(vec![batch]); + SyncResponse::Success(request_id, sync_data) +} + +// Helper function to create an invalid batch response +fn create_invalid_response() -> SyncResponse { + let different_data = MockData { + data: vec![4, 5, 6], + }; + let tx = Transaction::random(32); + let batch = Batch::new(vec![tx]); + let sync_data = SyncData::Batches(vec![batch]); + let request_id = different_data.into_sync_request().digest(); + SyncResponse::Success(request_id, sync_data) +} + +#[rstest] +#[tokio::test] +async fn test_synchronizer_invalid_response_data( + test_data: MockData, + channels: ( + mpsc::Sender, + mpsc::Receiver, + broadcast::Sender, + broadcast::Receiver, + mpsc::Sender, + mpsc::Receiver, + ), + peers: (PeerId, PeerId, PeerId), +) { + let (network_tx, mut network_rx, sync_tx, sync_rx, commands_tx, commands_rx) = channels; + let (peer_id, peer_id2, _) = peers; + + let provider = Box::new(MockDataProvider { + peers: vec![peer_id, peer_id2], + }); + + let requested_object = TestRequestedObject { + object: test_data.clone(), + source: provider, + }; + + let token = CancellationToken::new(); + let handle = Fetcher::spawn( + token.clone(), + network_tx, + commands_rx, + sync_rx, + Arc::new(MockConnector), + 10, + ); + + // Send the request through the commands channel + commands_tx + .send(Box::new(requested_object) as BoxedFetch) + .await + .unwrap(); + + // Verify initial request is sent with timeout + let initial_request = + tokio::time::timeout(tokio::time::Duration::from_secs(5), network_rx.recv()) + .await + .expect("Timed out waiting for initial request") + .expect("Channel closed unexpectedly"); + + match initial_request { + NetworkRequest::SendTo(pid, RequestPayload::SyncRequest(sync_req)) => { + assert_eq!(pid, peer_id); + let expected_digest = blake3::hash(&test_data.bytes()).into(); + assert_eq!(sync_req, SyncRequest::Batches(vec![expected_digest])); + + // Send invalid response + let response = create_invalid_response(); + sync_tx + .send(TestReceivedObject { + object: response, + sender: pid, + }) + .unwrap(); + } + _ => panic!("Expected initial SendTo request with SyncRequest payload"), + } + + // Give some time for the invalid response to be processed + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + // Verify retry request is sent with timeout + let retry_request = + tokio::time::timeout(tokio::time::Duration::from_secs(5), network_rx.recv()) + .await + .expect("Timed out waiting for retry request") + .expect("Channel closed unexpectedly"); + + match retry_request { + NetworkRequest::SendTo(pid, RequestPayload::SyncRequest(_)) => { + assert_eq!(pid, peer_id2, "Retry should use the second peer"); + + // Send valid response from second peer to complete the test + let request_id = test_data.into_sync_request().digest(); + let response = create_valid_response(request_id); + sync_tx + .send(TestReceivedObject { + object: response, + sender: peer_id2, + }) + .unwrap(); + } + _ => panic!("Expected retry request with SyncRequest payload"), + } + + // Verify no more requests are sent + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + assert!( + network_rx.try_recv().is_err(), + "No more requests should be sent after successful response" + ); + + // Clean shutdown + drop(commands_tx); + token.cancel(); + let _ = handle.await; +} + +#[rstest] +#[tokio::test] +async fn test_synchronizer_multiple_peers( + test_data: MockData, + channels: ( + mpsc::Sender, + mpsc::Receiver, + broadcast::Sender, + broadcast::Receiver, + mpsc::Sender, + mpsc::Receiver, + ), + peers: (PeerId, PeerId, PeerId), +) { + let (network_tx, mut network_rx, sync_tx, sync_rx, commands_tx, commands_rx) = channels; + let (peer1, peer2, peer3) = peers; + + let provider = Box::new(MockDataProvider { + peers: vec![peer1, peer2, peer3], + }); + + let requested_object = TestRequestedObject { + object: test_data.clone(), + source: provider.clone(), + }; + + let token = CancellationToken::new(); + let handle = Fetcher::spawn( + token.clone(), + network_tx, + commands_rx, + sync_rx, + Arc::new(MockConnector), + 10, + ); + + // Send the request through the commands channel + commands_tx + .send(Box::new(requested_object) as BoxedFetch) + .await + .unwrap(); + + // Verify first request is sent + let request = network_rx.recv().await.unwrap(); + match request { + NetworkRequest::SendTo(pid, RequestPayload::SyncRequest(sync_req)) => { + assert_eq!(pid, peer1); + let expected_digest = blake3::hash(&test_data.bytes()).into(); + assert_eq!(sync_req, SyncRequest::Batches(vec![expected_digest])); + + // Send successful response + let request_id = test_data.into_sync_request().digest(); + let response = create_valid_response(request_id); + sync_tx + .send(TestReceivedObject { + object: response, + sender: peer1, + }) + .unwrap(); + } + _ => panic!("Expected SendTo request with SyncRequest payload"), + } + + // Verify no more requests are sent + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + assert!( + network_rx.try_recv().is_err(), + "No more requests should be sent after successful response" + ); + + // Clean shutdown + drop(commands_tx); + token.cancel(); + let _ = handle.await; +} + +#[rstest] +#[tokio::test] +async fn test_synchronizer_concurrent_requests( + test_data_set: Vec, + channels: ( + mpsc::Sender, + mpsc::Receiver, + broadcast::Sender, + broadcast::Receiver, + mpsc::Sender, + mpsc::Receiver, + ), + peers: (PeerId, PeerId, PeerId), +) { + let (network_tx, mut network_rx, sync_tx, sync_rx, commands_tx, commands_rx) = channels; + let (peer1, peer2, peer3) = peers; + + let token = CancellationToken::new(); + let handle = Fetcher::spawn( + token.clone(), + network_tx, + commands_rx, + sync_rx, + Arc::new(MockConnector), + 10, + ); + + // Send multiple requests concurrently and track their digests + let mut expected_digests = HashSet::new(); + for test_data in test_data_set.iter() { + let provider = Box::new(MockDataProvider { + peers: vec![peer1, peer2, peer3], + }); + + let requested_object = TestRequestedObject { + object: test_data.clone(), + source: provider, + }; + + // Store the expected digest + let sync_req = test_data.into_sync_request(); + if let SyncRequest::Batches(digests) = sync_req { + expected_digests.insert(digests[0]); + } + + commands_tx + .send(Box::new(requested_object) as BoxedFetch) + .await + .unwrap(); + } + + // Verify all requests are processed + let mut received_digests = HashSet::new(); + for _ in 0..test_data_set.len() { + let request = network_rx.recv().await.unwrap(); + match request { + NetworkRequest::SendTo( + pid, + RequestPayload::SyncRequest(SyncRequest::Batches(digests)), + ) => { + let digest = digests[0]; + assert!( + expected_digests.contains(&digest), + "Received unexpected request digest: {:?}, expected one of: {:?}", + digest, + expected_digests + ); + received_digests.insert(digest); + + // Send successful response + let response = create_valid_response(digest); + sync_tx + .send(TestReceivedObject { + object: response, + sender: pid, + }) + .unwrap(); + } + _ => panic!("Expected SendTo request with SyncRequest payload"), + } + } + + // Verify we received requests for all expected digests + assert_eq!( + received_digests.len(), + expected_digests.len(), + "Should receive requests for all test data" + ); + assert!( + received_digests.is_subset(&expected_digests), + "All received digests should be from our requests" + ); + assert!( + expected_digests.is_subset(&received_digests), + "All expected digests should be requested" + ); + + // Verify no more requests are sent + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + assert!( + network_rx.try_recv().is_err(), + "No more requests should be sent after all responses" + ); + + // Clean shutdown + drop(commands_tx); + token.cancel(); + let _ = handle.await; +} + +#[rstest] +#[tokio::test] +async fn test_synchronizer_retry_on_failure( + test_data: MockData, + channels: ( + mpsc::Sender, + mpsc::Receiver, + broadcast::Sender, + broadcast::Receiver, + mpsc::Sender, + mpsc::Receiver, + ), + peers: (PeerId, PeerId, PeerId), +) { + let (network_tx, mut network_rx, sync_tx, sync_rx, commands_tx, commands_rx) = channels; + let (peer1, peer2, _) = peers; + + let provider = Box::new(MockDataProvider { + peers: vec![peer1, peer2], + }); + + let requested_object = TestRequestedObject { + object: test_data.clone(), + source: provider, + }; + + let token = CancellationToken::new(); + let handle = Fetcher::spawn( + token.clone(), + network_tx, + commands_rx, + sync_rx, + Arc::new(MockConnector), + 10, + ); + + // Send the request through the commands channel + commands_tx + .send(Box::new(requested_object) as BoxedFetch) + .await + .unwrap(); + + // Verify first request is sent + let request = network_rx.recv().await.unwrap(); + match request { + NetworkRequest::SendTo(pid, RequestPayload::SyncRequest(_)) => { + assert_eq!(pid, peer1, "First request should be sent to peer1"); + + // Send failure response + let request_id = test_data.into_sync_request().digest(); + let response = SyncResponse::Failure(request_id); + sync_tx + .send(TestReceivedObject { + object: response, + sender: peer1, + }) + .unwrap(); + } + _ => panic!("Expected SendTo request with SyncRequest payload"), + } + + // Verify retry request is sent to second peer + let retry_request = network_rx.recv().await.unwrap(); + match retry_request { + NetworkRequest::SendTo(pid, RequestPayload::SyncRequest(_)) => { + assert_eq!(pid, peer2, "Retry should be sent to peer2"); + assert_ne!(pid, peer1, "Retry should use a different peer"); + + // Send successful response from second peer + let request_id = test_data.into_sync_request().digest(); + let response = create_valid_response(request_id); + sync_tx + .send(TestReceivedObject { + object: response, + sender: peer2, + }) + .unwrap(); + } + _ => panic!("Expected retry request with SyncRequest payload"), + } + + // Verify no more requests are sent + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + assert!( + network_rx.try_recv().is_err(), + "No more requests should be sent after successful response" + ); + + // Clean shutdown + drop(commands_tx); + token.cancel(); + let _ = handle.await; +} diff --git a/src/types/dag.rs b/src/types/dag.rs index 659c32c..c14db86 100644 --- a/src/types/dag.rs +++ b/src/types/dag.rs @@ -57,25 +57,29 @@ where } /// Check if a vertex has all its parents in the DAG, returning an error containing the missing parents if not. pub fn check_parents(&self, vertex: &Vertex) -> Result<(), DagError> { - if vertex.layer == 0 { + // Only base layer vertices can be parentless + if vertex.layer == self.base_layer { + if vertex.parents.is_empty() { + return Ok(()); + } + } + + // All other vertices must have valid parents in previous layers + let found_parents: HashSet<_> = (self.base_layer..vertex.layer) + .rev() + .flat_map(|layer| self.vertices_by_layers.get(&layer).into_iter().flatten()) + .collect(); + + if vertex.parents.iter().all(|p| found_parents.contains(p)) { Ok(()) } else { - self.vertices_by_layers - .get(&(vertex.layer - 1)) - .map(|potential_parents| { - if vertex.parents.is_subset(potential_parents) { - Ok(()) - } else { - Err(DagError::MissingParents( - vertex - .parents - .difference(potential_parents) - .cloned() - .collect(), - )) - } - }) - .unwrap_or(Err(DagError::MissingParents(vertex.parents.clone()))) + let missing: HashSet<_> = vertex + .parents + .iter() + .filter(|p| !found_parents.contains(*p)) + .cloned() + .collect(); + Err(DagError::MissingParents(missing)) } } /// Insert a vertex in the DAG, returning an error if its parents are missing but inserting it anyway. @@ -116,6 +120,10 @@ where .map(|vertex| vertex.data.clone()) .collect() } + /// Get a vertex by its ID + pub fn get(&self, id: &str) -> Option<&Vertex> { + self.vertices.get(id) + } } impl Vertex diff --git a/src/types/mod.rs b/src/types/mod.rs index 5d0c17f..bbe43d8 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -43,3 +43,8 @@ impl AsBytes for Acknowledgment { self.0.to_vec() } } + +#[cfg(test)] +mod tests { + mod dag_tests; +} diff --git a/src/types/tests/dag_tests.rs b/src/types/tests/dag_tests.rs new file mode 100644 index 0000000..4ba8676 --- /dev/null +++ b/src/types/tests/dag_tests.rs @@ -0,0 +1,186 @@ +use std::collections::HashSet; + +use crate::types::{ + dag::{Dag, DagError, Vertex}, + traits::AsBytes, +}; + +#[derive(Clone, Debug)] +struct TestData { + value: u64, +} + +impl AsBytes for TestData { + fn bytes(&self) -> Vec { + self.value.to_be_bytes().to_vec() + } +} + +#[tokio::test] +async fn test_dag_creation_and_basic_ops() { + let base_layer: u64 = 0; + let mut dag: Dag = Dag::new(base_layer); + + // Test initial state + assert_eq!(dag.height(), base_layer); + assert_eq!(dag.base_layer(), base_layer); + + // Test vertex insertion + let data = TestData { value: 1 }; + let parents = HashSet::new(); + let vertex = Vertex::from_data(data, 1, parents); + let vertex_id = vertex.id().clone(); + + dag.insert(vertex).unwrap(); + assert_eq!(dag.height(), 1); + assert_eq!(dag.layer_size(1), 1); + + // Test vertex retrieval + let retrieved = dag.get(&vertex_id).unwrap(); + assert_eq!(retrieved.data().value, 1); + assert_eq!(*retrieved.layer(), 1); +} + +#[tokio::test] +async fn test_dag_parent_child_relationships() { + let mut dag: Dag = Dag::new(0); + + // Create parent vertex + let parent_data = TestData { value: 1 }; + let parent_vertex = Vertex::from_data(parent_data, 1, HashSet::new()); + let parent_id = parent_vertex.id().clone(); + dag.insert(parent_vertex).unwrap(); + + // Create child vertex with parent reference + let mut parents = HashSet::new(); + parents.insert(parent_id); + let child_data = TestData { value: 2 }; + let child_vertex = Vertex::from_data(child_data, 2, parents); + + dag.insert_checked(child_vertex).unwrap(); +} + +#[tokio::test] +async fn test_dag_invalid_parent() { + let mut dag: Dag = Dag::new(0); + + let mut parents = HashSet::new(); + parents.insert("non_existent_parent".to_string()); + let data = TestData { value: 1 }; + let vertex = Vertex::from_data(data, 1, parents); + + match dag.insert_checked(vertex) { + Err(DagError::MissingParents(_)) => (), + _ => panic!("Expected MissingParents error"), + } +} + +#[tokio::test] +async fn test_dag_layer_operations() { + let mut dag: Dag = Dag::new(0); + + // Insert vertices in different layers + for i in 1..=3 { + let data = TestData { value: i }; + let vertex = Vertex::from_data(data, i as u64, HashSet::new()); + dag.insert(vertex).unwrap(); + } + + // Test layer queries + assert_eq!(dag.layer_size(1), 1); + assert_eq!(dag.layer_size(2), 1); + assert_eq!(dag.layer_size(3), 1); + + let layer_2_vertices = dag.layer_vertices(2); + assert_eq!(layer_2_vertices.len(), 1); + assert_eq!(layer_2_vertices[0].data().value, 2); +} + +#[tokio::test] +async fn test_dag_multiple_parents() { + let mut dag: Dag = Dag::new(0); + + // Create two parent vertices + let parent1_data = TestData { value: 1 }; + let parent2_data = TestData { value: 2 }; + let parent1_vertex = Vertex::from_data(parent1_data, 1, HashSet::new()); + let parent2_vertex = Vertex::from_data(parent2_data, 1, HashSet::new()); + + let parent1_id = parent1_vertex.id().clone(); + let parent2_id = parent2_vertex.id().clone(); + + dag.insert(parent1_vertex).unwrap(); + dag.insert(parent2_vertex).unwrap(); + + // Create child with multiple parents + let mut parents = HashSet::new(); + parents.insert(parent1_id); + parents.insert(parent2_id); + + let child_data = TestData { value: 3 }; + let child_vertex = Vertex::from_data(child_data, 2, parents); + + dag.insert_checked(child_vertex).unwrap(); + assert_eq!(dag.layer_size(2), 1); +} + +#[tokio::test] +async fn test_dag_cyclic_insertion_prevention() { + let mut dag: Dag = Dag::new(0); + + // Create first vertex + let data1 = TestData { value: 1 }; + let vertex1 = Vertex::from_data(data1, 1, HashSet::new()); + let vertex1_id = vertex1.id().clone(); + dag.insert(vertex1).unwrap(); + + // Try to create a vertex in a lower layer referencing a higher layer + let mut parents = HashSet::new(); + parents.insert(vertex1_id); + let data2 = TestData { value: 2 }; + let vertex2 = Vertex::from_data(data2, 0, parents); + + assert!(dag.insert_checked(vertex2).is_err()); +} + +#[tokio::test] +async fn test_dag_complex_hierarchy() { + let mut dag: Dag = Dag::new(0); + + // Layer 1: Two vertices + let vertex1_1 = Vertex::from_data(TestData { value: 11 }, 1, HashSet::new()); + let vertex1_2 = Vertex::from_data(TestData { value: 12 }, 1, HashSet::new()); + let id1_1 = vertex1_1.id().clone(); + let id1_2 = vertex1_2.id().clone(); + + dag.insert(vertex1_1).unwrap(); + dag.insert(vertex1_2).unwrap(); + + // Layer 2: Two vertices, each with one parent + let mut parents2_1 = HashSet::new(); + parents2_1.insert(id1_1.clone()); + let mut parents2_2 = HashSet::new(); + parents2_2.insert(id1_2.clone()); + + let vertex2_1 = Vertex::from_data(TestData { value: 21 }, 2, parents2_1); + let vertex2_2 = Vertex::from_data(TestData { value: 22 }, 2, parents2_2); + let id2_1 = vertex2_1.id().clone(); + let id2_2 = vertex2_2.id().clone(); + + dag.insert_checked(vertex2_1).unwrap(); + dag.insert_checked(vertex2_2).unwrap(); + + // Layer 3: One vertex with both layer 2 vertices as parents + let mut parents3 = HashSet::new(); + parents3.insert(id2_1); + parents3.insert(id2_2); + + let vertex3 = Vertex::from_data(TestData { value: 31 }, 3, parents3); + dag.insert_checked(vertex3).unwrap(); + + // Verify the structure + assert_eq!(dag.layer_size(1), 2); + assert_eq!(dag.layer_size(2), 2); + assert_eq!(dag.layer_size(3), 1); + assert_eq!(dag.height(), 3); +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index bd8294b..7279f0b 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -70,4 +70,8 @@ impl CircularBuffer { self.buffer = vec![None; self.size]; res } + + pub fn iter(&self) -> impl Iterator { + self.buffer.iter().filter_map(|x| x.as_ref()) + } }