From 49e91b0e940f1a28754b813d5b5fa15fa757ba38 Mon Sep 17 00:00:00 2001 From: Nasr Date: Thu, 16 Jan 2025 14:48:48 +0700 Subject: [PATCH] feat: finish up task manager --- crates/torii/indexer/src/engine.rs | 153 +++++------------------ crates/torii/indexer/src/task_manager.rs | 143 ++++++++++++++++----- crates/torii/sqlite/src/cache.rs | 16 +-- crates/torii/sqlite/src/erc.rs | 90 +++++++------ crates/torii/sqlite/src/lib.rs | 10 +- 5 files changed, 199 insertions(+), 213 deletions(-) diff --git a/crates/torii/indexer/src/engine.rs b/crates/torii/indexer/src/engine.rs index e1f18a7d30..12eef20880 100644 --- a/crates/torii/indexer/src/engine.rs +++ b/crates/torii/indexer/src/engine.rs @@ -1,15 +1,14 @@ use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; use std::fmt::Debug; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::sync::Arc; use std::time::Duration; use anyhow::Result; use bitflags::bitflags; -use cainome::cairo_serde::CairoSerde; use dojo_utils::provider as provider_utils; use dojo_world::contracts::world::WorldContractReader; -use futures_util::future::{join_all, try_join_all}; +use futures_util::future::join_all; use hashlink::LinkedHashMap; use starknet::core::types::{ BlockHashAndNumber, BlockId, BlockTag, EmittedEvent, Event, EventFilter, EventsPage, @@ -18,7 +17,7 @@ use starknet::core::types::{ }; use starknet::core::utils::get_selector_from_name; use starknet::providers::Provider; -use starknet_crypto::{poseidon_hash_many, Felt}; +use starknet_crypto::Felt; use tokio::sync::broadcast::Sender; use tokio::sync::mpsc::Sender as BoundedSender; use tokio::sync::Semaphore; @@ -47,6 +46,7 @@ use crate::processors::upgrade_model::UpgradeModelProcessor; use crate::processors::{ BlockProcessor, EventProcessor, EventProcessorConfig, TransactionProcessor, }; +use crate::task_manager::{ParallelizedEvent, TaskManager}; type EventProcessorMap

= HashMap>>>; @@ -233,17 +233,27 @@ impl Engine

{ let contracts = Arc::new( contracts.iter().map(|contract| (contract.address, contract.r#type)).collect(), ); + let world = Arc::new(world); + let processors = Arc::new(processors); + let max_concurrent_tasks = config.max_concurrent_tasks; + let event_processor_config = config.event_processor_config.clone(); Self { - world: Arc::new(world), - db, + world: world.clone(), + db: db.clone(), provider: Arc::new(provider), - processors: Arc::new(processors), + processors: processors.clone(), config, shutdown_tx, block_tx, contracts, - tasks: BTreeMap::new(), + task_manager: TaskManager::new( + db, + world, + processors, + max_concurrent_tasks, + event_processor_config, + ), } } @@ -531,7 +541,7 @@ impl Engine

{ } // Process parallelized events - self.process_tasks().await?; + self.task_manager.process_tasks().await?; self.db.update_cursors( data.block_number - 1, @@ -578,7 +588,7 @@ impl Engine

{ } // Process parallelized events - self.process_tasks().await?; + self.task_manager.process_tasks().await?; let last_block_timestamp = get_block_timestamp(&self.provider, data.latest_block_number).await?; @@ -588,77 +598,6 @@ impl Engine

{ Ok(()) } - async fn process_tasks(&mut self) -> Result<()> { - let semaphore = Arc::new(Semaphore::new(self.config.max_concurrent_tasks)); - - // Process each priority level sequentially - for (priority, task_group) in std::mem::take(&mut self.tasks) { - let mut handles = Vec::new(); - - // Process all tasks within this priority level concurrently - for (task_id, events) in task_group { - let db = self.db.clone(); - let world = self.world.clone(); - let semaphore = semaphore.clone(); - let processors = self.processors.clone(); - let event_processor_config = self.config.event_processor_config.clone(); - - handles.push(tokio::spawn(async move { - let _permit = semaphore.acquire().await?; - let mut local_db = db.clone(); - - // Process all events for this task sequentially - for (contract_type, event) in events { - let contract_processors = processors.get_event_processor(contract_type); - if let Some(processors) = contract_processors.get(&event.event.keys[0]) { - let processor = processors - .iter() - .find(|p| p.validate(&event.event)) - .expect("Must find at least one processor for the event"); - - debug!( - target: LOG_TARGET, - event_name = processor.event_key(), - task_id = %task_id, - priority = %priority, - "Processing parallelized event." - ); - - if let Err(e) = processor - .process( - &world, - &mut local_db, - event.block_number, - event.block_timestamp, - &event.event_id, - &event.event, - &event_processor_config, - ) - .await - { - error!( - target: LOG_TARGET, - event_name = processor.event_key(), - error = %e, - task_id = %task_id, - priority = %priority, - "Processing parallelized event." - ); - } - } - } - - Ok::<_, anyhow::Error>(()) - })); - } - - // Wait for all tasks in this priority level to complete before moving to next priority - try_join_all(handles).await?; - } - - Ok(()) - } - async fn process_transaction_with_events( &mut self, transaction_hash: Felt, @@ -870,50 +809,20 @@ impl Engine

{ .find(|p| p.validate(event)) .expect("Must find atleast one processor for the event"); - let (task_priority, task_identifier) = match processor.event_key().as_str() { - "ModelRegistered" | "EventRegistered" => { - let mut hasher = DefaultHasher::new(); - event.keys.iter().for_each(|k| k.hash(&mut hasher)); - let hash = hasher.finish(); - (0usize, hash) // Priority 0 (highest) for model/event registration - } - "StoreSetRecord" | "StoreUpdateRecord" | "StoreUpdateMember" | "StoreDelRecord" => { - let mut hasher = DefaultHasher::new(); - event.keys[1].hash(&mut hasher); - event.keys[2].hash(&mut hasher); - let hash = hasher.finish(); - (2usize, hash) // Priority 2 (lower) for store operations - } - "EventEmitted" => { - let mut hasher = DefaultHasher::new(); - - let keys = Vec::::cairo_deserialize(&event.data, 0).unwrap_or_else(|e| { - panic!("Expected EventEmitted keys to be well formed: {:?}", e); - }); - - // selector - event.keys[1].hash(&mut hasher); - // entity id - let entity_id = poseidon_hash_many(&keys); - entity_id.hash(&mut hasher); - - let hash = hasher.finish(); - (2usize, hash) // Priority 2 for event messages - } - _ => (0, 0), // No parallelization for other events - }; + let (task_priority, task_identifier) = (processor.task_priority(), processor.task_identifier(event)); + // if our event can be parallelized, we add it to the task manager if task_identifier != 0 { - self.tasks.entry(task_priority).or_default().entry(task_identifier).or_default().push( - ( + self.task_manager.add_parallelized_event( + task_priority, + task_identifier, + ParallelizedEvent { contract_type, - ParallelizedEvent { - event_id: event_id.to_string(), - event: event.clone(), - block_number, - block_timestamp, - }, - ), + event_id: event_id.to_string(), + event: event.clone(), + block_number, + block_timestamp, + }, ); } else { // Process non-parallelized events immediately diff --git a/crates/torii/indexer/src/task_manager.rs b/crates/torii/indexer/src/task_manager.rs index 1dbc1b7690..5dabe3fb20 100644 --- a/crates/torii/indexer/src/task_manager.rs +++ b/crates/torii/indexer/src/task_manager.rs @@ -2,11 +2,17 @@ use std::{ collections::{BTreeMap, HashMap}, sync::Arc, }; - +use anyhow::Result; +use dojo_world::contracts::WorldContractReader; +use futures_util::future::try_join_all; use starknet::{core::types::Event, providers::Provider}; -use torii_sqlite::types::ContractType; +use tokio::sync::Semaphore; +use torii_sqlite::{types::ContractType, Sql}; +use tracing::{debug, error}; + +use crate::{engine::Processors, processors::EventProcessorConfig}; -use crate::engine::Processors; +const LOG_TARGET: &str = "torii::indexer::task_manager"; pub type TaskId = u64; type TaskPriority = usize; @@ -21,47 +27,116 @@ pub struct ParallelizedEvent { } pub struct TaskManager { + db: Sql, + world: Arc>, tasks: BTreeMap>>, processors: Arc>, + max_concurrent_tasks: usize, + event_processor_config: EventProcessorConfig, } impl TaskManager

{ - pub fn new(processors: Arc>) -> Self { - Self { tasks: BTreeMap::new(), processors } - } - - pub fn add_parallelized_event(&mut self, parallelized_event: ParallelizedEvent) -> TaskId { - let event_key = parallelized_event.event.keys[0]; - let processor = self - .processors - .get_event_processor(parallelized_event.contract_type) - .get(&event_key) - .unwrap() - .iter() - .find(|p| p.validate(¶llelized_event.event)) - .unwrap(); - let priority = processor.task_priority(); - let task_id = processor.task_identifier(¶llelized_event.event); - - if task_id != 0 { - self.tasks - .entry(priority) - .or_default() - .entry(task_id) - .or_default() - .push(parallelized_event); + pub fn new( + db: Sql, + world: Arc>, + processors: Arc>, + max_concurrent_tasks: usize, + event_processor_config: EventProcessorConfig, + ) -> Self { + Self { + db, + world, + tasks: BTreeMap::new(), + processors, + max_concurrent_tasks, + event_processor_config, } - - task_id } - pub fn take_tasks( + pub fn add_parallelized_event( &mut self, - ) -> BTreeMap>> { - std::mem::take(&mut self.tasks) + priority: TaskPriority, + task_identifier: TaskId, + parallelized_event: ParallelizedEvent, + ) { + self.tasks + .entry(priority) + .or_default() + .entry(task_identifier) + .or_default() + .push(parallelized_event); } - pub fn is_empty(&self) -> bool { - self.tasks.is_empty() + pub async fn process_tasks( + &mut self + ) -> Result<()> { + let semaphore = Arc::new(Semaphore::new(self.max_concurrent_tasks)); + + // Process each priority level sequentially + for (priority, task_group) in std::mem::take(&mut self.tasks) { + let mut handles = Vec::new(); + + // Process all tasks within this priority level concurrently + for (task_id, events) in task_group { + let db = self.db.clone(); + let world = self.world.clone(); + let semaphore = semaphore.clone(); + let processors = self.processors.clone(); + let event_processor_config = self.event_processor_config.clone(); + + handles.push(tokio::spawn(async move { + let _permit = semaphore.acquire().await?; + let mut local_db = db.clone(); + + // Process all events for this task sequentially + for ParallelizedEvent { contract_type, event, block_number, block_timestamp, event_id } in events { + let contract_processors = processors.get_event_processor(contract_type); + if let Some(processors) = contract_processors.get(&event.keys[0]) { + let processor = processors + .iter() + .find(|p| p.validate(&event)) + .expect("Must find at least one processor for the event"); + + debug!( + target: LOG_TARGET, + event_name = processor.event_key(), + task_id = %task_id, + priority = %priority, + "Processing parallelized event." + ); + + if let Err(e) = processor + .process( + &world, + &mut local_db, + block_number, + block_timestamp, + &event_id, + &event, + &event_processor_config, + ) + .await + { + error!( + target: LOG_TARGET, + event_name = processor.event_key(), + error = %e, + task_id = %task_id, + priority = %priority, + "Processing parallelized event." + ); + } + } + } + + Ok::<_, anyhow::Error>(()) + })); + } + + // Wait for all tasks in this priority level to complete before moving to next priority + try_join_all(handles).await?; + } + + Ok(()) } } diff --git a/crates/torii/sqlite/src/cache.rs b/crates/torii/sqlite/src/cache.rs index bbfad566db..5b4dd71a1f 100644 --- a/crates/torii/sqlite/src/cache.rs +++ b/crates/torii/sqlite/src/cache.rs @@ -118,13 +118,13 @@ impl ModelCache { #[derive(Debug)] pub struct LocalCache { - pub erc_cache: HashMap<(ContractType, String), I256>, - pub token_id_registry: HashSet, + pub erc_cache: RwLock>, + pub token_id_registry: RwLock>, } impl Clone for LocalCache { fn clone(&self) -> Self { - Self { erc_cache: HashMap::new(), token_id_registry: self.token_id_registry.clone() } + Self { erc_cache: RwLock::new(HashMap::new()), token_id_registry: RwLock::new(HashSet::new()) } } } @@ -139,14 +139,14 @@ impl LocalCache { let token_id_registry = token_id_registry.into_iter().map(|token_id| token_id.0).collect(); - Self { erc_cache: HashMap::new(), token_id_registry } + Self { erc_cache: RwLock::new(HashMap::new()), token_id_registry: RwLock::new(token_id_registry) } } - pub fn contains_token_id(&self, token_id: &str) -> bool { - self.token_id_registry.contains(token_id) + pub async fn contains_token_id(&self, token_id: &str) -> bool { + self.token_id_registry.read().await.contains(token_id) } - pub fn register_token_id(&mut self, token_id: String) { - self.token_id_registry.insert(token_id); + pub async fn register_token_id(&self, token_id: String) { + self.token_id_registry.write().await.insert(token_id); } } diff --git a/crates/torii/sqlite/src/erc.rs b/crates/torii/sqlite/src/erc.rs index f11f7988c1..427a49a8d8 100644 --- a/crates/torii/sqlite/src/erc.rs +++ b/crates/torii/sqlite/src/erc.rs @@ -36,7 +36,7 @@ impl Sql { // contract_address let token_id = felt_to_sql_string(&contract_address); - let token_exists: bool = self.local_cache.contains_token_id(&token_id); + let token_exists: bool = self.local_cache.contains_token_id(&token_id).await; if !token_exists { self.register_erc20_token_metadata(contract_address, &token_id, provider).await?; @@ -52,26 +52,25 @@ impl Sql { event_id, )?; - if from_address != Felt::ZERO { - // from_address/contract_address/ - let from_balance_id = felts_to_sql_string(&[from_address, contract_address]); - let from_balance = self - .local_cache - .erc_cache - .entry((ContractType::ERC20, from_balance_id)) - .or_default(); - *from_balance -= I256::from(amount); - } - - if to_address != Felt::ZERO { - let to_balance_id = felts_to_sql_string(&[to_address, contract_address]); - let to_balance = - self.local_cache.erc_cache.entry((ContractType::ERC20, to_balance_id)).or_default(); - *to_balance += I256::from(amount); + { + let mut erc_cache = self.local_cache.erc_cache.write().await; + if from_address != Felt::ZERO { + // from_address/contract_address/ + let from_balance_id = felts_to_sql_string(&[from_address, contract_address]); + let from_balance = + erc_cache.entry((ContractType::ERC20, from_balance_id)).or_default(); + *from_balance -= I256::from(amount); + } + + if to_address != Felt::ZERO { + let to_balance_id = felts_to_sql_string(&[to_address, contract_address]); + let to_balance = erc_cache.entry((ContractType::ERC20, to_balance_id)).or_default(); + *to_balance += I256::from(amount); + } } let block_id = BlockId::Number(block_number); - if self.local_cache.erc_cache.len() >= 100000 { + if self.local_cache.erc_cache.read().await.len() >= 100000 { self.flush().await.with_context(|| "Failed to flush in handle_erc20_transfer")?; self.apply_cache_diff(block_id).await?; } @@ -93,7 +92,7 @@ impl Sql { // contract_address:id let actual_token_id = token_id; let token_id = felt_and_u256_to_sql_string(&contract_address, &token_id); - let token_exists: bool = self.local_cache.contains_token_id(&token_id); + let token_exists: bool = self.local_cache.contains_token_id(&token_id).await; if !token_exists { self.register_erc721_token_metadata(contract_address, &token_id, actual_token_id) @@ -111,30 +110,31 @@ impl Sql { )?; // from_address/contract_address:id - if from_address != Felt::ZERO { - let from_balance_id = - format!("{}{SQL_FELT_DELIMITER}{}", felt_to_sql_string(&from_address), &token_id); - let from_balance = self - .local_cache - .erc_cache - .entry((ContractType::ERC721, from_balance_id)) - .or_default(); - *from_balance -= I256::from(1u8); + { + let mut erc_cache = self.local_cache.erc_cache.write().await; + if from_address != Felt::ZERO { + let from_balance_id = format!( + "{}{SQL_FELT_DELIMITER}{}", + felt_to_sql_string(&from_address), + &token_id + ); + let from_balance = + erc_cache.entry((ContractType::ERC721, from_balance_id)).or_default(); + *from_balance -= I256::from(1u8); + } + + if to_address != Felt::ZERO { + let to_balance_id = + format!("{}{SQL_FELT_DELIMITER}{}", felt_to_sql_string(&to_address), &token_id); + let to_balance = + erc_cache.entry((ContractType::ERC721, to_balance_id)).or_default(); + *to_balance += I256::from(1u8); + } } - if to_address != Felt::ZERO { - let to_balance_id = - format!("{}{SQL_FELT_DELIMITER}{}", felt_to_sql_string(&to_address), &token_id); - let to_balance = self - .local_cache - .erc_cache - .entry((ContractType::ERC721, to_balance_id)) - .or_default(); - *to_balance += I256::from(1u8); - } let block_id = BlockId::Number(block_number); - if self.local_cache.erc_cache.len() >= 100000 { + if self.local_cache.erc_cache.read().await.len() >= 100000 { self.flush().await.with_context(|| "Failed to flush in handle_erc721_transfer")?; self.apply_cache_diff(block_id).await?; } @@ -215,7 +215,7 @@ impl Sql { }), ))?; - self.local_cache.register_token_id(token_id.to_string()); + self.local_cache.register_token_id(token_id.to_string()).await; Ok(()) } @@ -240,7 +240,7 @@ impl Sql { // this cache is used while applying the cache diff // so we need to make sure that all RegisterErc*Token queries // are applied before the cache diff is applied - self.local_cache.register_token_id(token_id.to_string()); + self.local_cache.register_token_id(token_id.to_string()).await; Ok(()) } @@ -279,15 +279,13 @@ impl Sql { } pub async fn apply_cache_diff(&mut self, block_id: BlockId) -> Result<()> { - if !self.local_cache.erc_cache.is_empty() { + if !self.local_cache.erc_cache.read().await.is_empty() { + let mut erc_cache = self.local_cache.erc_cache.write().await; self.executor.send(QueryMessage::new( "".to_string(), vec![], QueryType::ApplyBalanceDiff(ApplyBalanceDiffQuery { - erc_cache: mem::replace( - &mut self.local_cache.erc_cache, - HashMap::with_capacity(64), - ), + erc_cache: mem::replace(&mut erc_cache, HashMap::with_capacity(64)), block_id, }), ))?; diff --git a/crates/torii/sqlite/src/lib.rs b/crates/torii/sqlite/src/lib.rs index d8bbcc4dfa..4349b88718 100644 --- a/crates/torii/sqlite/src/lib.rs +++ b/crates/torii/sqlite/src/lib.rs @@ -43,8 +43,7 @@ pub struct Sql { pub pool: Pool, pub executor: UnboundedSender, model_cache: Arc, - // when SQL struct is cloned a empty local_cache is created - local_cache: LocalCache, + local_cache: Arc, } #[derive(Debug, Clone)] @@ -75,7 +74,12 @@ impl Sql { } let local_cache = LocalCache::new(pool.clone()).await; - let db = Self { pool: pool.clone(), executor, model_cache, local_cache }; + let db = Self { + pool: pool.clone(), + executor, + model_cache, + local_cache: Arc::new(local_cache), + }; db.execute().await?;