diff --git a/Cargo.lock b/Cargo.lock index 326be72083..459bfe7ef1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1269,6 +1269,7 @@ dependencies = [ "flate2", "flatten-serde-json", "fs2", + "futures", "glob", "indoc", "k256", diff --git a/crates/cheatnet/Cargo.toml b/crates/cheatnet/Cargo.toml index d328944110..66c651b585 100644 --- a/crates/cheatnet/Cargo.toml +++ b/crates/cheatnet/Cargo.toml @@ -45,6 +45,7 @@ shared.workspace = true rand.workspace = true [dev-dependencies] +futures.workspace = true ctor.workspace = true indoc.workspace = true rayon.workspace = true diff --git a/crates/cheatnet/src/forking/state.rs b/crates/cheatnet/src/forking/state.rs index 6c5358f874..9df4971de4 100644 --- a/crates/cheatnet/src/forking/state.rs +++ b/crates/cheatnet/src/forking/state.rs @@ -29,8 +29,10 @@ use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; use std::cell::RefCell; use std::collections::HashMap; +use std::future::Future; use std::io::Read; -use tokio::runtime::Runtime; +use tokio::runtime::Handle; +use tokio::task; use universal_sierra_compiler_api::{compile_sierra, SierraType}; use url::Url; @@ -38,7 +40,6 @@ use url::Url; pub struct ForkStateReader { client: JsonRpcClient, block_number: BlockNumber, - runtime: Runtime, cache: RefCell, } @@ -51,12 +52,11 @@ impl ForkStateReader { ), client: JsonRpcClient::new(HttpTransport::new(url)), block_number, - runtime: Runtime::new().expect("Could not instantiate Runtime"), }) } pub fn chain_id(&self) -> Result { - let id = self.runtime.block_on(self.client.chain_id())?; + let id = sync(self.client.chain_id())?; let id = parse_cairo_short_string(&id)?; Ok(ChainId::from(id)) } @@ -85,10 +85,7 @@ impl BlockInfoReader for ForkStateReader { return Ok(cache_hit); } - match self - .runtime - .block_on(self.client.get_block_with_tx_hashes(self.block_id())) - { + match sync(self.client.get_block_with_tx_hashes(self.block_id())) { Ok(MaybePendingBlockWithTxHashes::Block(block)) => { let block_info = BlockInfo { block_number: BlockNumber(block.block_number), @@ -125,7 +122,7 @@ impl StateReader for ForkStateReader { return Ok(cache_hit); } - match self.runtime.block_on(self.client.get_storage_at( + match sync(self.client.get_storage_at( Felt::from_(contract_address), Felt::from_(*key.0.key()), self.block_id(), @@ -149,7 +146,7 @@ impl StateReader for ForkStateReader { return Ok(cache_hit); } - match self.runtime.block_on( + match sync( self.client .get_nonce(self.block_id(), Felt::from_(contract_address)), ) { @@ -175,7 +172,7 @@ impl StateReader for ForkStateReader { return Ok(cache_hit); } - match self.runtime.block_on( + match sync( self.client .get_class_hash_at(self.block_id(), Felt::from_(contract_address)), ) { @@ -206,7 +203,7 @@ impl StateReader for ForkStateReader { if let Some(cache_hit) = cache.get_compiled_contract_class(&class_hash) { Ok(cache_hit) } else { - match self.runtime.block_on( + match sync( self.client .get_class(self.block_id(), Felt::from_(class_hash)), ) { @@ -283,3 +280,7 @@ impl StateReader for ForkStateReader { )) } } + +fn sync(action: impl Future) -> R { + task::block_in_place(move || Handle::current().block_on(action)) +} diff --git a/crates/cheatnet/tests/cheatcodes/cheat_caller_address.rs b/crates/cheatnet/tests/cheatcodes/cheat_caller_address.rs index 46adf3cf14..b4ae226927 100644 --- a/crates/cheatnet/tests/cheatcodes/cheat_caller_address.rs +++ b/crates/cheatnet/tests/cheatcodes/cheat_caller_address.rs @@ -291,8 +291,8 @@ fn cheat_caller_address_one_then_all() { ); } -#[test] -fn cheat_caller_address_cairo0_callback() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn cheat_caller_address_cairo0_callback() { let temp_dir = TempDir::new().unwrap(); let cached_state = create_fork_cached_state_at(53_631, temp_dir.path().to_str().unwrap()); let mut test_env = TestEnvironment::new(); diff --git a/crates/cheatnet/tests/cheatcodes/replace_bytecode.rs b/crates/cheatnet/tests/cheatcodes/replace_bytecode.rs index c32585ed1c..33530291cd 100644 --- a/crates/cheatnet/tests/cheatcodes/replace_bytecode.rs +++ b/crates/cheatnet/tests/cheatcodes/replace_bytecode.rs @@ -28,8 +28,8 @@ impl ReplaceBytecodeTrait for TestEnvironment { } } -#[test] -fn fork() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn fork() { let cache_dir = TempDir::new().unwrap(); let mut test_env = TestEnvironment::new(); test_env.cached_state = create_fork_cached_state_at(53_300, cache_dir.path().to_str().unwrap()); diff --git a/crates/cheatnet/tests/cheatcodes/spy_events.rs b/crates/cheatnet/tests/cheatcodes/spy_events.rs index 2496e0ac3e..289688c4d1 100644 --- a/crates/cheatnet/tests/cheatcodes/spy_events.rs +++ b/crates/cheatnet/tests/cheatcodes/spy_events.rs @@ -318,8 +318,8 @@ fn test_emitted_by_emit_events_syscall() { ); } -#[test] -fn capture_cairo0_event() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn capture_cairo0_event() { let temp_dir = TempDir::new().unwrap(); let mut cached_state = create_fork_cached_state_at(53_626, temp_dir.path().to_str().unwrap()); let mut cheatnet_state = CheatnetState::default(); diff --git a/crates/cheatnet/tests/starknet/cheat_fork.rs b/crates/cheatnet/tests/starknet/cheat_fork.rs index 45d6c9c996..b64693d8b7 100644 --- a/crates/cheatnet/tests/starknet/cheat_fork.rs +++ b/crates/cheatnet/tests/starknet/cheat_fork.rs @@ -13,7 +13,8 @@ const CAIRO0_TESTER_ADDRESS: &str = #[test_case("return_caller_address"; "when common call")] #[test_case("return_proxied_caller_address"; "when library call")] -fn cheat_caller_address_cairo0_contract(selector: &str) { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn cheat_caller_address_cairo0_contract(selector: &str) { let cache_dir = TempDir::new().unwrap(); let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap()); let mut cheatnet_state = CheatnetState::default(); @@ -68,7 +69,8 @@ fn cheat_caller_address_cairo0_contract(selector: &str) { #[test_case("return_block_number"; "when common call")] #[test_case("return_proxied_block_number"; "when library call")] -fn cheat_block_number_cairo0_contract(selector: &str) { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn cheat_block_number_cairo0_contract(selector: &str) { let cache_dir = TempDir::new().unwrap(); let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap()); let mut cheatnet_state = CheatnetState::default(); @@ -122,7 +124,8 @@ fn cheat_block_number_cairo0_contract(selector: &str) { #[test_case("return_block_timestamp"; "when common call")] #[test_case("return_proxied_block_timestamp"; "when library call")] -fn cheat_block_timestamp_cairo0_contract(selector: &str) { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn cheat_block_timestamp_cairo0_contract(selector: &str) { let cache_dir = TempDir::new().unwrap(); let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap()); let mut cheatnet_state = CheatnetState::default(); diff --git a/crates/cheatnet/tests/starknet/forking.rs b/crates/cheatnet/tests/starknet/forking.rs index 82af0fc512..bbb45b4642 100644 --- a/crates/cheatnet/tests/starknet/forking.rs +++ b/crates/cheatnet/tests/starknet/forking.rs @@ -12,16 +12,17 @@ use cheatnet::state::{BlockInfoReader, CheatnetState, ExtendedStateReader}; use conversions::byte_array::ByteArray; use conversions::string::TryFromHexStr; use conversions::IntoConv; -use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; +use futures::future::join_all; use runtime::EnhancedHintError; use serde_json::Value; use starknet_api::block::BlockNumber; use starknet_api::core::ContractAddress; use starknet_types_core::felt::Felt; use tempfile::TempDir; +use tokio::runtime::Handle; -#[test] -fn fork_simple() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn fork_simple() { let cache_dir = TempDir::new().unwrap(); let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap()); let mut cheatnet_state = CheatnetState::default(); @@ -61,8 +62,8 @@ fn fork_simple() { assert_success(output, &[Felt::from(100)]); } -#[test] -fn try_calling_nonexistent_contract() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn try_calling_nonexistent_contract() { let cache_dir = TempDir::new().unwrap(); let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap()); let mut cheatnet_state = CheatnetState::default(); @@ -83,8 +84,8 @@ fn try_calling_nonexistent_contract() { assert_panic(output, &panic_data_felts); } -#[test] -fn try_deploying_undeclared_class() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn try_deploying_undeclared_class() { let cache_dir = TempDir::new().unwrap(); let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap()); let mut cheatnet_state = CheatnetState::default(); @@ -103,8 +104,8 @@ fn try_deploying_undeclared_class() { }); } -#[test] -fn test_forking_at_block_number() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_forking_at_block_number() { let cache_dir = TempDir::new().unwrap(); { @@ -148,8 +149,8 @@ fn test_forking_at_block_number() { purge_cache(cache_dir.path().to_str().unwrap()); } -#[test] -fn call_forked_contract_from_other_contract() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn call_forked_contract_from_other_contract() { let cache_dir = TempDir::new().unwrap(); let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap()); let mut cheatnet_state = CheatnetState::default(); @@ -176,8 +177,8 @@ fn call_forked_contract_from_other_contract() { assert_success(output, &[Felt::from(0)]); } -#[test] -fn library_call_on_forked_class_hash() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn library_call_on_forked_class_hash() { let cache_dir = TempDir::new().unwrap(); let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap()); let mut cheatnet_state = CheatnetState::default(); @@ -222,8 +223,8 @@ fn library_call_on_forked_class_hash() { assert_success(output, &[Felt::from(100)]); } -#[test] -fn call_forked_contract_from_constructor() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn call_forked_contract_from_constructor() { let cache_dir = TempDir::new().unwrap(); let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap()); let mut cheatnet_state = CheatnetState::default(); @@ -255,8 +256,8 @@ fn call_forked_contract_from_constructor() { assert_success(output, &[Felt::from(0)]); } -#[test] -fn call_forked_contract_get_block_info_via_proxy() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn call_forked_contract_get_block_info_via_proxy() { let cache_dir = TempDir::new().unwrap(); let mut cached_fork_state = create_fork_cached_state_at(53_655, cache_dir.path().to_str().unwrap()); @@ -314,8 +315,8 @@ fn call_forked_contract_get_block_info_via_proxy() { ); } -#[test] -fn call_forked_contract_get_block_info_via_libcall() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn call_forked_contract_get_block_info_via_libcall() { let cache_dir = TempDir::new().unwrap(); let mut cached_fork_state = create_fork_cached_state_at(53_669, cache_dir.path().to_str().unwrap()); @@ -374,8 +375,8 @@ fn call_forked_contract_get_block_info_via_libcall() { ); } -#[test] -fn using_specified_block_nb_is_cached() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn using_specified_block_nb_is_cached() { let cache_dir = TempDir::new().unwrap(); let run_test = || { let mut cached_state = @@ -457,8 +458,8 @@ fn using_specified_block_nb_is_cached() { purge_cache(cache_dir.path().to_str().unwrap()); } -#[test] -fn test_cache_merging() { +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn test_cache_merging() { fn run_test(cache_dir: &str, contract_address: &str, balance: u64) { let mut cached_state = create_fork_cached_state_at(53_680, cache_dir); let _ = cached_state.state.get_block_info().unwrap(); @@ -549,27 +550,27 @@ fn test_cache_merging() { "0x1176a1bd84444c89232ec27754698e5d2e7e1a7f1539f12027f28b23ec9f3d8" ); }; - let cache_dir_str = cache_dir.path().to_str().unwrap(); + let cache_dir_str = cache_dir.path().to_string_lossy().to_string(); - run_test(cache_dir_str, contract_1_address, 0); - run_test(cache_dir_str, contract_2_address, 0); + run_test(&cache_dir_str, contract_1_address, 0); + run_test(&cache_dir_str, contract_2_address, 0); assert_cache(); purge_cache(cache_dir.path().to_str().unwrap()); // Parallel execution - [ - (cache_dir_str, contract_1_address, 0), - (cache_dir_str, contract_2_address, 0), - ] - .par_iter() - .for_each(|param_tpl| run_test(param_tpl.0, param_tpl.1, param_tpl.2)); + let tasks = [contract_1_address, contract_2_address].map(move |address| { + let cache_dir_str = cache_dir_str.clone(); + Handle::current().spawn(async move { run_test(&cache_dir_str, address, 0) }) + }); + + join_all(tasks).await; assert_cache(); } -#[test] -fn test_cached_block_info_merging() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_cached_block_info_merging() { fn run_test(cache_dir: &str, balance: u64, call_get_block_info: bool) { let mut cached_state = create_fork_cached_state_at(53_680, cache_dir); if call_get_block_info { @@ -636,8 +637,8 @@ fn test_cached_block_info_merging() { assert_cached_block_info(true); } -#[test] -fn test_calling_nonexistent_url() { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_calling_nonexistent_url() { let temp_dir = TempDir::new().unwrap(); let nonexistent_url = "http://nonexistent-node-address.com".parse().unwrap(); let mut cached_fork_state = CachedState::new(ExtendedStateReader {