From e0c6add2037cba1c2633bc1618c5b49f0aa15d2d Mon Sep 17 00:00:00 2001 From: bennyhodl Date: Tue, 26 Nov 2024 13:13:14 -0500 Subject: [PATCH] async manager --- ddk-manager/Cargo.toml | 4 +- ddk-manager/src/lib.rs | 5 +- ddk-manager/src/manager.rs | 140 +++-- ddk-manager/tests/channel_execution_tests.rs | 526 ++++++++++--------- ddk-manager/tests/manager_execution_tests.rs | 399 +++++++------- ddk-manager/tests/test_utils.rs | 64 +-- mocks/Cargo.toml | 3 +- mocks/src/mock_oracle_provider.rs | 5 +- p2pd-oracle-client/Cargo.toml | 5 +- p2pd-oracle-client/src/lib.rs | 48 +- sample/src/cli.rs | 186 ++++--- sample/src/main.rs | 11 +- 12 files changed, 767 insertions(+), 629 deletions(-) diff --git a/ddk-manager/Cargo.toml b/ddk-manager/Cargo.toml index 3c1cf80a..456589bb 100644 --- a/ddk-manager/Cargo.toml +++ b/ddk-manager/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Crypto Garage", "benny b "] description = "Creation and handling of Discrete Log Contracts (DLC)." -edition = "2018" +edition = "2021" homepage = "https://github.com/bennyhodl/rust-dlc" license-file = "../LICENSE" name = "ddk-manager" @@ -21,6 +21,7 @@ bitcoin = { version = "0.32.2", default-features = false } ddk-dlc = { version = "0.7.0", default-features = false, path = "../ddk-dlc" } ddk-messages = { version = "0.7.0", default-features = false, path = "../ddk-messages" } ddk-trie = { version = "0.7.0", default-features = false, path = "../ddk-trie" } +futures = "0.3.31" hex = { package = "hex-conservative", version = "0.1" } lightning = { version = "0.0.125", default-features = false, features = ["grind_signatures"] } log = "0.4.14" @@ -43,6 +44,7 @@ secp256k1-zkp = {version = "0.11.0", features = ["hashes", "rand", "rand-std", " serde = "1.0" serde_json = "1.0" simple-wallet = {path = "../simple-wallet"} +tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread", "test-util"] } [[bench]] harness = false diff --git a/ddk-manager/src/lib.rs b/ddk-manager/src/lib.rs index 62670a1b..bca5d42b 100644 --- a/ddk-manager/src/lib.rs +++ b/ddk-manager/src/lib.rs @@ -225,14 +225,15 @@ pub trait Storage { fn get_chain_monitor(&self) -> Result, Error>; } +#[async_trait::async_trait] /// Oracle trait provides access to oracle information. pub trait Oracle { /// Returns the public key of the oracle. fn get_public_key(&self) -> XOnlyPublicKey; /// Returns the announcement for the event with the given id if found. - fn get_announcement(&self, event_id: &str) -> Result; + async fn get_announcement(&self, event_id: &str) -> Result; /// Returns the attestation for the event with the given id if found. - fn get_attestation(&self, event_id: &str) -> Result; + async fn get_attestation(&self, event_id: &str) -> Result; } /// Represents a UTXO. diff --git a/ddk-manager/src/manager.rs b/ddk-manager/src/manager.rs index 860549ed..8778784b 100644 --- a/ddk-manager/src/manager.rs +++ b/ddk-manager/src/manager.rs @@ -31,6 +31,9 @@ use ddk_messages::channel::{ }; use ddk_messages::oracle_msgs::{OracleAnnouncement, OracleAttestation}; use ddk_messages::{AcceptDlc, Message as DlcMessage, OfferDlc, SignDlc}; +use futures::stream; +use futures::stream::FuturesUnordered; +use futures::{StreamExt, TryStreamExt}; use hex::DisplayHex; use lightning::chain::chaininterface::FeeEstimator; use lightning::ln::chan_utils::{ @@ -269,16 +272,13 @@ where /// and an OfferDlc message returned. /// /// This function will fetch the oracle announcements from the oracle. - pub fn send_offer( + pub async fn send_offer( &self, contract_input: &ContractInput, counter_party: PublicKey, ) -> Result { - let oracle_announcements = contract_input - .contract_infos - .iter() - .map(|x| self.get_oracle_announcements(&x.oracles)) - .collect::, Error>>()?; + // If the oracle announcement fails to retrieve, then log and continue. + let oracle_announcements = self.oracle_announcements(contract_input).await?; self.send_offer_with_announcements(contract_input, counter_party, oracle_announcements) } @@ -375,13 +375,13 @@ where /// Function to call to check the state of the currently executing DLCs and /// update them if possible. - pub fn periodic_check(&self, check_channels: bool) -> Result<(), Error> { + pub async fn periodic_check(&self, check_channels: bool) -> Result<(), Error> { self.check_signed_contracts()?; - self.check_confirmed_contracts()?; + self.check_confirmed_contracts().await?; self.check_preclosed_contracts()?; if check_channels { - self.channel_checks()?; + self.channel_checks().await?; } Ok(()) @@ -470,7 +470,7 @@ where Ok(()) } - fn get_oracle_announcements( + async fn get_oracle_announcements( &self, oracle_inputs: &OracleInput, ) -> Result, Error> { @@ -480,7 +480,8 @@ where .oracles .get(pubkey) .ok_or_else(|| Error::InvalidParameters("Unknown oracle public key".to_string()))?; - announcements.push(oracle.get_announcement(&oracle_inputs.event_id)?.clone()); + let announcement = oracle.get_announcement(&oracle_inputs.event_id).await?; + announcements.push(announcement); } Ok(announcements) @@ -547,13 +548,13 @@ where Ok(()) } - fn check_confirmed_contracts(&self) -> Result<(), Error> { + async fn check_confirmed_contracts(&self) -> Result<(), Error> { for c in self.store.get_confirmed_contracts()? { // Confirmed contracts from channel are processed in channel specific methods. if c.channel_id.is_some() { continue; } - if let Err(e) = self.check_confirmed_contract(&c) { + if let Err(e) = self.check_confirmed_contract(&c).await { error!( "Error checking confirmed contract {}: {}", c.accepted_contract.get_contract_id_string(), @@ -565,7 +566,7 @@ where Ok(()) } - fn get_closable_contract_info<'a>( + async fn get_closable_contract_info<'a>( &'a self, contract: &'a SignedContract, ) -> ClosableContractInfo<'a> { @@ -581,26 +582,54 @@ where .enumerate() .collect(); if matured.len() >= contract_info.threshold { - let attestations: Vec<_> = matured - .iter() - .filter_map(|(i, announcement)| { - let oracle = self.oracles.get(&announcement.oracle_public_key)?; - let attestation = oracle + let attestations = stream::iter(matured.iter()) + .map(|(i, announcement)| async move { + // First try to get the oracle + let oracle = match self.oracles.get(&announcement.oracle_public_key) { + Some(oracle) => oracle, + None => { + log::debug!( + "Oracle not found for key: {}", + announcement.oracle_public_key + ); + return None; + } + }; + + // Then try to get the attestation + let attestation = match oracle .get_attestation(&announcement.oracle_event.event_id) - .ok()?; - attestation - .validate(&self.secp, announcement) - .map_err(|_| { + .await + { + Ok(attestation) => attestation, + Err(e) => { log::error!( - "Oracle attestation is not valid. pubkey={} event_id={}", - announcement.oracle_public_key, - announcement.oracle_event.event_id - ) - }) - .ok()?; + "Attestation not found for event. id={} error={}", + announcement.oracle_event.event_id, + e.to_string() + ); + return None; + } + }; + + // Validate the attestation + if let Err(e) = attestation.validate(&self.secp, announcement) { + log::error!( + "Oracle attestation is not valid. pubkey={} event_id={}, error={:?}", + announcement.oracle_public_key, + announcement.oracle_event.event_id, + e + ); + return None; + } + Some((*i, attestation)) }) - .collect(); + .collect::>() + .await + .filter_map(|result| async move { result }) // Filter out None values + .collect::>() + .await; if attestations.len() >= contract_info.threshold { return Some((contract_info, adaptor_info, attestations)); } @@ -609,8 +638,8 @@ where None } - fn check_confirmed_contract(&self, contract: &SignedContract) -> Result<(), Error> { - let closable_contract_info = self.get_closable_contract_info(contract); + async fn check_confirmed_contract(&self, contract: &SignedContract) -> Result<(), Error> { + let closable_contract_info = self.get_closable_contract_info(contract).await; if let Some((contract_info, adaptor_info, attestations)) = closable_contract_info { let offer = &contract.accepted_contract.offered_contract; let signer = self.signer_provider.derive_contract_signer(offer.keys_id)?; @@ -910,16 +939,12 @@ where { /// Create a new channel offer and return the [`dlc_messages::channel::OfferChannel`] /// message to be sent to the `counter_party`. - pub fn offer_channel( + pub async fn offer_channel( &self, contract_input: &ContractInput, counter_party: PublicKey, ) -> Result { - let oracle_announcements = contract_input - .contract_infos - .iter() - .map(|x| self.get_oracle_announcements(&x.oracles)) - .collect::, Error>>()?; + let oracle_announcements = self.oracle_announcements(contract_input).await?; let (offered_channel, offered_contract) = crate::channel_updater::offer_channel( &self.secp, @@ -1092,7 +1117,7 @@ where /// Returns a [`RenewOffer`] message as well as the [`PublicKey`] of the /// counter party's node to offer the establishment of a new contract in the /// channel. - pub fn renew_offer( + pub async fn renew_offer( &self, channel_id: &ChannelId, counter_payout: u64, @@ -1101,11 +1126,7 @@ where let mut signed_channel = get_channel_in_state!(self, channel_id, Signed, None as Option)?; - let oracle_announcements = contract_input - .contract_infos - .iter() - .map(|x| self.get_oracle_announcements(&x.oracles)) - .collect::, Error>>()?; + let oracle_announcements = self.oracle_announcements(contract_input).await?; let (msg, offered_contract) = crate::channel_updater::renew_offer( &self.secp, @@ -1300,7 +1321,7 @@ where Ok(()) } - fn try_finalize_closing_established_channel( + async fn try_finalize_closing_established_channel( &self, signed_channel: SignedChannel, ) -> Result<(), Error> { @@ -1327,6 +1348,7 @@ where let (contract_info, adaptor_info, attestations) = self .get_closable_contract_info(&confirmed_contract) + .await .ok_or_else(|| { Error::InvalidState("Could not get information to close contract".to_string()) })?; @@ -2087,13 +2109,13 @@ where Ok(()) } - fn channel_checks(&self) -> Result<(), Error> { + async fn channel_checks(&self) -> Result<(), Error> { let established_closing_channels = self .store .get_signed_channels(Some(SignedChannelStateType::Closing))?; for channel in established_closing_channels { - if let Err(e) = self.try_finalize_closing_established_channel(channel) { + if let Err(e) = self.try_finalize_closing_established_channel(channel).await { error!("Error trying to close established channel: {}", e); } } @@ -2590,6 +2612,30 @@ where pnl, }) } + + async fn oracle_announcements( + &self, + contract_input: &ContractInput, + ) -> Result>, Error> { + let announcements = stream::iter(contract_input.contract_infos.iter()) + .map(|x| { + let future = self.get_oracle_announcements(&x.oracles); + async move { + match future.await { + Ok(result) => Ok(result), + Err(e) => { + log::error!("Failed to get oracle announcements: {}", e); + Err(e) + } + } + } + }) + .collect::>() + .await + .try_collect::>() + .await?; + Ok(announcements) + } } #[cfg(test)] diff --git a/ddk-manager/tests/channel_execution_tests.rs b/ddk-manager/tests/channel_execution_tests.rs index abcfeae0..6a281008 100644 --- a/ddk-manager/tests/channel_execution_tests.rs +++ b/ddk-manager/tests/channel_execution_tests.rs @@ -22,18 +22,15 @@ use secp256k1_zkp::EcdsaAdaptorSignature; use simple_wallet::SimpleWallet; use test_utils::{get_enum_test_params, TestParams}; -use std::sync::mpsc::{sync_channel, Receiver, Sender}; -use std::thread; - -use std::time::Duration; use std::{ collections::HashMap, sync::{ atomic::{AtomicBool, Ordering}, - mpsc::channel, - Arc, Mutex, + Arc, }, }; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::Mutex; use crate::test_utils::{refresh_wallet, EVENT_MATURITY}; @@ -57,10 +54,13 @@ type DlcParty = Arc< >, >; -fn get_established_channel_contract_id(dlc_party: &DlcParty, channel_id: &ChannelId) -> ContractId { +async fn get_established_channel_contract_id( + dlc_party: &DlcParty, + channel_id: &ChannelId, +) -> ContractId { let channel = dlc_party .lock() - .unwrap() + .await .get_store() .get_channel(channel_id) .unwrap() @@ -82,11 +82,11 @@ fn alter_adaptor_sig(input: &EcdsaAdaptorSignature) -> EcdsaAdaptorSignature { /// We wrap updating the state of the chain monitor and calling the /// `Manager::periodic_check` because the latter will only be aware of /// newly confirmed transactions if the former processes new blocks. -fn periodic_check(dlc_party: DlcParty) { - let dlc_manager = dlc_party.lock().unwrap(); +async fn periodic_check(dlc_party: DlcParty) { + let dlc_manager = dlc_party.lock().await; dlc_manager.periodic_chain_monitor().unwrap(); - dlc_manager.periodic_check(true).unwrap(); + dlc_manager.periodic_check(true).await.unwrap(); } #[derive(Eq, PartialEq, Clone)] @@ -115,180 +115,192 @@ enum TestPath { CancelOffer, } -#[test] +#[tokio::test] #[ignore] -fn channel_established_close_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::Close); +async fn channel_established_close_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::Close).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_bad_accept_buffer_adaptor_test() { +async fn channel_bad_accept_buffer_adaptor_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::BadAcceptBufferAdaptorSignature, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_bad_sign_buffer_adaptor_test() { +async fn channel_bad_sign_buffer_adaptor_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::BadSignBufferAdaptorSignature, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settled_close_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleClose); +async fn channel_settled_close_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleClose).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_punish_buffer_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::BufferCheat); +async fn channel_punish_buffer_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::BufferCheat).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_close_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::RenewedClose); +async fn channel_renew_close_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::RenewedClose).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_established_close_test() { +async fn channel_renew_established_close_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::RenewEstablishedClose, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settle_cheat_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleCheat); +async fn channel_settle_cheat_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleCheat).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_collaborative_close_test() { +async fn channel_collaborative_close_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::CollaborativeClose, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settle_renew_settle_test() { +async fn channel_settle_renew_settle_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::SettleRenewSettle, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settle_offer_timeout_test() { +async fn channel_settle_offer_timeout_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::SettleOfferTimeout, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settle_accept_timeout_test() { +async fn channel_settle_accept_timeout_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::SettleAcceptTimeout, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settle_confirm_timeout_test() { +async fn channel_settle_confirm_timeout_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::SettleConfirmTimeout, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settle_reject_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleReject); +async fn channel_settle_reject_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleReject).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settle_race_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleRace); +async fn channel_settle_race_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleRace).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_offer_timeout_test() { +async fn channel_renew_offer_timeout_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::RenewOfferTimeout, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_accept_timeout_test() { +async fn channel_renew_accept_timeout_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::RenewAcceptTimeout, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_confirm_timeout_test() { +async fn channel_renew_confirm_timeout_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::RenewConfirmTimeout, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_finalize_timeout_test() { +async fn channel_renew_finalize_timeout_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::RenewFinalizeTimeout, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_reject_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::RenewReject); +async fn channel_renew_reject_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::RenewReject).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_race_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::RenewRace); +async fn channel_renew_race_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::RenewRace).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_offer_reject_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::CancelOffer); +async fn channel_offer_reject_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::CancelOffer).await; } -fn channel_execution_test(test_params: TestParams, path: TestPath) { +async fn channel_execution_test(test_params: TestParams, path: TestPath) { env_logger::init(); - let (alice_send, bob_receive) = channel::>(); - let (bob_send, alice_receive) = channel::>(); - let (alice_sync_send, alice_sync_receive) = sync_channel::<()>(0); - let (bob_sync_send, bob_sync_receive) = sync_channel::<()>(0); + let (alice_send, mut bob_receive) = channel::>(100); + let (bob_send, mut alice_receive) = channel::>(100); + let (alice_sync_send, mut alice_sync_receive) = channel::<()>(100); + let (bob_sync_send, mut bob_sync_receive) = channel::<()>(100); let (_, _, sink_rpc) = init_clients(); @@ -476,43 +488,51 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { let offer_msg = bob_manager_send .lock() - .unwrap() + .await .offer_channel( &test_params.contract_input, "0218845781f631c48f1c9709e23092067d06837f30aa0cd0544ac887fe91ddd166" .parse() .unwrap(), ) + .await .expect("Send offer error"); let temporary_channel_id = offer_msg.temporary_channel_id; bob_send .send(Some(Message::OfferChannel(offer_msg))) + .await .unwrap(); assert_channel_state!(bob_manager_send, temporary_channel_id, Offered); - alice_sync_receive.recv().expect("Error synchronizing"); + alice_sync_receive + .recv() + .await + .expect("Error synchronizing"); assert_channel_state!(alice_manager_send, temporary_channel_id, Offered); if let TestPath::CancelOffer = path { let (reject_msg, _) = alice_manager_send .lock() - .unwrap() + .await .reject_channel(&temporary_channel_id) .expect("Error rejecting contract offer"); assert_channel_state!(alice_manager_send, temporary_channel_id, Cancelled); - alice_send.send(Some(Message::Reject(reject_msg))).unwrap(); + alice_send + .send(Some(Message::Reject(reject_msg))) + .await + .unwrap(); - bob_sync_receive.recv().expect("Error synchronizing"); + bob_sync_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(bob_manager_send, temporary_channel_id, Cancelled); return; } let (mut accept_msg, channel_id, contract_id, _) = alice_manager_send .lock() - .unwrap() + .await .accept_channel(&temporary_channel_id) .expect("Error accepting contract offer"); assert_channel_state!(alice_manager_send, channel_id, Accepted); @@ -524,30 +544,39 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { bob_expect_error.store(true, Ordering::Relaxed); alice_send .send(Some(Message::AcceptChannel(accept_msg))) + .await .unwrap(); - bob_sync_receive.recv().expect("Error synchronizing"); + bob_sync_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(bob_manager_send, temporary_channel_id, FailedAccept); } TestPath::BadSignBufferAdaptorSignature => { alice_expect_error.store(true, Ordering::Relaxed); alice_send .send(Some(Message::AcceptChannel(accept_msg))) + .await .unwrap(); // Bob receives accept message - bob_sync_receive.recv().expect("Error synchronizing"); + bob_sync_receive.recv().await.expect("Error synchronizing"); // Alice receives sign message - alice_sync_receive.recv().expect("Error synchronizing"); + alice_sync_receive + .recv() + .await + .expect("Error synchronizing"); assert_channel_state!(alice_manager_send, channel_id, FailedSign); } _ => { alice_send .send(Some(Message::AcceptChannel(accept_msg))) + .await .unwrap(); - bob_sync_receive.recv().expect("Error synchronizing"); + bob_sync_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(bob_manager_send, channel_id, Signed, Established); - alice_sync_receive.recv().expect("Error synchronizing"); + alice_sync_receive + .recv() + .await + .expect("Error synchronizing"); assert_channel_state!(alice_manager_send, channel_id, Signed, Established); @@ -555,9 +584,9 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { mocks::mock_time::set_time((EVENT_MATURITY as u64) + 1); - periodic_check(alice_manager_send.clone()); + periodic_check(alice_manager_send.clone()).await; - periodic_check(bob_manager_send.clone()); + periodic_check(bob_manager_send.clone()).await; assert_contract_state!(alice_manager_send, contract_id, Confirmed); assert_contract_state!(bob_manager_send, contract_id, Confirmed); @@ -568,25 +597,25 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { ( alice_manager_send, &alice_send, - &alice_sync_receive, + &mut alice_sync_receive, bob_manager_send, &bob_send, - &bob_sync_receive, + &mut bob_sync_receive, ) } else { ( bob_manager_send, &bob_send, - &bob_sync_receive, + &mut bob_sync_receive, alice_manager_send, &alice_send, - &alice_sync_receive, + &mut alice_sync_receive, ) }; match path { TestPath::Close => { - close_established_channel(first, second, channel_id, &generate_blocks); + close_established_channel(first, second, channel_id, &generate_blocks).await; } TestPath::CollaborativeClose => { collaborative_close( @@ -596,7 +625,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { channel_id, second_receive, &generate_blocks, - ); + ) + .await; } TestPath::SettleOfferTimeout | TestPath::SettleAcceptTimeout @@ -610,7 +640,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second_receive, channel_id, path, - ); + ) + .await; } TestPath::SettleReject => { settle_reject( @@ -621,7 +652,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second_send, second_receive, channel_id, - ); + ) + .await; } TestPath::SettleRace => { settle_race( @@ -632,7 +664,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second_send, second_receive, channel_id, - ); + ) + .await; } _ => { // Shuffle positions @@ -657,7 +690,7 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { ) }; - first.lock().unwrap().get_store().save(); + first.lock().await.get_store().save(); if let TestPath::RenewEstablishedClose = path { } else { @@ -669,7 +702,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second_send, second_receive, channel_id, - ); + ) + .await; } match path { @@ -682,12 +716,12 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { closer .lock() - .unwrap() + .await .force_close_channel(&channel_id) .expect("to be able to unilaterally close the channel."); } TestPath::BufferCheat => { - cheat_punish(first, second, channel_id, &generate_blocks, true); + cheat_punish(first, second, channel_id, &generate_blocks, true).await; } TestPath::RenewOfferTimeout | TestPath::RenewAcceptTimeout @@ -703,7 +737,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { &test_params.contract_input, path, &generate_blocks, - ); + ) + .await; } TestPath::RenewReject => { renew_reject( @@ -715,7 +750,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second_receive, channel_id, &test_params.contract_input, - ); + ) + .await; } TestPath::RenewRace => { renew_race( @@ -727,12 +763,13 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second_receive, channel_id, &test_params.contract_input, - ); + ) + .await; } TestPath::RenewedClose | TestPath::SettleCheat | TestPath::RenewEstablishedClose => { - first.lock().unwrap().get_store().save(); + first.lock().await.get_store().save(); let check_prev_contract_close = if let TestPath::RenewEstablishedClose = path { @@ -751,7 +788,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { channel_id, &test_params.contract_input, check_prev_contract_close, - ); + ) + .await; if let TestPath::RenewedClose = path { close_established_channel( @@ -759,9 +797,11 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second, channel_id, &generate_blocks, - ); + ) + .await; } else if let TestPath::SettleCheat = path { - cheat_punish(first, second, channel_id, &generate_blocks, false); + cheat_punish(first, second, channel_id, &generate_blocks, false) + .await; } } TestPath::SettleRenewSettle => { @@ -775,7 +815,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { channel_id, &test_params.contract_input, false, - ); + ) + .await; settle_channel( first, @@ -785,7 +826,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second_send, second_receive, channel_id, - ); + ) + .await; } _ => (), } @@ -794,14 +836,14 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { } } - alice_send.send(None).unwrap(); - bob_send.send(None).unwrap(); + alice_send.send(None).await.unwrap(); + bob_send.send(None).await.unwrap(); - alice_handle.join().unwrap(); - bob_handle.join().unwrap(); + alice_handle.await.unwrap(); + bob_handle.await.unwrap(); } -fn close_established_channel( +async fn close_established_channel( first: DlcParty, second: DlcParty, channel_id: ChannelId, @@ -811,31 +853,31 @@ fn close_established_channel( { first .lock() - .unwrap() + .await .force_close_channel(&channel_id) .expect("to be able to unilaterally close."); assert_channel_state!(first, channel_id, Signed, Closing); - let contract_id = get_established_channel_contract_id(&first, &channel_id); + let contract_id = get_established_channel_contract_id(&first, &channel_id).await; - periodic_check(first.clone()); + periodic_check(first.clone()).await; let wait = ddk_manager::manager::CET_NSEQUENCE; generate_blocks(10); - periodic_check(second.clone()); + periodic_check(second.clone()).await; assert_channel_state!(second, channel_id, Signed, Closing); - periodic_check(first.clone()); + periodic_check(first.clone()).await; // Should not have changed state before the CET is spendable. assert_channel_state!(first, channel_id, Signed, Closing); generate_blocks(wait as u64 - 9); - periodic_check(first.clone()); + periodic_check(first.clone()).await; assert_channel_state!(first, channel_id, Closed); @@ -843,70 +885,71 @@ fn close_established_channel( generate_blocks(1); - periodic_check(second.clone()); + periodic_check(second.clone()).await; assert_channel_state!(second, channel_id, CounterClosed); assert_contract_state!(second, contract_id, PreClosed); generate_blocks(5); - periodic_check(first.clone()); - periodic_check(second.clone()); + periodic_check(first.clone()).await; + periodic_check(second.clone()).await; assert_contract_state!(first, contract_id, Closed); assert_contract_state!(second, contract_id, Closed); } -fn cheat_punish( +async fn cheat_punish( first: DlcParty, second: DlcParty, channel_id: ChannelId, generate_blocks: &F, established: bool, ) { - first.lock().unwrap().get_store().rollback(); + first.lock().await.get_store().rollback(); if established { first .lock() - .unwrap() + .await .force_close_channel(&channel_id) .expect("the cheater to be able to close on established"); } else { first .lock() - .unwrap() + .await .force_close_channel(&channel_id) .expect("the cheater to be able to close on settled"); } generate_blocks(2); - periodic_check(second.clone()); + periodic_check(second.clone()).await; assert_channel_state!(second, channel_id, ClosedPunished); } -fn settle_channel( +async fn settle_channel( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, ) { let (settle_offer, _) = first .lock() - .unwrap() + .await .settle_offer(&channel_id, test_utils::ACCEPT_COLLATERAL) .expect("to be able to offer a settlement of the contract."); first_send .send(Some(Message::SettleOffer(settle_offer))) + .await .unwrap(); - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, SettledOffered); @@ -914,46 +957,48 @@ fn settle_channel( let (settle_accept, _) = second .lock() - .unwrap() + .await .accept_settle_offer(&channel_id) .expect("to be able to accept a settlement offer"); second_send .send(Some(Message::SettleAccept(settle_accept))) + .await .unwrap(); // Process Accept - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); // Process Confirm - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); // Process Finalize - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, Settled); assert_channel_state!(second, channel_id, Signed, Settled); } -fn settle_reject( +async fn settle_reject( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, ) { let (settle_offer, _) = first .lock() - .unwrap() + .await .settle_offer(&channel_id, test_utils::ACCEPT_COLLATERAL) .expect("to be able to reject a settlement of the contract."); first_send .send(Some(Message::SettleOffer(settle_offer))) + .await .unwrap(); - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, SettledOffered); @@ -961,128 +1006,126 @@ fn settle_reject( let (settle_reject, _) = second .lock() - .unwrap() + .await .reject_settle_offer(&channel_id) .expect("to be able to reject a settlement offer"); second_send .send(Some(Message::Reject(settle_reject))) + .await .unwrap(); - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, Established); assert_channel_state!(second, channel_id, Signed, Established); } -fn settle_race( +async fn settle_race( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, ) { let (settle_offer, _) = first .lock() - .unwrap() + .await .settle_offer(&channel_id, test_utils::ACCEPT_COLLATERAL) .expect("to be able to offer a settlement of the contract."); let (settle_offer_2, _) = second .lock() - .unwrap() + .await .settle_offer(&channel_id, test_utils::ACCEPT_COLLATERAL) .expect("to be able to offer a settlement of the contract."); first_send .send(Some(Message::SettleOffer(settle_offer))) + .await .unwrap(); second_send .send(Some(Message::SettleOffer(settle_offer_2))) + .await .unwrap(); // Process 2 offers + 2 rejects - first_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 1"); - second_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 2"); - first_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 3"); - second_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 4"); + first_receive.recv().await.expect("Error synchronizing 1"); + second_receive.recv().await.expect("Error synchronizing 2"); + first_receive.recv().await.expect("Error synchronizing 3"); + second_receive.recv().await.expect("Error synchronizing 4"); assert_channel_state!(first, channel_id, Signed, Established); assert_channel_state!(second, channel_id, Signed, Established); } -fn renew_channel( +async fn renew_channel( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, contract_input: &ContractInput, check_prev_contract_close: bool, ) { let prev_contract_id = if check_prev_contract_close { - Some(get_established_channel_contract_id(&first, &channel_id)) + Some(get_established_channel_contract_id(&first, &channel_id).await) } else { None }; let (renew_offer, _) = first .lock() - .unwrap() + .await .renew_offer(&channel_id, test_utils::ACCEPT_COLLATERAL, contract_input) + .await .expect("to be able to renew channel contract"); first_send .send(Some(Message::RenewOffer(renew_offer))) + .await .expect("to be able to send the renew offer"); // Process Renew Offer - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, RenewOffered); assert_channel_state!(second, channel_id, Signed, RenewOffered); let (accept_renew, _) = second .lock() - .unwrap() + .await .accept_renew_offer(&channel_id) .expect("to be able to accept the renewal"); second_send .send(Some(Message::RenewAccept(accept_renew))) + .await .expect("to be able to send the accept renew"); // Process Renew Accept - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, RenewConfirmed); // Process Renew Confirm - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); // Process Renew Finalize - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); // Process Renew Revoke - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); if let Some(prev_contract_id) = prev_contract_id { assert_contract_state!(first, prev_contract_id, Closed); assert_contract_state!(second, prev_contract_id, Closed); } - let new_contract_id = get_established_channel_contract_id(&first, &channel_id); + let new_contract_id = get_established_channel_contract_id(&first, &channel_id).await; assert_channel_state!(first, channel_id, Signed, Established); assert_contract_state!(first, new_contract_id, Confirmed); @@ -1090,62 +1133,66 @@ fn renew_channel( assert_contract_state!(second, new_contract_id, Confirmed); } -fn renew_reject( +async fn renew_reject( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, contract_input: &ContractInput, ) { let (renew_offer, _) = first .lock() - .unwrap() + .await .renew_offer(&channel_id, test_utils::ACCEPT_COLLATERAL, contract_input) + .await .expect("to be able to renew channel contract"); first_send .send(Some(Message::RenewOffer(renew_offer))) + .await .expect("to be able to send the renew offer"); // Process Renew Offer - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, RenewOffered); assert_channel_state!(second, channel_id, Signed, RenewOffered); let (renew_reject, _) = second .lock() - .unwrap() + .await .reject_renew_offer(&channel_id) .expect("to be able to reject the renewal"); second_send .send(Some(Message::Reject(renew_reject))) + .await .expect("to be able to send the renew reject"); // Process Renew Reject - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, Settled); assert_channel_state!(second, channel_id, Signed, Settled); } -fn renew_race( +async fn renew_race( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, contract_input: &ContractInput, ) { let (renew_offer, _) = first .lock() - .unwrap() + .await .renew_offer(&channel_id, test_utils::OFFER_COLLATERAL, contract_input) + .await .expect("to be able to renew channel contract"); let mut contract_input_2 = contract_input.clone(); @@ -1154,61 +1201,57 @@ fn renew_race( let (renew_offer_2, _) = second .lock() - .unwrap() + .await .renew_offer(&channel_id, test_utils::OFFER_COLLATERAL, &contract_input_2) + .await .expect("to be able to renew channel contract"); first_send .send(Some(Message::RenewOffer(renew_offer))) + .await .expect("to be able to send the renew offer"); second_send .send(Some(Message::RenewOffer(renew_offer_2))) + .await .expect("to be able to send the renew offer"); // Process 2 offers + 2 rejects - first_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 1"); - second_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 2"); - first_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 3"); - second_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 4"); + first_receive.recv().await.expect("Error synchronizing 1"); + second_receive.recv().await.expect("Error synchronizing 2"); + first_receive.recv().await.expect("Error synchronizing 3"); + second_receive.recv().await.expect("Error synchronizing 4"); assert_channel_state!(first, channel_id, Signed, Settled); assert_channel_state!(second, channel_id, Signed, Settled); } -fn collaborative_close( +async fn collaborative_close( first: DlcParty, first_send: &Sender>, second: DlcParty, channel_id: ChannelId, - sync_receive: &Receiver<()>, + sync_receive: &mut Receiver<()>, generate_blocks: &F, ) { - let contract_id = get_established_channel_contract_id(&first, &channel_id); + let contract_id = get_established_channel_contract_id(&first, &channel_id).await; let close_offer = first .lock() - .unwrap() + .await .offer_collaborative_close(&channel_id, 100000000) .expect("to be able to propose a collaborative close"); first_send .send(Some(Message::CollaborativeCloseOffer(close_offer))) + .await .expect("to be able to send collaborative close"); - sync_receive.recv().expect("Error synchronizing"); + sync_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, CollaborativeCloseOffered); assert_channel_state!(second, channel_id, Signed, CollaborativeCloseOffered); second .lock() - .unwrap() + .await .accept_collaborative_close(&channel_id) .expect("to be able to accept a collaborative close"); @@ -1217,19 +1260,19 @@ fn collaborative_close( generate_blocks(2); - periodic_check(first.clone()); + periodic_check(first.clone()).await; assert_channel_state!(first, channel_id, CollaborativelyClosed); assert_contract_state!(first, contract_id, Closed); } -fn renew_timeout( +async fn renew_timeout( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, contract_input: &ContractInput, path: TestPath, @@ -1238,64 +1281,67 @@ fn renew_timeout( { let (renew_offer, _) = first .lock() - .unwrap() + .await .renew_offer(&channel_id, test_utils::ACCEPT_COLLATERAL, contract_input) + .await .expect("to be able to offer a settlement of the contract."); first_send .send(Some(Message::RenewOffer(renew_offer))) + .await .unwrap(); - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); if let TestPath::RenewOfferTimeout = path { mocks::mock_time::set_time( (EVENT_MATURITY as u64) + ddk_manager::manager::PEER_TIMEOUT + 2, ); - periodic_check(first.clone()); + periodic_check(first.clone()).await; assert_channel_state!(first, channel_id, Closed); } else { let (renew_accept, _) = second .lock() - .unwrap() + .await .accept_renew_offer(&channel_id) .expect("to be able to accept a settlement offer"); second_send .send(Some(Message::RenewAccept(renew_accept))) + .await .unwrap(); // Process Accept - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); if let TestPath::RenewAcceptTimeout = path { mocks::mock_time::set_time( (EVENT_MATURITY as u64) + ddk_manager::manager::PEER_TIMEOUT + 2, ); - periodic_check(second.clone()); + periodic_check(second.clone()).await; assert_channel_state!(second, channel_id, Closed); } else if let TestPath::RenewConfirmTimeout = path { // Process Confirm - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); mocks::mock_time::set_time( (EVENT_MATURITY as u64) + ddk_manager::manager::PEER_TIMEOUT + 2, ); - periodic_check(first.clone()); + periodic_check(first.clone()).await; assert_channel_state!(first, channel_id, Closed); } else if let TestPath::RenewFinalizeTimeout = path { //Process confirm - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); // Process Finalize - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); mocks::mock_time::set_time( (EVENT_MATURITY as u64) + ddk_manager::manager::PEER_TIMEOUT + 2, ); - periodic_check(second.clone()); + periodic_check(second.clone()).await; generate_blocks(289); - periodic_check(second.clone()); + periodic_check(second.clone()).await; assert_channel_state!(second, channel_id, Closed); } @@ -1303,69 +1349,71 @@ fn renew_timeout( } } -fn settle_timeout( +async fn settle_timeout( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, path: TestPath, ) { let (settle_offer, _) = first .lock() - .unwrap() + .await .settle_offer(&channel_id, test_utils::ACCEPT_COLLATERAL) .expect("to be able to offer a settlement of the contract."); first_send .send(Some(Message::SettleOffer(settle_offer))) + .await .unwrap(); - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); if let TestPath::SettleOfferTimeout = path { mocks::mock_time::set_time( (EVENT_MATURITY as u64) + ddk_manager::manager::PEER_TIMEOUT + 2, ); - periodic_check(first.clone()); + periodic_check(first.clone()).await; assert_channel_state!(first, channel_id, Signed, Closing); } else { let (settle_accept, _) = second .lock() - .unwrap() + .await .accept_settle_offer(&channel_id) .expect("to be able to accept a settlement offer"); second_send .send(Some(Message::SettleAccept(settle_accept))) + .await .unwrap(); // Process Accept - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); if let TestPath::SettleAcceptTimeout = path { mocks::mock_time::set_time( (EVENT_MATURITY as u64) + ddk_manager::manager::PEER_TIMEOUT + 2, ); - periodic_check(second.clone()); + periodic_check(second.clone()).await; second .lock() - .unwrap() + .await .get_store() .get_channel(&channel_id) .unwrap(); assert_channel_state!(second, channel_id, Signed, Closing); } else if let TestPath::SettleConfirmTimeout = path { // Process Confirm - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); mocks::mock_time::set_time( (EVENT_MATURITY as u64) + ddk_manager::manager::PEER_TIMEOUT + 2, ); - periodic_check(first.clone()); + periodic_check(first.clone()).await; assert_channel_state!(first, channel_id, Signed, Closing); } diff --git a/ddk-manager/tests/manager_execution_tests.rs b/ddk-manager/tests/manager_execution_tests.rs index 8ea4509b..f1555c05 100644 --- a/ddk-manager/tests/manager_execution_tests.rs +++ b/ddk-manager/tests/manager_execution_tests.rs @@ -30,11 +30,10 @@ use serde_json::{from_str, to_writer_pretty}; use std::collections::HashMap; use std::sync::{ atomic::{AtomicBool, Ordering}, - mpsc::channel, - Arc, Mutex, + Arc, }; -use std::thread; - +use tokio::sync::mpsc::channel; +use tokio::sync::Mutex; #[derive(serde::Serialize, serde::Deserialize)] struct TestVectorPart { message: T, @@ -90,15 +89,16 @@ fn create_test_vector() { macro_rules! periodic_check { ($d:expr, $id:expr, $p:ident) => { $d.lock() - .unwrap() + .await .periodic_check(true) + .await .expect("Periodic check error"); assert_contract_state!($d, $id, $p); }; } -fn numerical_common( +async fn numerical_common( nb_oracles: usize, threshold: usize, payout_function_pieces_cb: F, @@ -124,10 +124,11 @@ fn numerical_common( ), TestPath::Close, manual_close, - ); + ) + .await; } -fn numerical_polynomial_common( +async fn numerical_polynomial_common( nb_oracles: usize, threshold: usize, difference_params: Option, @@ -139,10 +140,11 @@ fn numerical_polynomial_common( get_polynomial_payout_curve_pieces, difference_params, manual_close, - ); + ) + .await; } -fn numerical_common_diff_nb_digits( +async fn numerical_common_diff_nb_digits( nb_oracles: usize, threshold: usize, difference_params: Option, @@ -171,7 +173,8 @@ fn numerical_common_diff_nb_digits( ), TestPath::Close, manual_close, - ); + ) + .await; } #[derive(Eq, PartialEq, Clone)] @@ -184,330 +187,344 @@ enum TestPath { BadSignRefundSignature, } -#[test] +#[tokio::test] #[ignore] -fn single_oracle_numerical_test() { - numerical_polynomial_common(1, 1, None, false); +async fn single_oracle_numerical_test() { + numerical_polynomial_common(1, 1, None, false).await; } -#[test] +#[tokio::test] #[ignore] -fn single_oracle_numerical_manual_test() { - numerical_polynomial_common(1, 1, None, true); +async fn single_oracle_numerical_manual_test() { + numerical_polynomial_common(1, 1, None, true).await; } -#[test] +#[tokio::test] #[ignore] -fn single_oracle_numerical_hyperbola_test() { - numerical_common(1, 1, get_hyperbola_payout_curve_pieces, None, false); +async fn single_oracle_numerical_hyperbola_test() { + numerical_common(1, 1, get_hyperbola_payout_curve_pieces, None, false).await; } -#[test] +#[tokio::test] #[ignore] -fn three_of_three_oracle_numerical_test() { - numerical_polynomial_common(3, 3, None, false); +async fn three_of_three_oracle_numerical_test() { + numerical_polynomial_common(3, 3, None, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_test() { - numerical_polynomial_common(5, 2, None, false); +async fn two_of_five_oracle_numerical_test() { + numerical_polynomial_common(5, 2, None, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_manual_test() { - numerical_polynomial_common(5, 2, None, true); +async fn two_of_five_oracle_numerical_manual_test() { + numerical_polynomial_common(5, 2, None, true).await; } -#[test] +#[tokio::test] #[ignore] -fn three_of_three_oracle_numerical_with_diff_test() { - numerical_polynomial_common(3, 3, Some(get_difference_params()), false); +async fn three_of_three_oracle_numerical_with_diff_test() { + numerical_polynomial_common(3, 3, Some(get_difference_params()), false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_with_diff_test() { - numerical_polynomial_common(5, 2, Some(get_difference_params()), false); +async fn two_of_five_oracle_numerical_with_diff_test() { + numerical_polynomial_common(5, 2, Some(get_difference_params()), false).await; } -#[test] +#[tokio::test] #[ignore] -fn three_of_five_oracle_numerical_with_diff_test() { - numerical_polynomial_common(5, 3, Some(get_difference_params()), false); +async fn three_of_five_oracle_numerical_with_diff_test() { + numerical_polynomial_common(5, 3, Some(get_difference_params()), false).await; } -#[test] +#[tokio::test] #[ignore] -fn three_of_five_oracle_numerical_with_diff_manual_test() { - numerical_polynomial_common(5, 3, Some(get_difference_params()), true); +async fn three_of_five_oracle_numerical_with_diff_manual_test() { + numerical_polynomial_common(5, 3, Some(get_difference_params()), true).await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_test() { - manager_execution_test(get_enum_test_params(1, 1, None), TestPath::Close, false); +async fn enum_single_oracle_test() { + manager_execution_test(get_enum_test_params(1, 1, None), TestPath::Close, false).await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_manual_test() { - manager_execution_test(get_enum_test_params(1, 1, None), TestPath::Close, true); +async fn enum_single_oracle_manual_test() { + manager_execution_test(get_enum_test_params(1, 1, None), TestPath::Close, true).await; } -#[test] +#[tokio::test] #[ignore] -fn enum_3_of_3_test() { - manager_execution_test(get_enum_test_params(3, 3, None), TestPath::Close, false); +async fn enum_3_of_3_test() { + manager_execution_test(get_enum_test_params(3, 3, None), TestPath::Close, false).await; } -#[test] +#[tokio::test] #[ignore] -fn enum_3_of_3_manual_test() { - manager_execution_test(get_enum_test_params(3, 3, None), TestPath::Close, true); +async fn enum_3_of_3_manual_test() { + manager_execution_test(get_enum_test_params(3, 3, None), TestPath::Close, true).await; } -#[test] +#[tokio::test] #[ignore] -fn enum_3_of_5_test() { - manager_execution_test(get_enum_test_params(5, 3, None), TestPath::Close, false); +async fn enum_3_of_5_test() { + manager_execution_test(get_enum_test_params(5, 3, None), TestPath::Close, false).await; } -#[test] +#[tokio::test] #[ignore] -fn enum_3_of_5_manual_test() { - manager_execution_test(get_enum_test_params(5, 3, None), TestPath::Close, true); +async fn enum_3_of_5_manual_test() { + manager_execution_test(get_enum_test_params(5, 3, None), TestPath::Close, true).await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_with_diff_3_of_5_test() { +async fn enum_and_numerical_with_diff_3_of_5_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 3, true, Some(get_difference_params())), TestPath::Close, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_with_diff_3_of_5_manual_test() { +async fn enum_and_numerical_with_diff_3_of_5_manual_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 3, true, Some(get_difference_params())), TestPath::Close, true, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_with_diff_5_of_5_test() { +async fn enum_and_numerical_with_diff_5_of_5_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 5, true, Some(get_difference_params())), TestPath::Close, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_with_diff_5_of_5_manual_test() { +async fn enum_and_numerical_with_diff_5_of_5_manual_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 5, true, Some(get_difference_params())), TestPath::Close, true, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_3_of_5_test() { +async fn enum_and_numerical_3_of_5_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 3, false, None), TestPath::Close, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_3_of_5_manual_test() { +async fn enum_and_numerical_3_of_5_manual_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 3, false, None), TestPath::Close, true, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_5_of_5_test() { +async fn enum_and_numerical_5_of_5_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 5, false, None), TestPath::Close, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_5_of_5_manual_test() { +async fn enum_and_numerical_5_of_5_manual_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 5, false, None), TestPath::Close, true, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_refund_test() { +async fn enum_single_oracle_refund_test() { manager_execution_test( get_enum_test_params(1, 1, Some(get_enum_oracles(1, 0))), TestPath::Refund, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_refund_manual_test() { +async fn enum_single_oracle_refund_manual_test() { manager_execution_test( get_enum_test_params(1, 1, Some(get_enum_oracles(1, 0))), TestPath::Refund, true, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_bad_accept_cet_sig_test() { +async fn enum_single_oracle_bad_accept_cet_sig_test() { manager_execution_test( get_enum_test_params(1, 1, Some(get_enum_oracles(1, 0))), TestPath::BadAcceptCetSignature, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_bad_accept_refund_sig_test() { +async fn enum_single_oracle_bad_accept_refund_sig_test() { manager_execution_test( get_enum_test_params(1, 1, Some(get_enum_oracles(1, 0))), TestPath::BadAcceptRefundSignature, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_bad_sign_cet_sig_test() { +async fn enum_single_oracle_bad_sign_cet_sig_test() { manager_execution_test( get_enum_test_params(1, 1, Some(get_enum_oracles(1, 0))), TestPath::BadSignCetSignature, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_bad_sign_refund_sig_test() { +async fn enum_single_oracle_bad_sign_refund_sig_test() { manager_execution_test( get_enum_test_params(1, 1, Some(get_enum_oracles(1, 0))), TestPath::BadSignRefundSignature, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_two_oracle_numerical_diff_nb_digits_test() { - numerical_common_diff_nb_digits(2, 2, None, false, false); +async fn two_of_two_oracle_numerical_diff_nb_digits_test() { + numerical_common_diff_nb_digits(2, 2, None, false, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_two_oracle_numerical_diff_nb_digits_manual_test() { - numerical_common_diff_nb_digits(2, 2, None, false, true); +async fn two_of_two_oracle_numerical_diff_nb_digits_manual_test() { + numerical_common_diff_nb_digits(2, 2, None, false, true).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_diff_nb_digits_test() { - numerical_common_diff_nb_digits(5, 2, None, false, false); +async fn two_of_five_oracle_numerical_diff_nb_digits_test() { + numerical_common_diff_nb_digits(5, 2, None, false, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_diff_nb_digits_manual_test() { - numerical_common_diff_nb_digits(5, 2, None, false, true); +async fn two_of_five_oracle_numerical_diff_nb_digits_manual_test() { + numerical_common_diff_nb_digits(5, 2, None, false, true).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_two_oracle_numerical_with_diff_diff_nb_digits_test() { - numerical_common_diff_nb_digits(2, 2, Some(get_difference_params()), false, false); +async fn two_of_two_oracle_numerical_with_diff_diff_nb_digits_test() { + numerical_common_diff_nb_digits(2, 2, Some(get_difference_params()), false, false).await; } -#[test] +#[tokio::test] #[ignore] -fn three_of_three_oracle_numerical_with_diff_diff_nb_digits_test() { - numerical_common_diff_nb_digits(3, 3, Some(get_difference_params()), false, false); +async fn three_of_three_oracle_numerical_with_diff_diff_nb_digits_test() { + numerical_common_diff_nb_digits(3, 3, Some(get_difference_params()), false, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_with_diff_diff_nb_digits_test() { - numerical_common_diff_nb_digits(5, 2, Some(get_difference_params()), false, false); +async fn two_of_five_oracle_numerical_with_diff_diff_nb_digits_test() { + numerical_common_diff_nb_digits(5, 2, Some(get_difference_params()), false, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_two_oracle_numerical_with_diff_diff_nb_digits_max_value_test() { - numerical_common_diff_nb_digits(2, 2, Some(get_difference_params()), true, false); +async fn two_of_two_oracle_numerical_with_diff_diff_nb_digits_max_value_test() { + numerical_common_diff_nb_digits(2, 2, Some(get_difference_params()), true, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_three_oracle_numerical_with_diff_diff_nb_digits_max_value_test() { - numerical_common_diff_nb_digits(3, 2, Some(get_difference_params()), true, false); +async fn two_of_three_oracle_numerical_with_diff_diff_nb_digits_max_value_test() { + numerical_common_diff_nb_digits(3, 2, Some(get_difference_params()), true, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_with_diff_diff_nb_digits_max_value_test() { - numerical_common_diff_nb_digits(5, 2, Some(get_difference_params()), true, false); +async fn two_of_five_oracle_numerical_with_diff_diff_nb_digits_max_value_test() { + numerical_common_diff_nb_digits(5, 2, Some(get_difference_params()), true, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_with_diff_diff_nb_digits_max_value_manual_test() { - numerical_common_diff_nb_digits(5, 2, Some(get_difference_params()), true, true); +async fn two_of_five_oracle_numerical_with_diff_diff_nb_digits_max_value_manual_test() { + numerical_common_diff_nb_digits(5, 2, Some(get_difference_params()), true, true).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_two_oracle_numerical_diff_nb_digits_max_value_test() { - numerical_common_diff_nb_digits(2, 2, None, true, false); +async fn two_of_two_oracle_numerical_diff_nb_digits_max_value_test() { + numerical_common_diff_nb_digits(2, 2, None, true, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_three_oracle_numerical_diff_nb_digits_max_value_test() { - numerical_common_diff_nb_digits(3, 2, None, true, false); +async fn two_of_three_oracle_numerical_diff_nb_digits_max_value_test() { + numerical_common_diff_nb_digits(3, 2, None, true, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_diff_nb_digits_max_value_test() { - numerical_common_diff_nb_digits(5, 2, None, true, false); +async fn two_of_five_oracle_numerical_diff_nb_digits_max_value_test() { + numerical_common_diff_nb_digits(5, 2, None, true, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_diff_nb_digits_max_value_manual_test() { - numerical_common_diff_nb_digits(5, 2, None, true, true); +async fn two_of_five_oracle_numerical_diff_nb_digits_max_value_manual_test() { + numerical_common_diff_nb_digits(5, 2, None, true, true).await; } fn alter_adaptor_sig(input: &mut CetAdaptorSignatures) { @@ -530,24 +547,21 @@ fn alter_refund_sig(refund_signature: &Signature) -> Signature { Signature::from_compact(©).unwrap() } -fn get_attestations(test_params: &TestParams) -> Vec<(usize, OracleAttestation)> { +async fn get_attestations(test_params: &TestParams) -> Vec<(usize, OracleAttestation)> { + let mut attestations = Vec::new(); for contract_info in test_params.contract_input.contract_infos.iter() { - let attestations: Vec<_> = contract_info - .oracles - .public_keys - .iter() - .enumerate() - .filter_map(|(i, pk)| { - let oracle = test_params - .oracles - .iter() - .find(|x| x.get_public_key() == *pk); - - oracle - .and_then(|o| o.get_attestation(&contract_info.oracles.event_id).ok()) - .map(|a| (i, a)) - }) - .collect(); + attestations.clear(); + for (i, pk) in contract_info.oracles.public_keys.iter().enumerate() { + let oracle = test_params + .oracles + .iter() + .find(|x| x.get_public_key() == *pk); + if let Some(o) = oracle { + if let Ok(attestation) = o.get_attestation(&contract_info.oracles.event_id).await { + attestations.push((i, attestation)); + } + } + } if attestations.len() >= contract_info.oracles.threshold as usize { return attestations; } @@ -556,11 +570,11 @@ fn get_attestations(test_params: &TestParams) -> Vec<(usize, OracleAttestation)> panic!("No attestations found"); } -fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: bool) { +async fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: bool) { env_logger::try_init().ok(); - let (alice_send, bob_receive) = channel::>(); - let (bob_send, alice_receive) = channel::>(); - let (sync_send, sync_receive) = channel::<()>(); + let (alice_send, mut bob_receive) = channel::>(100); + let (bob_send, mut alice_receive) = channel::>(100); + let (sync_send, mut sync_receive) = channel::<()>(100); let alice_sync_send = sync_send.clone(); let bob_sync_send = sync_send; let (_, _, sink_rpc) = init_clients(); @@ -734,28 +748,32 @@ fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: let offer_msg = bob_manager_send .lock() - .unwrap() + .await .send_offer( &test_params.contract_input, "0218845781f631c48f1c9709e23092067d06837f30aa0cd0544ac887fe91ddd166" .parse() .unwrap(), ) + .await .expect("Send offer error"); write_message("offer_message", offer_msg.clone()); let temporary_contract_id = offer_msg.temporary_contract_id; - bob_send.send(Some(Message::Offer(offer_msg))).unwrap(); + bob_send + .send(Some(Message::Offer(offer_msg))) + .await + .unwrap(); assert_contract_state!(bob_manager_send, temporary_contract_id, Offered); - sync_receive.recv().expect("Error synchronizing"); + sync_receive.recv().await.expect("Error synchronizing"); assert_contract_state!(alice_manager_send, temporary_contract_id, Offered); let (contract_id, _, mut accept_msg) = alice_manager_send .lock() - .unwrap() + .await .accept_contract_offer(&temporary_contract_id) .expect("Error accepting contract offer"); @@ -775,29 +793,38 @@ fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: _ => {} }; bob_expect_error.store(true, Ordering::Relaxed); - alice_send.send(Some(Message::Accept(accept_msg))).unwrap(); - sync_receive.recv().expect("Error synchronizing"); + alice_send + .send(Some(Message::Accept(accept_msg))) + .await + .unwrap(); + sync_receive.recv().await.expect("Error synchronizing"); assert_contract_state!(bob_manager_send, temporary_contract_id, FailedAccept); } TestPath::BadSignCetSignature | TestPath::BadSignRefundSignature => { alice_expect_error.store(true, Ordering::Relaxed); - alice_send.send(Some(Message::Accept(accept_msg))).unwrap(); + alice_send + .send(Some(Message::Accept(accept_msg))) + .await + .unwrap(); // Bob receives accept message - sync_receive.recv().expect("Error synchronizing"); + sync_receive.recv().await.expect("Error synchronizing"); // Alice receives sign message - sync_receive.recv().expect("Error synchronizing"); + sync_receive.recv().await.expect("Error synchronizing"); assert_contract_state!(alice_manager_send, contract_id, FailedSign); } TestPath::Close | TestPath::Refund => { - alice_send.send(Some(Message::Accept(accept_msg))).unwrap(); - sync_receive.recv().expect("Error synchronizing"); + alice_send + .send(Some(Message::Accept(accept_msg))) + .await + .unwrap(); + sync_receive.recv().await.expect("Error synchronizing"); assert_contract_state!(bob_manager_send, contract_id, Signed); // Should not change state and should not error periodic_check!(bob_manager_send, contract_id, Signed); - sync_receive.recv().expect("Error synchronizing"); + sync_receive.recv().await.expect("Error synchronizing"); assert_contract_state!(alice_manager_send, contract_id, Signed); @@ -831,15 +858,15 @@ fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: if manual_close { periodic_check!(first, contract_id, Confirmed); - let attestations = get_attestations(&test_params); + let attestations = get_attestations(&test_params).await; - let f = first.lock().unwrap(); + let f = first.lock().await; let contract = f .close_confirmed_contract(&contract_id, attestations) .expect("Error closing contract"); if let Contract::PreClosed(contract) = contract { - let mut s = second.lock().unwrap(); + let mut s = second.lock().await; let second_contract = s.get_store().get_contract(&contract_id).unwrap().unwrap(); if let Contract::Confirmed(signed) = second_contract { @@ -899,11 +926,11 @@ fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: } } - alice_send.send(None).unwrap(); - bob_send.send(None).unwrap(); + alice_send.send(None).await.unwrap(); + bob_send.send(None).await.unwrap(); - alice_handle.join().unwrap(); - bob_handle.join().unwrap(); + alice_handle.await.unwrap(); + bob_handle.await.unwrap(); create_test_vector(); } diff --git a/ddk-manager/tests/test_utils.rs b/ddk-manager/tests/test_utils.rs index 9f87b6b3..9c44a12a 100644 --- a/ddk-manager/tests/test_utils.rs +++ b/ddk-manager/tests/test_utils.rs @@ -46,39 +46,41 @@ pub const ROUNDING_MOD: u64 = 1; #[macro_export] macro_rules! receive_loop { ($receive:expr, $manager:expr, $send:expr, $expect_err:expr, $sync_send:expr, $rcv_callback: expr, $msg_callback: expr) => { - thread::spawn(move || loop { - match $receive.recv() { - Ok(Some(msg)) => match $manager.lock().unwrap().on_dlc_message( - &msg, - "0218845781f631c48f1c9709e23092067d06837f30aa0cd0544ac887fe91ddd166" - .parse() - .unwrap(), - ) { - Ok(opt) => { - if $expect_err.load(Ordering::Relaxed) != false { - panic!("Expected error not raised"); - } - match opt { - Some(msg) => { - let msg_opt = $rcv_callback(msg); - if let Some(msg) = msg_opt { - #[allow(clippy::redundant_closure_call)] - $msg_callback(&msg); - (&$send).send(Some(msg)).expect("Error sending"); + tokio::spawn(async move { + loop { + match $receive.recv().await { + Some(Some(msg)) => match $manager.lock().await.on_dlc_message( + &msg, + "0218845781f631c48f1c9709e23092067d06837f30aa0cd0544ac887fe91ddd166" + .parse() + .unwrap(), + ) { + Ok(opt) => { + if $expect_err.load(Ordering::Relaxed) != false { + panic!("Expected error not raised"); + } + match opt { + Some(msg) => { + let msg_opt = $rcv_callback(msg); + if let Some(msg) = msg_opt { + #[allow(clippy::redundant_closure_call)] + $msg_callback(&msg); + (&$send).send(Some(msg)).await.expect("Error sending"); + } } + None => {} } - None => {} } - } - Err(e) => { - if $expect_err.load(Ordering::Relaxed) != true { - panic!("Unexpected error {}", e); + Err(e) => { + if $expect_err.load(Ordering::Relaxed) != true { + panic!("Unexpected error {}", e); + } } - } - }, - Ok(None) | Err(_) => return, - }; - $sync_send.send(()).expect("Error syncing"); + }, + None | Some(None) => return, + }; + $sync_send.send(()).await.expect("Error syncing"); + } }) }; } @@ -104,7 +106,7 @@ macro_rules! assert_contract_state { ($d:expr, $id:expr, $p:ident) => { let res = $d .lock() - .unwrap() + .await .get_store() .get_contract(&$id) .expect("Could not retrieve contract"); @@ -146,7 +148,7 @@ macro_rules! write_channel { #[macro_export] macro_rules! assert_channel_state { ($d:expr, $id:expr, $p:ident $(, $s: ident)?) => {{ - assert_channel_state_unlocked!($d.lock().unwrap(), $id, $p $(, $s)?) + assert_channel_state_unlocked!($d.lock().await, $id, $p $(, $s)?) }}; } diff --git a/mocks/Cargo.toml b/mocks/Cargo.toml index 1a8e9f7d..28f40eeb 100644 --- a/mocks/Cargo.toml +++ b/mocks/Cargo.toml @@ -1,10 +1,11 @@ [package] authors = ["Crypto Garage"] -edition = "2018" +edition = "2021" name = "mocks" version = "0.1.0" [dependencies] +async-trait = "0.1.83" bitcoin = "0.32.2" ddk-dlc = {path = "../ddk-dlc"} ddk-manager = {path = "../ddk-manager"} diff --git a/mocks/src/mock_oracle_provider.rs b/mocks/src/mock_oracle_provider.rs index 678430c3..80dacafd 100644 --- a/mocks/src/mock_oracle_provider.rs +++ b/mocks/src/mock_oracle_provider.rs @@ -55,12 +55,13 @@ impl Default for MockOracle { } } +#[async_trait::async_trait] impl Oracle for MockOracle { fn get_public_key(&self) -> XOnlyPublicKey { XOnlyPublicKey::from_keypair(&self.key_pair).0 } - fn get_announcement(&self, event_id: &str) -> Result { + async fn get_announcement(&self, event_id: &str) -> Result { let res = self .announcements .get(event_id) @@ -68,7 +69,7 @@ impl Oracle for MockOracle { Ok(res.clone()) } - fn get_attestation(&self, event_id: &str) -> Result { + async fn get_attestation(&self, event_id: &str) -> Result { let res = self .attestations .get(event_id) diff --git a/p2pd-oracle-client/Cargo.toml b/p2pd-oracle-client/Cargo.toml index bfa531e7..db811d52 100644 --- a/p2pd-oracle-client/Cargo.toml +++ b/p2pd-oracle-client/Cargo.toml @@ -6,14 +6,17 @@ license-file = "../LICENSE" name = "p2pd-oracle-client" repository = "https://github.com/p2pderivatives/rust-dlc/tree/master/p2pd-oracle-client" version = "0.1.0" +edition = "2021" [dependencies] +async-trait = "0.1.83" chrono = {version = "0.4.19", features = ["serde"]} ddk-manager = {path = "../ddk-manager"} ddk-messages = {path = "../ddk-messages", features = ["use-serde"]} -reqwest = {version = "0.11", features = ["blocking", "json"]} +reqwest = {version = "0.11", features = ["json"]} secp256k1-zkp = {version = "0.11.0" } serde = {version = "*", features = ["derive"]} [dev-dependencies] mockito = "0.31.0" +tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread", "test-util"] } diff --git a/p2pd-oracle-client/src/lib.rs b/p2pd-oracle-client/src/lib.rs index 7b82fb9c..b2e6f3b3 100644 --- a/p2pd-oracle-client/src/lib.rs +++ b/p2pd-oracle-client/src/lib.rs @@ -69,17 +69,19 @@ struct AttestationResponse { values: Vec, } -fn get(path: &str) -> Result +async fn get(path: &str) -> Result where T: serde::de::DeserializeOwned, { - reqwest::blocking::get(path) + reqwest::get(path) + .await .map_err(|x| { ddk_manager::error::Error::IOError( std::io::Error::new(std::io::ErrorKind::Other, x).into(), ) })? .json::() + .await .map_err(|e| ddk_manager::error::Error::OracleError(e.to_string())) } @@ -109,7 +111,7 @@ impl P2PDOracleClient { /// Try to create an instance of an oracle client connecting to the provided /// host. Returns an error if the host could not be reached. Panics if the /// oracle uses an incompatible format. - pub fn new(host: &str) -> Result { + pub async fn new(host: &str) -> Result { if host.is_empty() { return Err(DlcManagerError::InvalidParameters( "Invalid host".to_string(), @@ -121,7 +123,7 @@ impl P2PDOracleClient { host.to_string() }; let path = pubkey_path(&host); - let public_key = get::(&path)?.public_key; + let public_key = get::(&path).await?.public_key; Ok(P2PDOracleClient { host, public_key }) } } @@ -144,19 +146,23 @@ fn parse_event_id(event_id: &str) -> Result<(String, DateTime), DlcManagerE Ok((asset_id.to_string(), date_time)) } +#[async_trait::async_trait] impl Oracle for P2PDOracleClient { fn get_public_key(&self) -> XOnlyPublicKey { self.public_key } - fn get_announcement(&self, event_id: &str) -> Result { + async fn get_announcement( + &self, + event_id: &str, + ) -> Result { let (asset_id, date_time) = parse_event_id(event_id)?; let path = announcement_path(&self.host, &asset_id, &date_time); - let announcement = get(&path)?; + let announcement = get(&path).await?; Ok(announcement) } - fn get_attestation( + async fn get_attestation( &self, event_id: &str, ) -> Result { @@ -166,7 +172,7 @@ impl Oracle for P2PDOracleClient { event_id: _, signatures, values, - } = get::(&path)?; + } = get::(&path).await?; Ok(OracleAttestation { event_id: event_id.to_string(), @@ -202,8 +208,8 @@ mod tests { ).create() } - #[test] - fn get_public_key_test() { + #[tokio::test] + async fn get_public_key_test() { let url = &mockito::server_url(); let _m = pubkey_mock(); let expected_pk: XOnlyPublicKey = @@ -211,13 +217,15 @@ mod tests { .parse() .unwrap(); - let client = P2PDOracleClient::new(url).expect("Error creating client instance."); + let client = P2PDOracleClient::new(url) + .await + .expect("Error creating client instance."); assert_eq!(expected_pk, client.get_public_key()); } - #[test] - fn get_announcement_test() { + #[tokio::test] + async fn get_announcement_test() { let url = &mockito::server_url(); let _pubkey_mock = pubkey_mock(); let path: &str = &announcement_path( @@ -229,15 +237,18 @@ mod tests { ); let _m = mock("GET", path).with_body(r#"{"announcementSignature":"f83db0ca25e4c209b55156737b0c65470a9702fe9d1d19a129994786384289397895e403ff37710095a04a0841a95738e3e8bc35bdef6bce50bf34eeb182bd9b","oraclePublicKey":"10dc8cf51ae3ee1c7967ffb9c9633a5ab06206535d8e1319f005a01ba33bc05d","oracleEvent":{"oracleNonces":["aca32fc8dead13983c655638ef921f1d38ef2f5286e58b2a1dab32b6e086e208","89603f8179830590fdce45eb17ba8bdf74e295a4633b58b46c9ede8274774164","5f3fcdfbba9ec75cb0868e04ec1f97089b4153fb2076bd1e017048e9df633aa1","8436d00f7331491dc6512e560a1f2414be42e893992eccb495642eefc7c5bf37","0d2593764c9c27eba0be3ca6c71a2de4e49a5f4aa1ce1e2cc379be3939547501","414318491e96919e67583db7a47eb1f8b4f1194bcb5b5dcc4fd10492d89926e4","b9a5ded7295e0343f385e5abedfd9e5f4137de8f67de0afa9396f7e0f996ef79","badf0bfe230ed605161630d8e3a092d7448461042db38912bc6c6a0ab195ff71","6e4780213cd7ed9de1300146079b897cae89dec7800065f615974193f58aa6db","7b12b48ad95634ee4ca476dd57e634fddc328e10276e71d27e0ae626fad7d699","a8058604adf590a1c38f8be19aa44175eb2d1130eb4d7f39a34f89f0a3fbed27","ffc3208f60b585cdc778be1290b352c34c22652d5348a87885816bcf17a80116","cb34c13f80b49e729e863035f30e1f8ea7777618eedb6d666c3b1c85a5b8a637","5000991f4631c0bba5d026f02125fdbe77e019dde57d31ce7f23ae3601a18623","094433a2432b81bbb6d6b7d65dc3498e2a7c9de5f35672d67097d54d920eadd2","11dff6b40b0938e1943c7888633d88871c2a2a1c16f412b22b80ba7ed8af8788","d5957f1a199b4abbc06894479c722ad0c4f120f0d5afeb76d589127213e33170","80e09bb453e6a0a444ec3ba222a62ecd59540b9dd8280566a17bebdfdfbd7a9e","0fe775b79b2172cb961e7c1aa54d521360903680680aaa55ea8be0404ee3768c","bfcdbb2cbcffba41048149d4bcf2a41cd5fd0a713df6f48104ade3022c284575"],"eventMaturityEpoch":1653865200,"eventDescriptor":{"digitDecompositionEvent":{"base":2,"isSigned":false,"unit":"usd/btc","precision":0,"nbDigits":20}},"eventId":"btcusd1653865200"}}"#).create(); - let client = P2PDOracleClient::new(url).expect("Error creating client instance"); + let client = P2PDOracleClient::new(url) + .await + .expect("Error creating client instance"); client .get_announcement("btcusd1624943400") + .await .expect("Error getting announcement"); } - #[test] - fn get_attestation_test() { + #[tokio::test] + async fn get_attestation_test() { let url = &mockito::server_url(); let _pubkey_mock = pubkey_mock(); let path: &str = &attestation_path( @@ -250,10 +261,13 @@ mod tests { let _m = mock("GET", path).with_body(r#"{"eventId":"btcusd1653517020","signatures":["ee05b1211d5f974732b10107dd302da062be47cd18f061c5080a50743412f9fd590cad90cfea762472e6fe865c4223bd388c877b7881a27892e15843ff1ac360","59ab83597089b48f5b3c2fd07c11edffa6b1180bdb6d9e7d6924979292d9c53fe79396ceb0782d5941c284d1642377136c06b2d9c2b85bda5969a773a971b5b0","d1f8c31a83bb34433da5b9808bb3692dd212b9022b7bc8f269fc817e96a7195db18262e934bebd4e68a3f2c96550826a5530350662df4c86c004f5cf1121ca67","e5cec554c39c4dd544d70175128271eecad77c1e3eaa6994c657e257d5c1c9dcd19b041ea8030e75448245b7f91705ad914c32761671a6172f928904b439ea6b","a209116d20f0931113c0880e8cd22d3f003609a32322ff8df241ef16e7d4efd1a9b723f582a22073e21188635f09f41f270f3126014542861be14b62b09c0ecc","f1da0b482f08f545a92338392b71cec33d948a5e5732ee4d5c0a87bd6b6cc12feeb1498da7afd93ae48ec4ce581ee79c0e92f338d3777c2ef06578e4ec1a853c","d9ab68244a3b47cc8cbd5a972f2f5059fc6b9711dba8d4a7a23607a99b9655593bab3abc1d3b02402cd0809c3c7016c741742efb363227de2bcfdcf290a053b3","c1146c1767a947f77794d05a2f58e50af824e3c8d70adde883e58d2dc1ddb157323b0aaf8cfb5b076a12395756bdcda64ab5d4799e43c88a41993659e6d49471","0d29d9383c9ee41055e1cb40104c9ca75280162779c0162cb6bf9aca2b223aba17de4b3f0f29ae6b749f22ba467b7e9f05456e8abb3ec328f62b7a924c6d4828","2bcc54002ceb271a940f24bc6dd0562b99c2d76cfb8f145f42ac37bc34fd3e94adba1194c5be91932b818c5715c73f287e066e228d796a373c4aec67fd777070","a91f77e3435c577682ff744d6f7da66c865a42e8645276dedf2ed2b8bc4c80285dff4b553b2231592e0fa8b4f242acb6888519fe82c457cc5204e5d9d511303a","546409d6bcdcfd5bef39957c8b1b09f7805b08ec2311bc73cf6927ae11f3567ffe8428aa7faa661518e9c02a702212ab05e494aab84624c3dd1a710f8c4c369b","9d601ee8a3d28dcdfdd05581f1b24d6e5a576f0b5544eb7c9921cb87a23fdb293c1edca89b43b5b84c1e305fbe52facbe6b03575aed8f95b4faccc90e0eb45ef","636b8028e9cd6cba6be5b3c1789b62aecfc17e9c28d7a621cfad2c3cf751046528028e1dbd6cee050d5d570cf5a3d8986471d73e7edca4093e36fc8e1097fb65","57c6337b52dc7fd8f49b29105f168fc9b4cb88ed2ba5f0e9a80a21e20836f87f875c3fe92afb437dd5647630b54eda6ba1be76ba6df8b641eb2e8be8ff1182dc","9e8843e32f9de4cd6d5bb9e938fd014babe11bb1faf35fc411d754259bc374f34dd841ed91f6bb3f030bc55a4791cdc41471c33b3f05fd35b9d1768fd381f953","97da4963747ab5e50534b93274065cba4fd24e6b7a9d3310db2596af24f70961fb03535e2a5ae272f7ea14e86daafa57073631596fecf7ceadf4ae3e6941b69e","94a414569743f87f1462a503be8cff1f229096d190b8b1349519c612b74eea872d5d763570aaaa54fad0605a43d742203bce489deea5570750030191e293c253","4d7117b89aad73eca7b341749bd54ffdd459b9b8b4ff128344d09273f66a3d2c01d2c86b61f7642d6e81f488580b456685cd68660458cff83b8858a05c9a1f4d","b12153a393a4fddac3079c1878cb89afccfe0ac8f539743c0608049f445e49ac7c89e33fcf832cda8d7e8a4f4dae94a303170f16c697feed8b78015873bd5ffc"],"values":["0","0","0","0","0","1","1","1","0","1","0","0","0","0","1","1","1","0","1","0"]}"#).create(); - let client = P2PDOracleClient::new(url).expect("Error creating client instance"); + let client = P2PDOracleClient::new(url) + .await + .expect("Error creating client instance"); client .get_attestation("btcusd1624943400") + .await .expect("Error getting attestation"); } } diff --git a/sample/src/cli.rs b/sample/src/cli.rs index b51f93b9..0d91e515 100644 --- a/sample/src/cli.rs +++ b/sample/src/cli.rs @@ -19,8 +19,9 @@ use std::io; use std::io::{BufRead, Write}; use std::net::{SocketAddr, ToSocketAddrs}; use std::str::SplitWhitespace; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::Duration; +use tokio::sync::Mutex; #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] @@ -97,7 +98,7 @@ pub(crate) async fn poll_for_user_input( print!("> "); io::stdout().flush().unwrap(); // Without flushing, the `>` doesn't print for line in stdin.lock().lines() { - process_incoming_messages(&peer_manager, &dlc_manager, &dlc_message_handler); + process_incoming_messages(&peer_manager, &dlc_manager, &dlc_message_handler).await; let line = line.unwrap(); let mut words = line.split_whitespace(); if let Some(word) = words.next() { @@ -164,32 +165,30 @@ pub(crate) async fn poll_for_user_input( .expect("Error deserializing contract input."); let manager_clone = dlc_manager.clone(); let is_contract = o == "offercontract"; - let offer = tokio::task::spawn_blocking(move || { - if is_contract { - DlcMessage::Offer( - manager_clone - .lock() - .unwrap() - .send_offer(&contract_input, pubkey) - .expect("Error sending offer"), - ) - } else { - DlcMessage::OfferChannel( - manager_clone - .lock() - .unwrap() - .offer_channel(&contract_input, pubkey) - .expect("Error sending offer channel"), - ) - } - }) - .await - .unwrap(); + let offer = if is_contract { + DlcMessage::Offer( + manager_clone + .lock() + .await + .send_offer(&contract_input, pubkey) + .await + .expect("Error sending offer"), + ) + } else { + DlcMessage::OfferChannel( + manager_clone + .lock() + .await + .offer_channel(&contract_input, pubkey) + .await + .expect("Error sending offer channel"), + ) + }; dlc_message_handler.send_message(pubkey, offer); peer_manager.process_events(); } "listoffers" => { - let locked_manager = dlc_manager.lock().unwrap(); + let locked_manager = dlc_manager.lock().await; for offer in locked_manager .get_store() .get_contract_offers() @@ -213,7 +212,7 @@ pub(crate) async fn poll_for_user_input( let (_, node_id, msg) = dlc_manager .lock() - .unwrap() + .await .accept_contract_offer(&contract_id) .expect("Error accepting contract."); dlc_message_handler.send_message(node_id, DlcMessage::Accept(msg)); @@ -222,62 +221,60 @@ pub(crate) async fn poll_for_user_input( "listcontracts" => { let manager_clone = dlc_manager.clone(); // Because the oracle client is currently blocking we need to use `spawn_blocking` here. - tokio::task::spawn_blocking(move || { - manager_clone - .lock() - .unwrap() - .periodic_check(true) - .expect("Error doing periodic check."); - let contracts = manager_clone - .lock() - .unwrap() - .get_store() - .get_contracts() - .expect("Error retrieving contract list."); - for contract in contracts { - let id = hex_str(&contract.get_id()); - match contract { - Contract::Offered(_) => { - println!("Offered contract: {}", id); - } - Contract::Accepted(_) => { - println!("Accepted contract: {}", id); - } - Contract::Confirmed(_) => { - println!("Confirmed contract: {}", id); - } - Contract::Signed(_) => { - println!("Signed contract: {}", id); - } - Contract::Closed(closed) => { - println!("Closed contract: {}", id); - if let Some(attestations) = closed.attestations { - println!( - "Outcomes: {:?}", - attestations - .iter() - .map(|x| x.outcomes.clone()) - .collect::>() - ); - } - println!("PnL: {} sats", closed.pnl) - } - Contract::Refunded(_) => { - println!("Refunded contract: {}", id); - } - Contract::FailedAccept(_) | Contract::FailedSign(_) => { - println!("Failed contract: {}", id); + + manager_clone + .lock() + .await + .periodic_check(true) + .await + .expect("Error doing periodic check."); + let contracts = manager_clone + .lock() + .await + .get_store() + .get_contracts() + .expect("Error retrieving contract list."); + for contract in contracts { + let id = hex_str(&contract.get_id()); + match contract { + Contract::Offered(_) => { + println!("Offered contract: {}", id); + } + Contract::Accepted(_) => { + println!("Accepted contract: {}", id); + } + Contract::Confirmed(_) => { + println!("Confirmed contract: {}", id); + } + Contract::Signed(_) => { + println!("Signed contract: {}", id); + } + Contract::Closed(closed) => { + println!("Closed contract: {}", id); + if let Some(attestations) = closed.attestations { + println!( + "Outcomes: {:?}", + attestations + .iter() + .map(|x| x.outcomes.clone()) + .collect::>() + ); } - Contract::Rejected(_) => println!("Rejected contract: {}", id), - Contract::PreClosed(_) => println!("Pre-closed contract: {}", id), + println!("PnL: {} sats", closed.pnl) + } + Contract::Refunded(_) => { + println!("Refunded contract: {}", id); + } + Contract::FailedAccept(_) | Contract::FailedSign(_) => { + println!("Failed contract: {}", id); } + Contract::Rejected(_) => println!("Rejected contract: {}", id), + Contract::PreClosed(_) => println!("Pre-closed contract: {}", id), } - }) - .await - .expect("Error listing contract info"); + } } "listchanneloffers" => { - let locked_manager = dlc_manager.lock().unwrap(); + let locked_manager = dlc_manager.lock().await; for offer in locked_manager .get_store() .get_offered_channels() @@ -305,7 +302,7 @@ pub(crate) async fn poll_for_user_input( let (msg, _, _, node_id) = dlc_manager .lock() - .unwrap() + .await .accept_channel(&channel_id) .expect("Error accepting channel."); dlc_message_handler.send_message(node_id, DlcMessage::AcceptChannel(msg)); @@ -323,7 +320,7 @@ pub(crate) async fn poll_for_user_input( let (msg, node_id) = dlc_manager .lock() - .unwrap() + .await .settle_offer(&channel_id, counter_payout) .expect("Error getting settle offer message."); dlc_message_handler.send_message(node_id, DlcMessage::SettleOffer(msg)); @@ -333,7 +330,7 @@ pub(crate) async fn poll_for_user_input( let channel_id = read_id_or_continue!(words, l, "channel id"); let (msg, node_id) = dlc_manager .lock() - .unwrap() + .await .accept_settle_offer(&channel_id) .expect("Error accepting settle channel offer."); dlc_message_handler.send_message(node_id, DlcMessage::SettleAccept(msg)); @@ -343,14 +340,14 @@ pub(crate) async fn poll_for_user_input( let channel_id = read_id_or_continue!(words, l, "channel id"); let (msg, node_id) = dlc_manager .lock() - .unwrap() + .await .reject_settle_offer(&channel_id) .expect("Error rejecting settle channel offer."); dlc_message_handler.send_message(node_id, DlcMessage::Reject(msg)); peer_manager.process_events(); } "listsettlechanneloffers" => { - let locked_manager = dlc_manager.lock().unwrap(); + let locked_manager = dlc_manager.lock().await; for channel in locked_manager .get_store() .get_signed_channels(Some(SignedChannelStateType::SettledReceived)) @@ -380,20 +377,17 @@ pub(crate) async fn poll_for_user_input( let contract_input: ContractInput = serde_json::from_str(&contract_input_str) .expect("Error deserializing contract input."); let manager_clone = dlc_manager.clone(); - let (renew_offer, node_id) = tokio::task::spawn_blocking(move || { - manager_clone - .lock() - .unwrap() - .renew_offer(&channel_id, counter_payout, &contract_input) - .expect("Error sending offer") - }) - .await - .unwrap(); + let (renew_offer, node_id) = manager_clone + .lock() + .await + .renew_offer(&channel_id, counter_payout, &contract_input) + .await + .expect("Error sending offer"); dlc_message_handler.send_message(node_id, DlcMessage::RenewOffer(renew_offer)); peer_manager.process_events(); } "listrenewchanneloffers" => { - let locked_manager = dlc_manager.lock().unwrap(); + let locked_manager = dlc_manager.lock().await; for channel in locked_manager .get_store() .get_signed_channels(Some(SignedChannelStateType::RenewOffered)) @@ -426,7 +420,7 @@ pub(crate) async fn poll_for_user_input( let channel_id = read_id_or_continue!(words, l, "channel id"); let (msg, node_id) = dlc_manager .lock() - .unwrap() + .await .accept_renew_offer(&channel_id) .expect("Error accepting channel."); dlc_message_handler.send_message(node_id, DlcMessage::RenewAccept(msg)); @@ -436,14 +430,14 @@ pub(crate) async fn poll_for_user_input( let channel_id = read_id_or_continue!(words, l, "channel id"); let (msg, node_id) = dlc_manager .lock() - .unwrap() + .await .reject_renew_offer(&channel_id) .expect("Error rejecting settle channel offer."); dlc_message_handler.send_message(node_id, DlcMessage::Reject(msg)); peer_manager.process_events(); } "listsignedchannels" => { - let locked_manager = dlc_manager.lock().unwrap(); + let locked_manager = dlc_manager.lock().await; for channel in locked_manager .get_store() .get_signed_channels(None) @@ -589,7 +583,7 @@ pub(crate) fn parse_peer_info( Ok((pubkey.unwrap(), peer_addr.unwrap().unwrap())) } -fn process_incoming_messages( +async fn process_incoming_messages( peer_manager: &Arc, dlc_manager: &Arc>, dlc_message_handler: &Arc, @@ -601,7 +595,7 @@ fn process_incoming_messages( println!("Processing message from {}", node_id); let resp = dlc_manager .lock() - .unwrap() + .await .on_dlc_message(&message, node_id) .expect("Error processing message"); if let Some(msg) = resp { diff --git a/sample/src/main.rs b/sample/src/main.rs index e725e26b..8418219c 100644 --- a/sample/src/main.rs +++ b/sample/src/main.rs @@ -18,8 +18,9 @@ use p2pd_oracle_client::P2PDOracleClient; use std::collections::hash_map::HashMap; use std::env; use std::fs; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::SystemTime; +use tokio::sync::Mutex; pub(crate) type PeerManager = LdkPeerManager< SocketDescriptor, @@ -72,11 +73,9 @@ async fn main() { // client uses reqwest in blocking mode to satisfy the non async oracle interface // so we need to use `spawn_blocking`. let oracle_host = config.oracle_config.host; - let oracle = tokio::task::spawn_blocking(move || { - P2PDOracleClient::new(&oracle_host).expect("Error creating oracle client") - }) - .await - .unwrap(); + let oracle = P2PDOracleClient::new(&oracle_host) + .await + .expect("Error creating oracle client"); let mut oracles = HashMap::new(); oracles.insert(oracle.get_public_key(), Box::new(oracle));