Skip to content

Commit

Permalink
Remove second tokio runtime (#2831)
Browse files Browse the repository at this point in the history
Close #901
  • Loading branch information
Draggu authored Jan 9, 2025
1 parent 6aaeb5c commit 9967fd1
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 57 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/cheatnet/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ shared.workspace = true
rand.workspace = true

[dev-dependencies]
futures.workspace = true
ctor.workspace = true
indoc.workspace = true
rayon.workspace = true
Expand Down
25 changes: 13 additions & 12 deletions crates/cheatnet/src/forking/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,17 @@ use starknet_api::state::StorageKey;
use starknet_types_core::felt::Felt;
use std::cell::RefCell;
use std::collections::HashMap;
use std::future::Future;
use std::io::Read;
use tokio::runtime::Runtime;
use tokio::runtime::Handle;
use tokio::task;
use universal_sierra_compiler_api::{compile_sierra, SierraType};
use url::Url;

#[derive(Debug)]
pub struct ForkStateReader {
client: JsonRpcClient<HttpTransport>,
block_number: BlockNumber,
runtime: Runtime,
cache: RefCell<ForkCache>,
}

Expand All @@ -51,12 +52,11 @@ impl ForkStateReader {
),
client: JsonRpcClient::new(HttpTransport::new(url)),
block_number,
runtime: Runtime::new().expect("Could not instantiate Runtime"),
})
}

pub fn chain_id(&self) -> Result<ChainId> {
let id = self.runtime.block_on(self.client.chain_id())?;
let id = sync(self.client.chain_id())?;
let id = parse_cairo_short_string(&id)?;
Ok(ChainId::from(id))
}
Expand Down Expand Up @@ -85,10 +85,7 @@ impl BlockInfoReader for ForkStateReader {
return Ok(cache_hit);
}

match self
.runtime
.block_on(self.client.get_block_with_tx_hashes(self.block_id()))
{
match sync(self.client.get_block_with_tx_hashes(self.block_id())) {
Ok(MaybePendingBlockWithTxHashes::Block(block)) => {
let block_info = BlockInfo {
block_number: BlockNumber(block.block_number),
Expand Down Expand Up @@ -125,7 +122,7 @@ impl StateReader for ForkStateReader {
return Ok(cache_hit);
}

match self.runtime.block_on(self.client.get_storage_at(
match sync(self.client.get_storage_at(
Felt::from_(contract_address),
Felt::from_(*key.0.key()),
self.block_id(),
Expand All @@ -149,7 +146,7 @@ impl StateReader for ForkStateReader {
return Ok(cache_hit);
}

match self.runtime.block_on(
match sync(
self.client
.get_nonce(self.block_id(), Felt::from_(contract_address)),
) {
Expand All @@ -175,7 +172,7 @@ impl StateReader for ForkStateReader {
return Ok(cache_hit);
}

match self.runtime.block_on(
match sync(
self.client
.get_class_hash_at(self.block_id(), Felt::from_(contract_address)),
) {
Expand Down Expand Up @@ -206,7 +203,7 @@ impl StateReader for ForkStateReader {
if let Some(cache_hit) = cache.get_compiled_contract_class(&class_hash) {
Ok(cache_hit)
} else {
match self.runtime.block_on(
match sync(
self.client
.get_class(self.block_id(), Felt::from_(class_hash)),
) {
Expand Down Expand Up @@ -283,3 +280,7 @@ impl StateReader for ForkStateReader {
))
}
}

fn sync<R>(action: impl Future<Output = R>) -> R {
task::block_in_place(move || Handle::current().block_on(action))
}
4 changes: 2 additions & 2 deletions crates/cheatnet/tests/cheatcodes/cheat_caller_address.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ fn cheat_caller_address_one_then_all() {
);
}

#[test]
fn cheat_caller_address_cairo0_callback() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn cheat_caller_address_cairo0_callback() {
let temp_dir = TempDir::new().unwrap();
let cached_state = create_fork_cached_state_at(53_631, temp_dir.path().to_str().unwrap());
let mut test_env = TestEnvironment::new();
Expand Down
4 changes: 2 additions & 2 deletions crates/cheatnet/tests/cheatcodes/replace_bytecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ impl ReplaceBytecodeTrait for TestEnvironment {
}
}

#[test]
fn fork() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn fork() {
let cache_dir = TempDir::new().unwrap();
let mut test_env = TestEnvironment::new();
test_env.cached_state = create_fork_cached_state_at(53_300, cache_dir.path().to_str().unwrap());
Expand Down
4 changes: 2 additions & 2 deletions crates/cheatnet/tests/cheatcodes/spy_events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ fn test_emitted_by_emit_events_syscall() {
);
}

#[test]
fn capture_cairo0_event() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn capture_cairo0_event() {
let temp_dir = TempDir::new().unwrap();
let mut cached_state = create_fork_cached_state_at(53_626, temp_dir.path().to_str().unwrap());
let mut cheatnet_state = CheatnetState::default();
Expand Down
9 changes: 6 additions & 3 deletions crates/cheatnet/tests/starknet/cheat_fork.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ const CAIRO0_TESTER_ADDRESS: &str =

#[test_case("return_caller_address"; "when common call")]
#[test_case("return_proxied_caller_address"; "when library call")]
fn cheat_caller_address_cairo0_contract(selector: &str) {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn cheat_caller_address_cairo0_contract(selector: &str) {
let cache_dir = TempDir::new().unwrap();
let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap());
let mut cheatnet_state = CheatnetState::default();
Expand Down Expand Up @@ -68,7 +69,8 @@ fn cheat_caller_address_cairo0_contract(selector: &str) {

#[test_case("return_block_number"; "when common call")]
#[test_case("return_proxied_block_number"; "when library call")]
fn cheat_block_number_cairo0_contract(selector: &str) {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn cheat_block_number_cairo0_contract(selector: &str) {
let cache_dir = TempDir::new().unwrap();
let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap());
let mut cheatnet_state = CheatnetState::default();
Expand Down Expand Up @@ -122,7 +124,8 @@ fn cheat_block_number_cairo0_contract(selector: &str) {

#[test_case("return_block_timestamp"; "when common call")]
#[test_case("return_proxied_block_timestamp"; "when library call")]
fn cheat_block_timestamp_cairo0_contract(selector: &str) {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn cheat_block_timestamp_cairo0_contract(selector: &str) {
let cache_dir = TempDir::new().unwrap();
let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap());
let mut cheatnet_state = CheatnetState::default();
Expand Down
73 changes: 37 additions & 36 deletions crates/cheatnet/tests/starknet/forking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@ use cheatnet::state::{BlockInfoReader, CheatnetState, ExtendedStateReader};
use conversions::byte_array::ByteArray;
use conversions::string::TryFromHexStr;
use conversions::IntoConv;
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
use futures::future::join_all;
use runtime::EnhancedHintError;
use serde_json::Value;
use starknet_api::block::BlockNumber;
use starknet_api::core::ContractAddress;
use starknet_types_core::felt::Felt;
use tempfile::TempDir;
use tokio::runtime::Handle;

#[test]
fn fork_simple() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn fork_simple() {
let cache_dir = TempDir::new().unwrap();
let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap());
let mut cheatnet_state = CheatnetState::default();
Expand Down Expand Up @@ -61,8 +62,8 @@ fn fork_simple() {
assert_success(output, &[Felt::from(100)]);
}

#[test]
fn try_calling_nonexistent_contract() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn try_calling_nonexistent_contract() {
let cache_dir = TempDir::new().unwrap();
let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap());
let mut cheatnet_state = CheatnetState::default();
Expand All @@ -83,8 +84,8 @@ fn try_calling_nonexistent_contract() {
assert_panic(output, &panic_data_felts);
}

#[test]
fn try_deploying_undeclared_class() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn try_deploying_undeclared_class() {
let cache_dir = TempDir::new().unwrap();
let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap());
let mut cheatnet_state = CheatnetState::default();
Expand All @@ -103,8 +104,8 @@ fn try_deploying_undeclared_class() {
});
}

#[test]
fn test_forking_at_block_number() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_forking_at_block_number() {
let cache_dir = TempDir::new().unwrap();

{
Expand Down Expand Up @@ -148,8 +149,8 @@ fn test_forking_at_block_number() {
purge_cache(cache_dir.path().to_str().unwrap());
}

#[test]
fn call_forked_contract_from_other_contract() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn call_forked_contract_from_other_contract() {
let cache_dir = TempDir::new().unwrap();
let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap());
let mut cheatnet_state = CheatnetState::default();
Expand All @@ -176,8 +177,8 @@ fn call_forked_contract_from_other_contract() {
assert_success(output, &[Felt::from(0)]);
}

#[test]
fn library_call_on_forked_class_hash() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn library_call_on_forked_class_hash() {
let cache_dir = TempDir::new().unwrap();
let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap());
let mut cheatnet_state = CheatnetState::default();
Expand Down Expand Up @@ -222,8 +223,8 @@ fn library_call_on_forked_class_hash() {
assert_success(output, &[Felt::from(100)]);
}

#[test]
fn call_forked_contract_from_constructor() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn call_forked_contract_from_constructor() {
let cache_dir = TempDir::new().unwrap();
let mut cached_fork_state = create_fork_cached_state(cache_dir.path().to_str().unwrap());
let mut cheatnet_state = CheatnetState::default();
Expand Down Expand Up @@ -255,8 +256,8 @@ fn call_forked_contract_from_constructor() {
assert_success(output, &[Felt::from(0)]);
}

#[test]
fn call_forked_contract_get_block_info_via_proxy() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn call_forked_contract_get_block_info_via_proxy() {
let cache_dir = TempDir::new().unwrap();
let mut cached_fork_state =
create_fork_cached_state_at(53_655, cache_dir.path().to_str().unwrap());
Expand Down Expand Up @@ -314,8 +315,8 @@ fn call_forked_contract_get_block_info_via_proxy() {
);
}

#[test]
fn call_forked_contract_get_block_info_via_libcall() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn call_forked_contract_get_block_info_via_libcall() {
let cache_dir = TempDir::new().unwrap();
let mut cached_fork_state =
create_fork_cached_state_at(53_669, cache_dir.path().to_str().unwrap());
Expand Down Expand Up @@ -374,8 +375,8 @@ fn call_forked_contract_get_block_info_via_libcall() {
);
}

#[test]
fn using_specified_block_nb_is_cached() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn using_specified_block_nb_is_cached() {
let cache_dir = TempDir::new().unwrap();
let run_test = || {
let mut cached_state =
Expand Down Expand Up @@ -457,8 +458,8 @@ fn using_specified_block_nb_is_cached() {
purge_cache(cache_dir.path().to_str().unwrap());
}

#[test]
fn test_cache_merging() {
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_cache_merging() {
fn run_test(cache_dir: &str, contract_address: &str, balance: u64) {
let mut cached_state = create_fork_cached_state_at(53_680, cache_dir);
let _ = cached_state.state.get_block_info().unwrap();
Expand Down Expand Up @@ -549,27 +550,27 @@ fn test_cache_merging() {
"0x1176a1bd84444c89232ec27754698e5d2e7e1a7f1539f12027f28b23ec9f3d8"
);
};
let cache_dir_str = cache_dir.path().to_str().unwrap();
let cache_dir_str = cache_dir.path().to_string_lossy().to_string();

run_test(cache_dir_str, contract_1_address, 0);
run_test(cache_dir_str, contract_2_address, 0);
run_test(&cache_dir_str, contract_1_address, 0);
run_test(&cache_dir_str, contract_2_address, 0);
assert_cache();

purge_cache(cache_dir.path().to_str().unwrap());

// Parallel execution
[
(cache_dir_str, contract_1_address, 0),
(cache_dir_str, contract_2_address, 0),
]
.par_iter()
.for_each(|param_tpl| run_test(param_tpl.0, param_tpl.1, param_tpl.2));
let tasks = [contract_1_address, contract_2_address].map(move |address| {
let cache_dir_str = cache_dir_str.clone();
Handle::current().spawn(async move { run_test(&cache_dir_str, address, 0) })
});

join_all(tasks).await;

assert_cache();
}

#[test]
fn test_cached_block_info_merging() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_cached_block_info_merging() {
fn run_test(cache_dir: &str, balance: u64, call_get_block_info: bool) {
let mut cached_state = create_fork_cached_state_at(53_680, cache_dir);
if call_get_block_info {
Expand Down Expand Up @@ -636,8 +637,8 @@ fn test_cached_block_info_merging() {
assert_cached_block_info(true);
}

#[test]
fn test_calling_nonexistent_url() {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_calling_nonexistent_url() {
let temp_dir = TempDir::new().unwrap();
let nonexistent_url = "http://nonexistent-node-address.com".parse().unwrap();
let mut cached_fork_state = CachedState::new(ExtendedStateReader {
Expand Down

0 comments on commit 9967fd1

Please sign in to comment.