diff --git a/bitcoin-rpc-provider/Cargo.toml b/bitcoin-rpc-provider/Cargo.toml index 34cec4e8..4965c683 100644 --- a/bitcoin-rpc-provider/Cargo.toml +++ b/bitcoin-rpc-provider/Cargo.toml @@ -5,6 +5,7 @@ name = "bitcoin-rpc-provider" version = "0.1.0" [dependencies] +async-trait = "0.1.83" bitcoin = {version = "0.32.2"} bitcoincore-rpc = {version = "0.19.0"} bitcoincore-rpc-json = {version = "0.19.0"} diff --git a/bitcoin-rpc-provider/src/lib.rs b/bitcoin-rpc-provider/src/lib.rs index 5791846b..782aca54 100644 --- a/bitcoin-rpc-provider/src/lib.rs +++ b/bitcoin-rpc-provider/src/lib.rs @@ -396,8 +396,9 @@ impl Wallet for BitcoinCoreProvider { } } +#[async_trait::async_trait] impl Blockchain for BitcoinCoreProvider { - fn send_transaction(&self, transaction: &Transaction) -> Result<(), ManagerError> { + async fn send_transaction(&self, transaction: &Transaction) -> Result<(), ManagerError> { self.client .lock() .unwrap() @@ -418,7 +419,7 @@ impl Blockchain for BitcoinCoreProvider { Ok(network) } - fn get_blockchain_height(&self) -> Result { + async fn get_blockchain_height(&self) -> Result { self.client .lock() .unwrap() @@ -426,7 +427,7 @@ impl Blockchain for BitcoinCoreProvider { .map_err(rpc_err_to_manager_err) } - fn get_block_at_height(&self, height: u64) -> Result { + async fn get_block_at_height(&self, height: u64) -> Result { let client = self.client.lock().unwrap(); let hash = client .get_block_hash(height) @@ -434,7 +435,7 @@ impl Blockchain for BitcoinCoreProvider { client.get_block(&hash).map_err(rpc_err_to_manager_err) } - fn get_transaction(&self, tx_id: &Txid) -> Result { + async fn get_transaction(&self, tx_id: &Txid) -> Result { let tx_info = self .client .lock() @@ -446,7 +447,7 @@ impl Blockchain for BitcoinCoreProvider { Ok(tx) } - fn get_transaction_confirmations(&self, tx_id: &Txid) -> Result { + async fn get_transaction_confirmations(&self, tx_id: &Txid) -> Result { let tx_info_res = self.client.lock().unwrap().get_transaction(tx_id, None); match tx_info_res { Ok(tx_info) => Ok(tx_info.info.confirmations as u32), diff --git a/ddk-manager/Cargo.toml b/ddk-manager/Cargo.toml index 3c1cf80a..0e1eff20 100644 --- a/ddk-manager/Cargo.toml +++ b/ddk-manager/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Crypto Garage", "benny b "] description = "Creation and handling of Discrete Log Contracts (DLC)." -edition = "2018" +edition = "2021" homepage = "https://github.com/bennyhodl/rust-dlc" license-file = "../LICENSE" name = "ddk-manager" @@ -21,6 +21,7 @@ bitcoin = { version = "0.32.2", default-features = false } ddk-dlc = { version = "0.7.0", default-features = false, path = "../ddk-dlc" } ddk-messages = { version = "0.7.0", default-features = false, path = "../ddk-messages" } ddk-trie = { version = "0.7.0", default-features = false, path = "../ddk-trie" } +futures = "0.3.31" hex = { package = "hex-conservative", version = "0.1" } lightning = { version = "0.0.125", default-features = false, features = ["grind_signatures"] } log = "0.4.14" @@ -43,6 +44,7 @@ secp256k1-zkp = {version = "0.11.0", features = ["hashes", "rand", "rand-std", " serde = "1.0" serde_json = "1.0" simple-wallet = {path = "../simple-wallet"} +tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread", "test-util", "fs"] } [[bench]] harness = false diff --git a/ddk-manager/src/channel_updater.rs b/ddk-manager/src/channel_updater.rs index 3faec908..927adc27 100644 --- a/ddk-manager/src/channel_updater.rs +++ b/ddk-manager/src/channel_updater.rs @@ -70,7 +70,7 @@ pub(crate) use get_signed_channel_state; /// Creates an [`OfferedChannel`] and an associated [`OfferedContract`] using /// the given parameter. -pub fn offer_channel( +pub async fn offer_channel( secp: &Secp256k1, contract: &ContractInput, counter_party: &PublicKey, @@ -99,7 +99,8 @@ where wallet, &signer, blockchain, - )?; + ) + .await?; let party_points = crate::utils::get_party_base_points(secp, signer_provider)?; let offered_contract = OfferedContract::new( @@ -142,7 +143,7 @@ where /// Move the given [`OfferedChannel`] and [`OfferedContract`] to an [`AcceptedChannel`] /// and [`AcceptedContract`], returning them as well as the [`AcceptChannel`] /// message to be sent to the counter party. -pub fn accept_channel_offer( +pub async fn accept_channel_offer( secp: &Secp256k1, offered_channel: &OfferedChannel, offered_contract: &OfferedContract, @@ -167,7 +168,8 @@ where wallet, &signer, blockchain, - )?; + ) + .await?; let per_update_seed = signer_provider.get_new_secret_key()?; diff --git a/ddk-manager/src/contract_updater.rs b/ddk-manager/src/contract_updater.rs index 64c6860c..7fcc06fb 100644 --- a/ddk-manager/src/contract_updater.rs +++ b/ddk-manager/src/contract_updater.rs @@ -27,7 +27,14 @@ use crate::{ /// Creates an [`OfferedContract`] and [`OfferDlc`] message from the provided /// contract and oracle information. -pub fn offer_contract( +pub async fn offer_contract< + W: Deref, + B: Deref, + T: Deref, + X: ContractSigner, + SP: Deref, + C: Signing, +>( secp: &Secp256k1, contract_input: &ContractInput, oracle_announcements: Vec>, @@ -56,7 +63,8 @@ where wallet, &signer, blockchain, - )?; + ) + .await?; let offered_contract = OfferedContract::new( id, @@ -77,7 +85,7 @@ where /// Creates an [`AcceptedContract`] and produces /// the accepting party's cet adaptor signatures. -pub fn accept_contract( +pub async fn accept_contract( secp: &Secp256k1, offered_contract: &OfferedContract, wallet: &W, @@ -99,7 +107,8 @@ where wallet, &signer, blockchain, - )?; + ) + .await?; let dlc_transactions = ddk_dlc::create_dlc_transactions( &offered_contract.offer_params, @@ -765,8 +774,8 @@ mod tests { use mocks::ddk_manager::contract::offered_contract::OfferedContract; use secp256k1_zkp::PublicKey; - #[test] - fn accept_contract_test() { + #[tokio::test] + async fn accept_contract_test() { let offer_dlc = serde_json::from_str(include_str!("../test_inputs/offer_contract.json")).unwrap(); let dummy_pubkey: PublicKey = @@ -780,10 +789,8 @@ mod tests { let utxo_value: u64 = offered_contract.total_collateral - offered_contract.offer_params.collateral + crate::utils::get_half_common_fee(fee_rate).unwrap(); - let wallet = Rc::new(mocks::mock_wallet::MockWallet::new( - &blockchain, - &[utxo_value, 10000], - )); + let wallet = + Rc::new(mocks::mock_wallet::MockWallet::new(&blockchain, &[utxo_value, 10000]).await); mocks::ddk_manager::contract_updater::accept_contract( secp256k1_zkp::SECP256K1, @@ -792,6 +799,7 @@ mod tests { &wallet, &blockchain, ) + .await .expect("Not to fail"); } } diff --git a/ddk-manager/src/lib.rs b/ddk-manager/src/lib.rs index 62670a1b..cd372523 100644 --- a/ddk-manager/src/lib.rs +++ b/ddk-manager/src/lib.rs @@ -167,20 +167,21 @@ pub trait Wallet { fn unreserve_utxos(&self, outpoints: &[OutPoint]) -> Result<(), Error>; } +#[async_trait::async_trait] /// Blockchain trait provides access to the bitcoin blockchain. pub trait Blockchain { /// Broadcast the given transaction to the bitcoin network. - fn send_transaction(&self, transaction: &Transaction) -> Result<(), Error>; + async fn send_transaction(&self, transaction: &Transaction) -> Result<(), Error>; /// Returns the network currently used (mainnet, testnet or regtest). fn get_network(&self) -> Result; /// Returns the height of the blockchain - fn get_blockchain_height(&self) -> Result; + async fn get_blockchain_height(&self) -> Result; /// Returns the block at given height - fn get_block_at_height(&self, height: u64) -> Result; + async fn get_block_at_height(&self, height: u64) -> Result; /// Get the transaction with given id. - fn get_transaction(&self, tx_id: &Txid) -> Result; + async fn get_transaction(&self, tx_id: &Txid) -> Result; /// Get the number of confirmation for the transaction with given id. - fn get_transaction_confirmations(&self, tx_id: &Txid) -> Result; + async fn get_transaction_confirmations(&self, tx_id: &Txid) -> Result; } /// Storage trait provides functionalities to store and retrieve DLCs. @@ -225,14 +226,15 @@ pub trait Storage { fn get_chain_monitor(&self) -> Result, Error>; } +#[async_trait::async_trait] /// Oracle trait provides access to oracle information. pub trait Oracle { /// Returns the public key of the oracle. fn get_public_key(&self) -> XOnlyPublicKey; /// Returns the announcement for the event with the given id if found. - fn get_announcement(&self, event_id: &str) -> Result; + async fn get_announcement(&self, event_id: &str) -> Result; /// Returns the attestation for the event with the given id if found. - fn get_attestation(&self, event_id: &str) -> Result; + async fn get_attestation(&self, event_id: &str) -> Result; } /// Represents a UTXO. diff --git a/ddk-manager/src/manager.rs b/ddk-manager/src/manager.rs index 860549ed..50792530 100644 --- a/ddk-manager/src/manager.rs +++ b/ddk-manager/src/manager.rs @@ -31,6 +31,9 @@ use ddk_messages::channel::{ }; use ddk_messages::oracle_msgs::{OracleAnnouncement, OracleAttestation}; use ddk_messages::{AcceptDlc, Message as DlcMessage, OfferDlc, SignDlc}; +use futures::stream; +use futures::stream::FuturesUnordered; +use futures::{StreamExt, TryStreamExt}; use hex::DisplayHex; use lightning::chain::chaininterface::FeeEstimator; use lightning::ln::chan_utils::{ @@ -137,7 +140,7 @@ macro_rules! check_for_timed_out_channels { if let SignedChannelState::$state { timeout, .. } = channel.state { let is_timed_out = timeout < $manager.time.unix_time_now(); if is_timed_out { - match $manager.force_close_channel_internal(channel, true) { + match $manager.force_close_channel_internal(channel, true).await { Err(e) => error!("Error force closing channel {}", e), _ => {} } @@ -159,7 +162,7 @@ where F::Target: FeeEstimator, { /// Create a new Manager struct. - pub fn new( + pub async fn new( wallet: W, signer_provider: SP, blockchain: B, @@ -168,7 +171,7 @@ where time: T, fee_estimator: F, ) -> Result { - let init_height = blockchain.get_blockchain_height()?; + let init_height = blockchain.get_blockchain_height().await?; let chain_monitor = Mutex::new( store .get_chain_monitor()? @@ -196,7 +199,7 @@ where } /// Function called to pass a DlcMessage to the Manager. - pub fn on_dlc_message( + pub async fn on_dlc_message( &self, msg: &DlcMessage, counter_party: PublicKey, @@ -208,7 +211,7 @@ where } DlcMessage::Accept(a) => Ok(Some(self.on_accept_message(a, &counter_party)?)), DlcMessage::Sign(s) => { - self.on_sign_message(s, &counter_party)?; + self.on_sign_message(s, &counter_party).await?; Ok(None) } DlcMessage::OfferChannel(o) => { @@ -219,7 +222,7 @@ where self.on_accept_channel(a, &counter_party)?, ))), DlcMessage::SignChannel(s) => { - self.on_sign_channel(s, &counter_party)?; + self.on_sign_channel(s, &counter_party).await?; Ok(None) } DlcMessage::SettleOffer(s) => match self.on_settle_offer(s, &counter_party)? { @@ -269,18 +272,16 @@ where /// and an OfferDlc message returned. /// /// This function will fetch the oracle announcements from the oracle. - pub fn send_offer( + pub async fn send_offer( &self, contract_input: &ContractInput, counter_party: PublicKey, ) -> Result { - let oracle_announcements = contract_input - .contract_infos - .iter() - .map(|x| self.get_oracle_announcements(&x.oracles)) - .collect::, Error>>()?; + // If the oracle announcement fails to retrieve, then log and continue. + let oracle_announcements = self.oracle_announcements(contract_input).await?; self.send_offer_with_announcements(contract_input, counter_party, oracle_announcements) + .await } /// Function called to create a new DLC. The offered contract will be stored @@ -288,7 +289,7 @@ where /// /// This function allows to pass the oracle announcements directly instead of /// fetching them from the oracle. - pub fn send_offer_with_announcements( + pub async fn send_offer_with_announcements( &self, contract_input: &ContractInput, counter_party: PublicKey, @@ -304,7 +305,8 @@ where &self.blockchain, &self.time, &self.signer_provider, - )?; + ) + .await?; offered_contract.validate()?; @@ -314,7 +316,7 @@ where } /// Function to call to accept a DLC for which an offer was received. - pub fn accept_contract_offer( + pub async fn accept_contract_offer( &self, contract_id: &ContractId, ) -> Result<(ContractId, PublicKey, AcceptDlc), Error> { @@ -329,7 +331,8 @@ where &self.wallet, &self.signer_provider, &self.blockchain, - )?; + ) + .await?; self.wallet.import_address(&Address::p2wsh( &accepted_contract.dlc_transactions.funding_script_pubkey, @@ -349,8 +352,8 @@ where /// /// Consumers **MUST** call this periodically in order to /// determine when pending transactions reach confirmation. - pub fn periodic_chain_monitor(&self) -> Result<(), Error> { - let cur_height = self.blockchain.get_blockchain_height()?; + pub async fn periodic_chain_monitor(&self) -> Result<(), Error> { + let cur_height = self.blockchain.get_blockchain_height().await?; let last_height = self.chain_monitor.lock().unwrap().last_height; // TODO(luckysori): We could end up reprocessing a block at @@ -362,7 +365,7 @@ where } for height in last_height + 1..=cur_height { - let block = self.blockchain.get_block_at_height(height)?; + let block = self.blockchain.get_block_at_height(height).await?; self.chain_monitor .lock() @@ -375,13 +378,13 @@ where /// Function to call to check the state of the currently executing DLCs and /// update them if possible. - pub fn periodic_check(&self, check_channels: bool) -> Result<(), Error> { - self.check_signed_contracts()?; - self.check_confirmed_contracts()?; - self.check_preclosed_contracts()?; + pub async fn periodic_check(&self, check_channels: bool) -> Result<(), Error> { + self.check_signed_contracts().await?; + self.check_confirmed_contracts().await?; + self.check_preclosed_contracts().await?; if check_channels { - self.channel_checks()?; + self.channel_checks().await?; } Ok(()) @@ -448,7 +451,11 @@ where Ok(DlcMessage::Sign(signed_msg)) } - fn on_sign_message(&self, sign_message: &SignDlc, peer_id: &PublicKey) -> Result<(), Error> { + async fn on_sign_message( + &self, + sign_message: &SignDlc, + peer_id: &PublicKey, + ) -> Result<(), Error> { let accepted_contract = get_contract_in_state!(self, &sign_message.contract_id, Accepted, Some(*peer_id))?; @@ -465,12 +472,12 @@ where self.store .update_contract(&Contract::Signed(signed_contract))?; - self.blockchain.send_transaction(&fund_tx)?; + self.blockchain.send_transaction(&fund_tx).await?; Ok(()) } - fn get_oracle_announcements( + async fn get_oracle_announcements( &self, oracle_inputs: &OracleInput, ) -> Result, Error> { @@ -480,7 +487,8 @@ where .oracles .get(pubkey) .ok_or_else(|| Error::InvalidParameters("Unknown oracle public key".to_string()))?; - announcements.push(oracle.get_announcement(&oracle_inputs.event_id)?.clone()); + let announcement = oracle.get_announcement(&oracle_inputs.event_id).await?; + announcements.push(announcement); } Ok(announcements) @@ -518,14 +526,17 @@ where Err(e) } - fn check_signed_contract(&self, contract: &SignedContract) -> Result<(), Error> { - let confirmations = self.blockchain.get_transaction_confirmations( - &contract - .accepted_contract - .dlc_transactions - .fund - .compute_txid(), - )?; + async fn check_signed_contract(&self, contract: &SignedContract) -> Result<(), Error> { + let confirmations = self + .blockchain + .get_transaction_confirmations( + &contract + .accepted_contract + .dlc_transactions + .fund + .compute_txid(), + ) + .await?; if confirmations >= NB_CONFIRMATIONS { self.store .update_contract(&Contract::Confirmed(contract.clone()))?; @@ -533,9 +544,9 @@ where Ok(()) } - fn check_signed_contracts(&self) -> Result<(), Error> { + async fn check_signed_contracts(&self) -> Result<(), Error> { for c in self.store.get_signed_contracts()? { - if let Err(e) = self.check_signed_contract(&c) { + if let Err(e) = self.check_signed_contract(&c).await { error!( "Error checking confirmed contract {}: {}", c.accepted_contract.get_contract_id_string(), @@ -547,13 +558,13 @@ where Ok(()) } - fn check_confirmed_contracts(&self) -> Result<(), Error> { + async fn check_confirmed_contracts(&self) -> Result<(), Error> { for c in self.store.get_confirmed_contracts()? { // Confirmed contracts from channel are processed in channel specific methods. if c.channel_id.is_some() { continue; } - if let Err(e) = self.check_confirmed_contract(&c) { + if let Err(e) = self.check_confirmed_contract(&c).await { error!( "Error checking confirmed contract {}: {}", c.accepted_contract.get_contract_id_string(), @@ -565,7 +576,7 @@ where Ok(()) } - fn get_closable_contract_info<'a>( + async fn get_closable_contract_info<'a>( &'a self, contract: &'a SignedContract, ) -> ClosableContractInfo<'a> { @@ -581,26 +592,54 @@ where .enumerate() .collect(); if matured.len() >= contract_info.threshold { - let attestations: Vec<_> = matured - .iter() - .filter_map(|(i, announcement)| { - let oracle = self.oracles.get(&announcement.oracle_public_key)?; - let attestation = oracle + let attestations = stream::iter(matured.iter()) + .map(|(i, announcement)| async move { + // First try to get the oracle + let oracle = match self.oracles.get(&announcement.oracle_public_key) { + Some(oracle) => oracle, + None => { + log::debug!( + "Oracle not found for key: {}", + announcement.oracle_public_key + ); + return None; + } + }; + + // Then try to get the attestation + let attestation = match oracle .get_attestation(&announcement.oracle_event.event_id) - .ok()?; - attestation - .validate(&self.secp, announcement) - .map_err(|_| { + .await + { + Ok(attestation) => attestation, + Err(e) => { log::error!( - "Oracle attestation is not valid. pubkey={} event_id={}", - announcement.oracle_public_key, - announcement.oracle_event.event_id - ) - }) - .ok()?; + "Attestation not found for event. id={} error={}", + announcement.oracle_event.event_id, + e.to_string() + ); + return None; + } + }; + + // Validate the attestation + if let Err(e) = attestation.validate(&self.secp, announcement) { + log::error!( + "Oracle attestation is not valid. pubkey={} event_id={}, error={:?}", + announcement.oracle_public_key, + announcement.oracle_event.event_id, + e + ); + return None; + } + Some((*i, attestation)) }) - .collect(); + .collect::>() + .await + .filter_map(|result| async move { result }) // Filter out None values + .collect::>() + .await; if attestations.len() >= contract_info.threshold { return Some((contract_info, adaptor_info, attestations)); } @@ -609,8 +648,8 @@ where None } - fn check_confirmed_contract(&self, contract: &SignedContract) -> Result<(), Error> { - let closable_contract_info = self.get_closable_contract_info(contract); + async fn check_confirmed_contract(&self, contract: &SignedContract) -> Result<(), Error> { + let closable_contract_info = self.get_closable_contract_info(contract).await; if let Some((contract_info, adaptor_info, attestations)) = closable_contract_info { let offer = &contract.accepted_contract.offered_contract; let signer = self.signer_provider.derive_contract_signer(offer.keys_id)?; @@ -622,11 +661,14 @@ where &attestations, &signer, )?; - match self.close_contract( - contract, - cet, - attestations.iter().map(|x| x.1.clone()).collect(), - ) { + match self + .close_contract( + contract, + cet, + attestations.iter().map(|x| x.1.clone()).collect(), + ) + .await + { Ok(closed_contract) => { self.store.update_contract(&closed_contract)?; return Ok(()); @@ -642,13 +684,13 @@ where } } - self.check_refund(contract)?; + self.check_refund(contract).await?; Ok(()) } /// Manually close a contract with the oracle attestations. - pub fn close_confirmed_contract( + pub async fn close_confirmed_contract( &self, contract_id: &ContractId, attestations: Vec<(usize, OracleAttestation)>, @@ -684,8 +726,9 @@ where // Check that the lock time has passed let time = bitcoin::absolute::Time::from_consensus(self.time.unix_time_now() as u32) .expect("Time is not in valid range. This should never happen."); - let height = Height::from_consensus(self.blockchain.get_blockchain_height()? as u32) - .expect("Height is not in valid range. This should never happen."); + let height = + Height::from_consensus(self.blockchain.get_blockchain_height().await? as u32) + .expect("Height is not in valid range. This should never happen."); let locktime = cet.lock_time; if !locktime.is_satisfied_by(height, time) { @@ -694,11 +737,14 @@ where )); } - match self.close_contract( - &contract, - cet, - attestations.into_iter().map(|x| x.1).collect(), - ) { + match self + .close_contract( + &contract, + cet, + attestations.into_iter().map(|x| x.1).collect(), + ) + .await + { Ok(closed_contract) => { self.store.update_contract(&closed_contract)?; Ok(closed_contract) @@ -718,9 +764,9 @@ where } } - fn check_preclosed_contracts(&self) -> Result<(), Error> { + async fn check_preclosed_contracts(&self) -> Result<(), Error> { for c in self.store.get_preclosed_contracts()? { - if let Err(e) = self.check_preclosed_contract(&c) { + if let Err(e) = self.check_preclosed_contract(&c).await { error!( "Error checking pre-closed contract {}: {}", c.signed_contract.accepted_contract.get_contract_id_string(), @@ -732,11 +778,12 @@ where Ok(()) } - fn check_preclosed_contract(&self, contract: &PreClosedContract) -> Result<(), Error> { + async fn check_preclosed_contract(&self, contract: &PreClosedContract) -> Result<(), Error> { let broadcasted_txid = contract.signed_cet.compute_txid(); let confirmations = self .blockchain - .get_transaction_confirmations(&broadcasted_txid)?; + .get_transaction_confirmations(&broadcasted_txid) + .await?; if confirmations >= NB_CONFIRMATIONS { let closed_contract = ClosedContract { attestations: contract.attestations.clone(), @@ -764,7 +811,7 @@ where Ok(()) } - fn close_contract( + async fn close_contract( &self, contract: &SignedContract, signed_cet: Transaction, @@ -772,14 +819,15 @@ where ) -> Result { let confirmations = self .blockchain - .get_transaction_confirmations(&signed_cet.compute_txid())?; + .get_transaction_confirmations(&signed_cet.compute_txid()) + .await?; if confirmations < 1 { // TODO(tibo): if this fails because another tx is already in // mempool or blockchain, we might have been cheated. There is // not much to be done apart from possibly extracting a fraud // proof but ideally it should be handled. - self.blockchain.send_transaction(&signed_cet)?; + self.blockchain.send_transaction(&signed_cet).await?; let preclosed_contract = PreClosedContract { signed_contract: contract.clone(), @@ -810,7 +858,7 @@ where Ok(Contract::Closed(closed_contract)) } - fn check_refund(&self, contract: &SignedContract) -> Result<(), Error> { + async fn check_refund(&self, contract: &SignedContract) -> Result<(), Error> { // TODO(tibo): should check for confirmation of refund before updating state if contract .accepted_contract @@ -824,13 +872,14 @@ where let refund = accepted_contract.dlc_transactions.refund.clone(); let confirmations = self .blockchain - .get_transaction_confirmations(&refund.compute_txid())?; + .get_transaction_confirmations(&refund.compute_txid()) + .await?; if confirmations == 0 { let offer = &contract.accepted_contract.offered_contract; let signer = self.signer_provider.derive_contract_signer(offer.keys_id)?; let refund = crate::contract_updater::get_signed_refund(&self.secp, contract, &signer)?; - self.blockchain.send_transaction(&refund)?; + self.blockchain.send_transaction(&refund).await?; } self.store @@ -910,16 +959,12 @@ where { /// Create a new channel offer and return the [`dlc_messages::channel::OfferChannel`] /// message to be sent to the `counter_party`. - pub fn offer_channel( + pub async fn offer_channel( &self, contract_input: &ContractInput, counter_party: PublicKey, ) -> Result { - let oracle_announcements = contract_input - .contract_infos - .iter() - .map(|x| self.get_oracle_announcements(&x.oracles)) - .collect::, Error>>()?; + let oracle_announcements = self.oracle_announcements(contract_input).await?; let (offered_channel, offered_contract) = crate::channel_updater::offer_channel( &self.secp, @@ -933,7 +978,8 @@ where &self.blockchain, &self.time, crate::utils::get_new_temporary_id(), - )?; + ) + .await?; let msg = offered_channel.get_offer_channel_msg(&offered_contract); @@ -979,7 +1025,7 @@ where /// Accept a channel that was offered. Returns the [`dlc_messages::channel::AcceptChannel`] /// message to be sent, the updated [`crate::ChannelId`] and [`crate::ContractId`], /// as well as the public key of the offering node. - pub fn accept_channel( + pub async fn accept_channel( &self, channel_id: &ChannelId, ) -> Result<(AcceptChannel, ChannelId, ContractId, PublicKey), Error> { @@ -1007,7 +1053,8 @@ where &self.wallet, &self.signer_provider, &self.blockchain, - )?; + ) + .await?; self.wallet.import_address(&Address::p2wsh( &accepted_contract.dlc_transactions.funding_script_pubkey, @@ -1027,10 +1074,10 @@ where } /// Force close the channel with given [`crate::ChannelId`]. - pub fn force_close_channel(&self, channel_id: &ChannelId) -> Result<(), Error> { + pub async fn force_close_channel(&self, channel_id: &ChannelId) -> Result<(), Error> { let channel = get_channel_in_state!(self, channel_id, Signed, None as Option)?; - self.force_close_channel_internal(channel, true) + self.force_close_channel_internal(channel, true).await } /// Offer to settle the balance of a channel so that the counter party gets @@ -1092,7 +1139,7 @@ where /// Returns a [`RenewOffer`] message as well as the [`PublicKey`] of the /// counter party's node to offer the establishment of a new contract in the /// channel. - pub fn renew_offer( + pub async fn renew_offer( &self, channel_id: &ChannelId, counter_payout: u64, @@ -1101,11 +1148,7 @@ where let mut signed_channel = get_channel_in_state!(self, channel_id, Signed, None as Option)?; - let oracle_announcements = contract_input - .contract_infos - .iter() - .map(|x| self.get_oracle_announcements(&x.oracles)) - .collect::, Error>>()?; + let oracle_announcements = self.oracle_announcements(contract_input).await?; let (msg, offered_contract) = crate::channel_updater::renew_offer( &self.secp, @@ -1259,7 +1302,7 @@ where /// Accept an offer to collaboratively close the channel. The close transaction /// will be broadcast and the state of the channel updated. - pub fn accept_collaborative_close(&self, channel_id: &ChannelId) -> Result<(), Error> { + pub async fn accept_collaborative_close(&self, channel_id: &ChannelId) -> Result<(), Error> { let signed_channel = get_channel_in_state!(self, channel_id, Signed, None as Option)?; @@ -1288,7 +1331,7 @@ where &self.signer_provider, )?; - self.blockchain.send_transaction(&close_tx)?; + self.blockchain.send_transaction(&close_tx).await?; self.store.upsert_channel(closed_channel, None)?; @@ -1300,7 +1343,7 @@ where Ok(()) } - fn try_finalize_closing_established_channel( + async fn try_finalize_closing_established_channel( &self, signed_channel: SignedChannel, ) -> Result<(), Error> { @@ -1314,7 +1357,8 @@ where if self .blockchain - .get_transaction_confirmations(&buffer_tx.compute_txid())? + .get_transaction_confirmations(&buffer_tx.compute_txid()) + .await? >= CET_NSEQUENCE { log::info!( @@ -1327,6 +1371,7 @@ where let (contract_info, adaptor_info, attestations) = self .get_closable_contract_info(&confirmed_contract) + .await .ok_or_else(|| { Error::InvalidState("Could not get information to close contract".to_string()) })?; @@ -1343,11 +1388,13 @@ where is_initiator, )?; - let closed_contract = self.close_contract( - &confirmed_contract, - signed_cet, - attestations.iter().map(|x| &x.1).cloned().collect(), - )?; + let closed_contract = self + .close_contract( + &confirmed_contract, + signed_cet, + attestations.iter().map(|x| &x.1).cloned().collect(), + ) + .await?; self.chain_monitor .lock() @@ -1479,7 +1526,7 @@ where Ok(sign_channel) } - fn on_sign_channel( + async fn on_sign_channel( &self, sign_channel: &SignChannel, peer_id: &PublicKey, @@ -1534,7 +1581,7 @@ where unreachable!(); } - self.blockchain.send_transaction(&signed_fund_tx)?; + self.blockchain.send_transaction(&signed_fund_tx).await?; self.store.upsert_channel( Channel::Signed(signed_channel), @@ -2087,24 +2134,24 @@ where Ok(()) } - fn channel_checks(&self) -> Result<(), Error> { + async fn channel_checks(&self) -> Result<(), Error> { let established_closing_channels = self .store .get_signed_channels(Some(SignedChannelStateType::Closing))?; for channel in established_closing_channels { - if let Err(e) = self.try_finalize_closing_established_channel(channel) { + if let Err(e) = self.try_finalize_closing_established_channel(channel).await { error!("Error trying to close established channel: {}", e); } } - if let Err(e) = self.check_for_timed_out_channels() { + if let Err(e) = self.check_for_timed_out_channels().await { error!("Error checking timed out channels {}", e); } - self.check_for_watched_tx() + self.check_for_watched_tx().await } - fn check_for_timed_out_channels(&self) -> Result<(), Error> { + async fn check_for_timed_out_channels(&self) -> Result<(), Error> { check_for_timed_out_channels!(self, RenewOffered); check_for_timed_out_channels!(self, RenewAccepted); check_for_timed_out_channels!(self, RenewConfirmed); @@ -2115,7 +2162,7 @@ where Ok(()) } - pub(crate) fn process_watched_txs( + pub(crate) async fn process_watched_txs( &self, watched_txs: Vec<(Transaction, ChannelInfo)>, ) -> Result<(), Error> { @@ -2290,7 +2337,7 @@ where } }; - self.blockchain.send_transaction(&signed_tx)?; + self.blockchain.send_transaction(&signed_tx).await?; let closed_channel = Channel::ClosedPunished(ClosedPunishedChannel { counter_party: signed_channel.counter_party, @@ -2420,10 +2467,10 @@ where Ok(()) } - fn check_for_watched_tx(&self) -> Result<(), Error> { + async fn check_for_watched_tx(&self) -> Result<(), Error> { let confirmed_txs = self.chain_monitor.lock().unwrap().confirmed_txs(); - self.process_watched_txs(confirmed_txs)?; + self.process_watched_txs(confirmed_txs).await?; self.get_store() .persist_chain_monitor(&self.chain_monitor.lock().unwrap())?; @@ -2431,7 +2478,7 @@ where Ok(()) } - fn force_close_channel_internal( + async fn force_close_channel_internal( &self, mut channel: SignedChannel, is_initiator: bool, @@ -2450,6 +2497,7 @@ where counter_buffer_adaptor_signature, buffer_transaction, ) + .await } SignedChannelState::RenewFinalized { buffer_transaction, @@ -2464,8 +2512,11 @@ where offer_buffer_adaptor_signature, buffer_transaction, ) + .await + } + SignedChannelState::Settled { .. } => { + self.close_settled_channel(channel, is_initiator).await } - SignedChannelState::Settled { .. } => self.close_settled_channel(channel, is_initiator), SignedChannelState::SettledOffered { .. } | SignedChannelState::SettledReceived { .. } | SignedChannelState::SettledAccepted { .. } @@ -2478,7 +2529,8 @@ where .roll_back_state .take() .expect("to have a rollback state"); - self.force_close_channel_internal(channel, is_initiator) + let channel_clone = channel.clone(); // Clone the channel to avoid moving it + Box::pin(self.force_close_channel_internal(channel_clone, is_initiator)).await } SignedChannelState::Closing { .. } => Err(Error::InvalidState( "Channel is already closing.".to_string(), @@ -2487,7 +2539,7 @@ where } /// Initiate the unilateral closing of a channel that has been established. - fn initiate_unilateral_close_established_channel( + async fn initiate_unilateral_close_established_channel( &self, mut signed_channel: SignedChannel, is_initiator: bool, @@ -2511,7 +2563,7 @@ where let buffer_transaction = get_signed_channel_state!(signed_channel, Closing, ref buffer_transaction)?; - self.blockchain.send_transaction(buffer_transaction)?; + self.blockchain.send_transaction(buffer_transaction).await?; self.chain_monitor .lock() @@ -2528,7 +2580,7 @@ where } /// Unilaterally close a channel that has been settled. - fn close_settled_channel( + async fn close_settled_channel( &self, signed_channel: SignedChannel, is_initiator: bool, @@ -2543,10 +2595,11 @@ where if self .blockchain .get_transaction_confirmations(&settle_tx.compute_txid()) + .await .unwrap_or(0) == 0 { - self.blockchain.send_transaction(&settle_tx)?; + self.blockchain.send_transaction(&settle_tx).await?; } self.chain_monitor @@ -2590,6 +2643,30 @@ where pnl, }) } + + async fn oracle_announcements( + &self, + contract_input: &ContractInput, + ) -> Result>, Error> { + let announcements = stream::iter(contract_input.contract_infos.iter()) + .map(|x| { + let future = self.get_oracle_announcements(&x.oracles); + async move { + match future.await { + Ok(result) => Ok(result), + Err(e) => { + log::error!("Failed to get oracle announcements: {}", e); + Err(e) + } + } + } + }) + .collect::>() + .await + .try_collect::>() + .await?; + Ok(announcements) + } } #[cfg(test)] @@ -2617,13 +2694,16 @@ mod test { SimpleSigner, >; - fn get_manager() -> TestManager { + async fn get_manager() -> TestManager { let blockchain = Rc::new(MockBlockchain::new()); let store = Rc::new(MemoryStorage::new()); - let wallet = Rc::new(MockWallet::new( - &blockchain, - &(0..100).map(|x| x as u64 * 1000000).collect::>(), - )); + let wallet = Rc::new( + MockWallet::new( + &blockchain, + &(0..100).map(|x| x as u64 * 1000000).collect::>(), + ) + .await, + ); let oracle_list = (0..5).map(|_| MockOracle::new()).collect::>(); let oracles: HashMap = oracle_list @@ -2643,6 +2723,7 @@ mod test { time, blockchain, ) + .await .unwrap() } @@ -2652,37 +2733,41 @@ mod test { .unwrap() } - #[test] - fn reject_offer_with_existing_contract_id() { + #[tokio::test] + async fn reject_offer_with_existing_contract_id() { let offer_message = Message::Offer( serde_json::from_str(include_str!("../test_inputs/offer_contract.json")).unwrap(), ); - let manager = get_manager(); + let manager = get_manager().await; manager .on_dlc_message(&offer_message, pubkey()) + .await .expect("To accept the first offer message"); manager .on_dlc_message(&offer_message, pubkey()) + .await .expect_err("To reject the second offer message"); } - #[test] - fn reject_channel_offer_with_existing_channel_id() { + #[tokio::test] + async fn reject_channel_offer_with_existing_channel_id() { let offer_message = Message::OfferChannel( serde_json::from_str(include_str!("../test_inputs/offer_channel.json")).unwrap(), ); - let manager = get_manager(); + let manager = get_manager().await; manager .on_dlc_message(&offer_message, pubkey()) + .await .expect("To accept the first offer message"); manager .on_dlc_message(&offer_message, pubkey()) + .await .expect_err("To reject the second offer message"); } } diff --git a/ddk-manager/src/utils.rs b/ddk-manager/src/utils.rs index eed67a27..cd4dee12 100644 --- a/ddk-manager/src/utils.rs +++ b/ddk-manager/src/utils.rs @@ -89,7 +89,7 @@ pub(crate) fn compute_id( res } -pub(crate) fn get_party_params( +pub(crate) async fn get_party_params( secp: &Secp256k1, own_collateral: u64, fee_rate: u64, @@ -120,7 +120,7 @@ where let mut funding_tx_info: Vec = Vec::new(); let mut total_input = Amount::ZERO; for utxo in utxos { - let prev_tx = blockchain.get_transaction(&utxo.outpoint.txid)?; + let prev_tx = blockchain.get_transaction(&utxo.outpoint.txid).await?; let mut writer = Vec::new(); prev_tx.consensus_encode(&mut writer)?; let prev_tx_vout = utxo.outpoint.vout; diff --git a/ddk-manager/tests/channel_execution_tests.rs b/ddk-manager/tests/channel_execution_tests.rs index abcfeae0..0579b846 100644 --- a/ddk-manager/tests/channel_execution_tests.rs +++ b/ddk-manager/tests/channel_execution_tests.rs @@ -21,19 +21,19 @@ use secp256k1_zkp::rand::{thread_rng, RngCore}; use secp256k1_zkp::EcdsaAdaptorSignature; use simple_wallet::SimpleWallet; use test_utils::{get_enum_test_params, TestParams}; +use tokio::time::sleep; -use std::sync::mpsc::{sync_channel, Receiver, Sender}; -use std::thread; - -use std::time::Duration; +use std::future::Future; +use std::pin::Pin; use std::{ collections::HashMap, sync::{ atomic::{AtomicBool, Ordering}, - mpsc::channel, - Arc, Mutex, + Arc, }, }; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::Mutex; use crate::test_utils::{refresh_wallet, EVENT_MATURITY}; @@ -57,10 +57,13 @@ type DlcParty = Arc< >, >; -fn get_established_channel_contract_id(dlc_party: &DlcParty, channel_id: &ChannelId) -> ContractId { +async fn get_established_channel_contract_id( + dlc_party: &DlcParty, + channel_id: &ChannelId, +) -> ContractId { let channel = dlc_party .lock() - .unwrap() + .await .get_store() .get_channel(channel_id) .unwrap() @@ -82,11 +85,11 @@ fn alter_adaptor_sig(input: &EcdsaAdaptorSignature) -> EcdsaAdaptorSignature { /// We wrap updating the state of the chain monitor and calling the /// `Manager::periodic_check` because the latter will only be aware of /// newly confirmed transactions if the former processes new blocks. -fn periodic_check(dlc_party: DlcParty) { - let dlc_manager = dlc_party.lock().unwrap(); +async fn periodic_check(dlc_party: DlcParty) { + let dlc_manager = dlc_party.lock().await; - dlc_manager.periodic_chain_monitor().unwrap(); - dlc_manager.periodic_check(true).unwrap(); + dlc_manager.periodic_chain_monitor().await.unwrap(); + dlc_manager.periodic_check(true).await.unwrap(); } #[derive(Eq, PartialEq, Clone)] @@ -115,182 +118,195 @@ enum TestPath { CancelOffer, } -#[test] +#[tokio::test] #[ignore] -fn channel_established_close_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::Close); +async fn channel_established_close_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::Close).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_bad_accept_buffer_adaptor_test() { +async fn channel_bad_accept_buffer_adaptor_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::BadAcceptBufferAdaptorSignature, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_bad_sign_buffer_adaptor_test() { +async fn channel_bad_sign_buffer_adaptor_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::BadSignBufferAdaptorSignature, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settled_close_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleClose); +async fn channel_settled_close_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleClose).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_punish_buffer_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::BufferCheat); +async fn channel_punish_buffer_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::BufferCheat).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_close_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::RenewedClose); +async fn channel_renew_close_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::RenewedClose).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_established_close_test() { +async fn channel_renew_established_close_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::RenewEstablishedClose, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settle_cheat_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleCheat); +async fn channel_settle_cheat_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleCheat).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_collaborative_close_test() { +async fn channel_collaborative_close_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::CollaborativeClose, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settle_renew_settle_test() { +async fn channel_settle_renew_settle_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::SettleRenewSettle, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settle_offer_timeout_test() { +async fn channel_settle_offer_timeout_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::SettleOfferTimeout, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settle_accept_timeout_test() { +async fn channel_settle_accept_timeout_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::SettleAcceptTimeout, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settle_confirm_timeout_test() { +async fn channel_settle_confirm_timeout_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::SettleConfirmTimeout, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settle_reject_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleReject); +async fn channel_settle_reject_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleReject).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_settle_race_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleRace); +async fn channel_settle_race_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::SettleRace).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_offer_timeout_test() { +async fn channel_renew_offer_timeout_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::RenewOfferTimeout, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_accept_timeout_test() { +async fn channel_renew_accept_timeout_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::RenewAcceptTimeout, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_confirm_timeout_test() { +async fn channel_renew_confirm_timeout_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::RenewConfirmTimeout, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_finalize_timeout_test() { +async fn channel_renew_finalize_timeout_test() { channel_execution_test( get_enum_test_params(1, 1, None), TestPath::RenewFinalizeTimeout, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_reject_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::RenewReject); +async fn channel_renew_reject_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::RenewReject).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_renew_race_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::RenewRace); +async fn channel_renew_race_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::RenewRace).await; } -#[test] +#[tokio::test] #[ignore] -fn channel_offer_reject_test() { - channel_execution_test(get_enum_test_params(1, 1, None), TestPath::CancelOffer); +async fn channel_offer_reject_test() { + channel_execution_test(get_enum_test_params(1, 1, None), TestPath::CancelOffer).await; } -fn channel_execution_test(test_params: TestParams, path: TestPath) { +async fn channel_execution_test(test_params: TestParams, path: TestPath) { env_logger::init(); - let (alice_send, bob_receive) = channel::>(); - let (bob_send, alice_receive) = channel::>(); - let (alice_sync_send, alice_sync_receive) = sync_channel::<()>(0); - let (bob_sync_send, bob_sync_receive) = sync_channel::<()>(0); + let (alice_send, mut bob_receive) = channel::>(100); + let (bob_send, mut alice_receive) = channel::>(100); + let (alice_sync_send, mut alice_sync_receive) = channel::<()>(100); + let (bob_sync_send, mut bob_sync_receive) = channel::<()>(100); let (_, _, sink_rpc) = init_clients(); + let sink = Arc::new(sink_rpc); let mut alice_oracles = HashMap::with_capacity(1); let mut bob_oracles = HashMap::with_capacity(1); @@ -306,10 +322,16 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { let mock_time = Arc::new(mocks::mock_time::MockTime {}); mocks::mock_time::set_time((EVENT_MATURITY as u64) - 1); - let electrs = Arc::new(ElectrsBlockchainProvider::new( - "http://localhost:3004/".to_string(), - bitcoin::Network::Regtest, - )); + let electrs = tokio::task::spawn_blocking(|| { + Arc::new(ElectrsBlockchainProvider::new( + "http://localhost:3004/".to_string(), + bitcoin::Network::Regtest, + )) + }) + .await + .unwrap(); + + println!("couldnt create electrs"); let alice_wallet = Arc::new(SimpleWallet::new( electrs.clone(), @@ -323,58 +345,66 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { bitcoin::Network::Regtest, )); + println!("getting address"); let alice_fund_address = alice_wallet.get_new_address().unwrap(); let bob_fund_address = bob_wallet.get_new_address().unwrap(); - sink_rpc - .send_to_address( - &alice_fund_address, - Amount::from_btc(2.0).unwrap(), - None, - None, - None, - None, - None, - None, - ) - .unwrap(); - - sink_rpc - .send_to_address( - &bob_fund_address, - Amount::from_btc(2.0).unwrap(), - None, - None, - None, - None, - None, - None, - ) - .unwrap(); - - let generate_blocks = |nb_blocks: u64| { - let prev_blockchain_height = electrs.get_blockchain_height().unwrap(); - - let sink_address = sink_rpc - .get_new_address(None, None) - .expect("RPC Error") - .assume_checked(); - sink_rpc - .generate_to_address(nb_blocks, &sink_address) - .expect("RPC Error"); - - // Wait for electrs to have processed the new blocks - let mut cur_blockchain_height = prev_blockchain_height; - while cur_blockchain_height < prev_blockchain_height + nb_blocks { - std::thread::sleep(std::time::Duration::from_millis(200)); - cur_blockchain_height = electrs.get_blockchain_height().unwrap(); - } + println!("funded address"); + sink.send_to_address( + &alice_fund_address, + Amount::from_btc(2.0).unwrap(), + None, + None, + None, + None, + None, + None, + ) + .unwrap(); + + println!("prolly not her."); + + sink.send_to_address( + &bob_fund_address, + Amount::from_btc(2.0).unwrap(), + None, + None, + None, + None, + None, + None, + ) + .unwrap(); + + let generate_blocks = |nb_blocks: u64| -> Pin + Send>> { + let electrs_clone = electrs.clone(); + let sink_clone = sink.clone(); + println!("failing"); + Box::pin(async move { + let prev_blockchain_height = electrs_clone.get_blockchain_height().await.unwrap(); + println!("Got block {}", prev_blockchain_height); + let sink_address = sink_clone + .get_new_address(None, None) + .expect("RPC Error") + .assume_checked(); + println!("got address"); + sink_clone + .generate_to_address(nb_blocks, &sink_address) + .expect("RPC Error"); + // Wait for electrs to have processed the new blocks + let mut cur_blockchain_height = prev_blockchain_height; + while cur_blockchain_height < prev_blockchain_height + nb_blocks { + sleep(std::time::Duration::from_millis(200)).await; + cur_blockchain_height = electrs_clone.get_blockchain_height().await.unwrap(); + } + }) }; - generate_blocks(6); - - refresh_wallet(&alice_wallet, 200000000); - refresh_wallet(&bob_wallet, 200000000); + println!("generating."); + generate_blocks(6).await; + println!("agh generate."); + refresh_wallet(&alice_wallet, 200000000).await; + refresh_wallet(&bob_wallet, 200000000).await; let alice_manager = Arc::new(Mutex::new( Manager::new( @@ -386,6 +416,7 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { Arc::clone(&mock_time), Arc::clone(&electrs), ) + .await .unwrap(), )); @@ -402,6 +433,7 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { Arc::clone(&mock_time), Arc::clone(&electrs), ) + .await .unwrap(), )); @@ -476,44 +508,53 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { let offer_msg = bob_manager_send .lock() - .unwrap() + .await .offer_channel( &test_params.contract_input, "0218845781f631c48f1c9709e23092067d06837f30aa0cd0544ac887fe91ddd166" .parse() .unwrap(), ) + .await .expect("Send offer error"); let temporary_channel_id = offer_msg.temporary_channel_id; bob_send .send(Some(Message::OfferChannel(offer_msg))) + .await .unwrap(); assert_channel_state!(bob_manager_send, temporary_channel_id, Offered); - alice_sync_receive.recv().expect("Error synchronizing"); + alice_sync_receive + .recv() + .await + .expect("Error synchronizing"); assert_channel_state!(alice_manager_send, temporary_channel_id, Offered); if let TestPath::CancelOffer = path { let (reject_msg, _) = alice_manager_send .lock() - .unwrap() + .await .reject_channel(&temporary_channel_id) .expect("Error rejecting contract offer"); assert_channel_state!(alice_manager_send, temporary_channel_id, Cancelled); - alice_send.send(Some(Message::Reject(reject_msg))).unwrap(); + alice_send + .send(Some(Message::Reject(reject_msg))) + .await + .unwrap(); - bob_sync_receive.recv().expect("Error synchronizing"); + bob_sync_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(bob_manager_send, temporary_channel_id, Cancelled); return; } let (mut accept_msg, channel_id, contract_id, _) = alice_manager_send .lock() - .unwrap() + .await .accept_channel(&temporary_channel_id) + .await .expect("Error accepting contract offer"); assert_channel_state!(alice_manager_send, channel_id, Accepted); @@ -524,40 +565,49 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { bob_expect_error.store(true, Ordering::Relaxed); alice_send .send(Some(Message::AcceptChannel(accept_msg))) + .await .unwrap(); - bob_sync_receive.recv().expect("Error synchronizing"); + bob_sync_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(bob_manager_send, temporary_channel_id, FailedAccept); } TestPath::BadSignBufferAdaptorSignature => { alice_expect_error.store(true, Ordering::Relaxed); alice_send .send(Some(Message::AcceptChannel(accept_msg))) + .await .unwrap(); // Bob receives accept message - bob_sync_receive.recv().expect("Error synchronizing"); + bob_sync_receive.recv().await.expect("Error synchronizing"); // Alice receives sign message - alice_sync_receive.recv().expect("Error synchronizing"); + alice_sync_receive + .recv() + .await + .expect("Error synchronizing"); assert_channel_state!(alice_manager_send, channel_id, FailedSign); } _ => { alice_send .send(Some(Message::AcceptChannel(accept_msg))) + .await .unwrap(); - bob_sync_receive.recv().expect("Error synchronizing"); + bob_sync_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(bob_manager_send, channel_id, Signed, Established); - alice_sync_receive.recv().expect("Error synchronizing"); + alice_sync_receive + .recv() + .await + .expect("Error synchronizing"); assert_channel_state!(alice_manager_send, channel_id, Signed, Established); - generate_blocks(6); + generate_blocks(6).await; mocks::mock_time::set_time((EVENT_MATURITY as u64) + 1); - periodic_check(alice_manager_send.clone()); + periodic_check(alice_manager_send.clone()).await; - periodic_check(bob_manager_send.clone()); + periodic_check(bob_manager_send.clone()).await; assert_contract_state!(alice_manager_send, contract_id, Confirmed); assert_contract_state!(bob_manager_send, contract_id, Confirmed); @@ -568,25 +618,25 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { ( alice_manager_send, &alice_send, - &alice_sync_receive, + &mut alice_sync_receive, bob_manager_send, &bob_send, - &bob_sync_receive, + &mut bob_sync_receive, ) } else { ( bob_manager_send, &bob_send, - &bob_sync_receive, + &mut bob_sync_receive, alice_manager_send, &alice_send, - &alice_sync_receive, + &mut alice_sync_receive, ) }; match path { TestPath::Close => { - close_established_channel(first, second, channel_id, &generate_blocks); + close_established_channel(first, second, channel_id, &generate_blocks).await; } TestPath::CollaborativeClose => { collaborative_close( @@ -596,7 +646,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { channel_id, second_receive, &generate_blocks, - ); + ) + .await; } TestPath::SettleOfferTimeout | TestPath::SettleAcceptTimeout @@ -610,7 +661,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second_receive, channel_id, path, - ); + ) + .await; } TestPath::SettleReject => { settle_reject( @@ -621,7 +673,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second_send, second_receive, channel_id, - ); + ) + .await; } TestPath::SettleRace => { settle_race( @@ -632,7 +685,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second_send, second_receive, channel_id, - ); + ) + .await; } _ => { // Shuffle positions @@ -657,7 +711,7 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { ) }; - first.lock().unwrap().get_store().save(); + first.lock().await.get_store().save(); if let TestPath::RenewEstablishedClose = path { } else { @@ -669,7 +723,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second_send, second_receive, channel_id, - ); + ) + .await; } match path { @@ -682,12 +737,13 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { closer .lock() - .unwrap() + .await .force_close_channel(&channel_id) + .await .expect("to be able to unilaterally close the channel."); } TestPath::BufferCheat => { - cheat_punish(first, second, channel_id, &generate_blocks, true); + cheat_punish(first, second, channel_id, &generate_blocks, true).await; } TestPath::RenewOfferTimeout | TestPath::RenewAcceptTimeout @@ -703,7 +759,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { &test_params.contract_input, path, &generate_blocks, - ); + ) + .await; } TestPath::RenewReject => { renew_reject( @@ -715,7 +772,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second_receive, channel_id, &test_params.contract_input, - ); + ) + .await; } TestPath::RenewRace => { renew_race( @@ -727,12 +785,13 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second_receive, channel_id, &test_params.contract_input, - ); + ) + .await; } TestPath::RenewedClose | TestPath::SettleCheat | TestPath::RenewEstablishedClose => { - first.lock().unwrap().get_store().save(); + first.lock().await.get_store().save(); let check_prev_contract_close = if let TestPath::RenewEstablishedClose = path { @@ -751,7 +810,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { channel_id, &test_params.contract_input, check_prev_contract_close, - ); + ) + .await; if let TestPath::RenewedClose = path { close_established_channel( @@ -759,9 +819,11 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second, channel_id, &generate_blocks, - ); + ) + .await; } else if let TestPath::SettleCheat = path { - cheat_punish(first, second, channel_id, &generate_blocks, false); + cheat_punish(first, second, channel_id, &generate_blocks, false) + .await; } } TestPath::SettleRenewSettle => { @@ -775,7 +837,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { channel_id, &test_params.contract_input, false, - ); + ) + .await; settle_channel( first, @@ -785,7 +848,8 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { second_send, second_receive, channel_id, - ); + ) + .await; } _ => (), } @@ -794,119 +858,125 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { } } - alice_send.send(None).unwrap(); - bob_send.send(None).unwrap(); + alice_send.send(None).await.unwrap(); + bob_send.send(None).await.unwrap(); - alice_handle.join().unwrap(); - bob_handle.join().unwrap(); + alice_handle.await.unwrap(); + bob_handle.await.unwrap(); } -fn close_established_channel( +async fn close_established_channel( first: DlcParty, second: DlcParty, channel_id: ChannelId, generate_blocks: &F, ) where - F: Fn(u64), + F: Fn(u64) -> Pin + Send>>, { first .lock() - .unwrap() + .await .force_close_channel(&channel_id) + .await .expect("to be able to unilaterally close."); assert_channel_state!(first, channel_id, Signed, Closing); - let contract_id = get_established_channel_contract_id(&first, &channel_id); + let contract_id = get_established_channel_contract_id(&first, &channel_id).await; - periodic_check(first.clone()); + periodic_check(first.clone()).await; let wait = ddk_manager::manager::CET_NSEQUENCE; - generate_blocks(10); + generate_blocks(10).await; - periodic_check(second.clone()); + periodic_check(second.clone()).await; assert_channel_state!(second, channel_id, Signed, Closing); - periodic_check(first.clone()); + periodic_check(first.clone()).await; // Should not have changed state before the CET is spendable. assert_channel_state!(first, channel_id, Signed, Closing); - generate_blocks(wait as u64 - 9); + generate_blocks(wait as u64 - 9).await; - periodic_check(first.clone()); + periodic_check(first.clone()).await; assert_channel_state!(first, channel_id, Closed); assert_contract_state!(first, contract_id, PreClosed); - generate_blocks(1); + generate_blocks(1).await; - periodic_check(second.clone()); + periodic_check(second.clone()).await; assert_channel_state!(second, channel_id, CounterClosed); assert_contract_state!(second, contract_id, PreClosed); - generate_blocks(5); + generate_blocks(5).await; - periodic_check(first.clone()); - periodic_check(second.clone()); + periodic_check(first.clone()).await; + periodic_check(second.clone()).await; assert_contract_state!(first, contract_id, Closed); assert_contract_state!(second, contract_id, Closed); } -fn cheat_punish( +async fn cheat_punish( first: DlcParty, second: DlcParty, channel_id: ChannelId, generate_blocks: &F, established: bool, -) { - first.lock().unwrap().get_store().rollback(); +) where + F: Fn(u64) -> Pin + Send>>, +{ + first.lock().await.get_store().rollback(); if established { first .lock() - .unwrap() + .await .force_close_channel(&channel_id) + .await .expect("the cheater to be able to close on established"); } else { first .lock() - .unwrap() + .await .force_close_channel(&channel_id) + .await .expect("the cheater to be able to close on settled"); } - generate_blocks(2); + generate_blocks(2).await; - periodic_check(second.clone()); + periodic_check(second.clone()).await; assert_channel_state!(second, channel_id, ClosedPunished); } -fn settle_channel( +async fn settle_channel( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, ) { let (settle_offer, _) = first .lock() - .unwrap() + .await .settle_offer(&channel_id, test_utils::ACCEPT_COLLATERAL) .expect("to be able to offer a settlement of the contract."); first_send .send(Some(Message::SettleOffer(settle_offer))) + .await .unwrap(); - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, SettledOffered); @@ -914,46 +984,48 @@ fn settle_channel( let (settle_accept, _) = second .lock() - .unwrap() + .await .accept_settle_offer(&channel_id) .expect("to be able to accept a settlement offer"); second_send .send(Some(Message::SettleAccept(settle_accept))) + .await .unwrap(); // Process Accept - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); // Process Confirm - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); // Process Finalize - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, Settled); assert_channel_state!(second, channel_id, Signed, Settled); } -fn settle_reject( +async fn settle_reject( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, ) { let (settle_offer, _) = first .lock() - .unwrap() + .await .settle_offer(&channel_id, test_utils::ACCEPT_COLLATERAL) .expect("to be able to reject a settlement of the contract."); first_send .send(Some(Message::SettleOffer(settle_offer))) + .await .unwrap(); - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, SettledOffered); @@ -961,128 +1033,126 @@ fn settle_reject( let (settle_reject, _) = second .lock() - .unwrap() + .await .reject_settle_offer(&channel_id) .expect("to be able to reject a settlement offer"); second_send .send(Some(Message::Reject(settle_reject))) + .await .unwrap(); - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, Established); assert_channel_state!(second, channel_id, Signed, Established); } -fn settle_race( +async fn settle_race( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, ) { let (settle_offer, _) = first .lock() - .unwrap() + .await .settle_offer(&channel_id, test_utils::ACCEPT_COLLATERAL) .expect("to be able to offer a settlement of the contract."); let (settle_offer_2, _) = second .lock() - .unwrap() + .await .settle_offer(&channel_id, test_utils::ACCEPT_COLLATERAL) .expect("to be able to offer a settlement of the contract."); first_send .send(Some(Message::SettleOffer(settle_offer))) + .await .unwrap(); second_send .send(Some(Message::SettleOffer(settle_offer_2))) + .await .unwrap(); // Process 2 offers + 2 rejects - first_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 1"); - second_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 2"); - first_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 3"); - second_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 4"); + first_receive.recv().await.expect("Error synchronizing 1"); + second_receive.recv().await.expect("Error synchronizing 2"); + first_receive.recv().await.expect("Error synchronizing 3"); + second_receive.recv().await.expect("Error synchronizing 4"); assert_channel_state!(first, channel_id, Signed, Established); assert_channel_state!(second, channel_id, Signed, Established); } -fn renew_channel( +async fn renew_channel( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, contract_input: &ContractInput, check_prev_contract_close: bool, ) { let prev_contract_id = if check_prev_contract_close { - Some(get_established_channel_contract_id(&first, &channel_id)) + Some(get_established_channel_contract_id(&first, &channel_id).await) } else { None }; let (renew_offer, _) = first .lock() - .unwrap() + .await .renew_offer(&channel_id, test_utils::ACCEPT_COLLATERAL, contract_input) + .await .expect("to be able to renew channel contract"); first_send .send(Some(Message::RenewOffer(renew_offer))) + .await .expect("to be able to send the renew offer"); // Process Renew Offer - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, RenewOffered); assert_channel_state!(second, channel_id, Signed, RenewOffered); let (accept_renew, _) = second .lock() - .unwrap() + .await .accept_renew_offer(&channel_id) .expect("to be able to accept the renewal"); second_send .send(Some(Message::RenewAccept(accept_renew))) + .await .expect("to be able to send the accept renew"); // Process Renew Accept - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, RenewConfirmed); // Process Renew Confirm - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); // Process Renew Finalize - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); // Process Renew Revoke - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); if let Some(prev_contract_id) = prev_contract_id { assert_contract_state!(first, prev_contract_id, Closed); assert_contract_state!(second, prev_contract_id, Closed); } - let new_contract_id = get_established_channel_contract_id(&first, &channel_id); + let new_contract_id = get_established_channel_contract_id(&first, &channel_id).await; assert_channel_state!(first, channel_id, Signed, Established); assert_contract_state!(first, new_contract_id, Confirmed); @@ -1090,62 +1160,66 @@ fn renew_channel( assert_contract_state!(second, new_contract_id, Confirmed); } -fn renew_reject( +async fn renew_reject( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, contract_input: &ContractInput, ) { let (renew_offer, _) = first .lock() - .unwrap() + .await .renew_offer(&channel_id, test_utils::ACCEPT_COLLATERAL, contract_input) + .await .expect("to be able to renew channel contract"); first_send .send(Some(Message::RenewOffer(renew_offer))) + .await .expect("to be able to send the renew offer"); // Process Renew Offer - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, RenewOffered); assert_channel_state!(second, channel_id, Signed, RenewOffered); let (renew_reject, _) = second .lock() - .unwrap() + .await .reject_renew_offer(&channel_id) .expect("to be able to reject the renewal"); second_send .send(Some(Message::Reject(renew_reject))) + .await .expect("to be able to send the renew reject"); // Process Renew Reject - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, Settled); assert_channel_state!(second, channel_id, Signed, Settled); } -fn renew_race( +async fn renew_race( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, contract_input: &ContractInput, ) { let (renew_offer, _) = first .lock() - .unwrap() + .await .renew_offer(&channel_id, test_utils::OFFER_COLLATERAL, contract_input) + .await .expect("to be able to renew channel contract"); let mut contract_input_2 = contract_input.clone(); @@ -1154,148 +1228,152 @@ fn renew_race( let (renew_offer_2, _) = second .lock() - .unwrap() + .await .renew_offer(&channel_id, test_utils::OFFER_COLLATERAL, &contract_input_2) + .await .expect("to be able to renew channel contract"); first_send .send(Some(Message::RenewOffer(renew_offer))) + .await .expect("to be able to send the renew offer"); second_send .send(Some(Message::RenewOffer(renew_offer_2))) + .await .expect("to be able to send the renew offer"); // Process 2 offers + 2 rejects - first_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 1"); - second_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 2"); - first_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 3"); - second_receive - .recv_timeout(Duration::from_secs(2)) - .expect("Error synchronizing 4"); + first_receive.recv().await.expect("Error synchronizing 1"); + second_receive.recv().await.expect("Error synchronizing 2"); + first_receive.recv().await.expect("Error synchronizing 3"); + second_receive.recv().await.expect("Error synchronizing 4"); assert_channel_state!(first, channel_id, Signed, Settled); assert_channel_state!(second, channel_id, Signed, Settled); } -fn collaborative_close( +async fn collaborative_close( first: DlcParty, first_send: &Sender>, second: DlcParty, channel_id: ChannelId, - sync_receive: &Receiver<()>, + sync_receive: &mut Receiver<()>, generate_blocks: &F, -) { - let contract_id = get_established_channel_contract_id(&first, &channel_id); +) where + F: Fn(u64) -> Pin + Send>>, +{ + let contract_id = get_established_channel_contract_id(&first, &channel_id).await; let close_offer = first .lock() - .unwrap() + .await .offer_collaborative_close(&channel_id, 100000000) .expect("to be able to propose a collaborative close"); first_send .send(Some(Message::CollaborativeCloseOffer(close_offer))) + .await .expect("to be able to send collaborative close"); - sync_receive.recv().expect("Error synchronizing"); + sync_receive.recv().await.expect("Error synchronizing"); assert_channel_state!(first, channel_id, Signed, CollaborativeCloseOffered); assert_channel_state!(second, channel_id, Signed, CollaborativeCloseOffered); second .lock() - .unwrap() + .await .accept_collaborative_close(&channel_id) + .await .expect("to be able to accept a collaborative close"); assert_channel_state!(second, channel_id, CollaborativelyClosed); assert_contract_state!(second, contract_id, Closed); - generate_blocks(2); + generate_blocks(2).await; - periodic_check(first.clone()); + periodic_check(first.clone()).await; assert_channel_state!(first, channel_id, CollaborativelyClosed); assert_contract_state!(first, contract_id, Closed); } -fn renew_timeout( +async fn renew_timeout( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, contract_input: &ContractInput, path: TestPath, generate_blocks: &F, -) { +) where + F: Fn(u64) -> Pin + Send>>, +{ { let (renew_offer, _) = first .lock() - .unwrap() + .await .renew_offer(&channel_id, test_utils::ACCEPT_COLLATERAL, contract_input) + .await .expect("to be able to offer a settlement of the contract."); first_send .send(Some(Message::RenewOffer(renew_offer))) + .await .unwrap(); - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); if let TestPath::RenewOfferTimeout = path { mocks::mock_time::set_time( (EVENT_MATURITY as u64) + ddk_manager::manager::PEER_TIMEOUT + 2, ); - periodic_check(first.clone()); + periodic_check(first.clone()).await; assert_channel_state!(first, channel_id, Closed); } else { let (renew_accept, _) = second .lock() - .unwrap() + .await .accept_renew_offer(&channel_id) .expect("to be able to accept a settlement offer"); second_send .send(Some(Message::RenewAccept(renew_accept))) + .await .unwrap(); // Process Accept - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); if let TestPath::RenewAcceptTimeout = path { mocks::mock_time::set_time( (EVENT_MATURITY as u64) + ddk_manager::manager::PEER_TIMEOUT + 2, ); - periodic_check(second.clone()); + periodic_check(second.clone()).await; assert_channel_state!(second, channel_id, Closed); } else if let TestPath::RenewConfirmTimeout = path { // Process Confirm - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); mocks::mock_time::set_time( (EVENT_MATURITY as u64) + ddk_manager::manager::PEER_TIMEOUT + 2, ); - periodic_check(first.clone()); + periodic_check(first.clone()).await; assert_channel_state!(first, channel_id, Closed); } else if let TestPath::RenewFinalizeTimeout = path { //Process confirm - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); // Process Finalize - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); mocks::mock_time::set_time( (EVENT_MATURITY as u64) + ddk_manager::manager::PEER_TIMEOUT + 2, ); - periodic_check(second.clone()); - generate_blocks(289); - periodic_check(second.clone()); + periodic_check(second.clone()).await; + generate_blocks(289).await; + periodic_check(second.clone()).await; assert_channel_state!(second, channel_id, Closed); } @@ -1303,69 +1381,71 @@ fn renew_timeout( } } -fn settle_timeout( +async fn settle_timeout( first: DlcParty, first_send: &Sender>, - first_receive: &Receiver<()>, + first_receive: &mut Receiver<()>, second: DlcParty, second_send: &Sender>, - second_receive: &Receiver<()>, + second_receive: &mut Receiver<()>, channel_id: ChannelId, path: TestPath, ) { let (settle_offer, _) = first .lock() - .unwrap() + .await .settle_offer(&channel_id, test_utils::ACCEPT_COLLATERAL) .expect("to be able to offer a settlement of the contract."); first_send .send(Some(Message::SettleOffer(settle_offer))) + .await .unwrap(); - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); if let TestPath::SettleOfferTimeout = path { mocks::mock_time::set_time( (EVENT_MATURITY as u64) + ddk_manager::manager::PEER_TIMEOUT + 2, ); - periodic_check(first.clone()); + periodic_check(first.clone()).await; assert_channel_state!(first, channel_id, Signed, Closing); } else { let (settle_accept, _) = second .lock() - .unwrap() + .await .accept_settle_offer(&channel_id) .expect("to be able to accept a settlement offer"); second_send .send(Some(Message::SettleAccept(settle_accept))) + .await .unwrap(); // Process Accept - first_receive.recv().expect("Error synchronizing"); + first_receive.recv().await.expect("Error synchronizing"); if let TestPath::SettleAcceptTimeout = path { mocks::mock_time::set_time( (EVENT_MATURITY as u64) + ddk_manager::manager::PEER_TIMEOUT + 2, ); - periodic_check(second.clone()); + periodic_check(second.clone()).await; second .lock() - .unwrap() + .await .get_store() .get_channel(&channel_id) .unwrap(); assert_channel_state!(second, channel_id, Signed, Closing); } else if let TestPath::SettleConfirmTimeout = path { // Process Confirm - second_receive.recv().expect("Error synchronizing"); + second_receive.recv().await.expect("Error synchronizing"); mocks::mock_time::set_time( (EVENT_MATURITY as u64) + ddk_manager::manager::PEER_TIMEOUT + 2, ); - periodic_check(first.clone()); + periodic_check(first.clone()).await; assert_channel_state!(first, channel_id, Signed, Closing); } diff --git a/ddk-manager/tests/manager_execution_tests.rs b/ddk-manager/tests/manager_execution_tests.rs index 8ea4509b..41716bbe 100644 --- a/ddk-manager/tests/manager_execution_tests.rs +++ b/ddk-manager/tests/manager_execution_tests.rs @@ -26,15 +26,17 @@ use lightning::ln::wire::Type; use lightning::util::ser::Writeable; use secp256k1_zkp::rand::{thread_rng, RngCore}; use secp256k1_zkp::{ecdsa::Signature, EcdsaAdaptorSignature}; -use serde_json::{from_str, to_writer_pretty}; +use serde_json::from_str; use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; use std::sync::{ atomic::{AtomicBool, Ordering}, - mpsc::channel, - Arc, Mutex, + Arc, }; -use std::thread; - +use tokio::sync::mpsc::channel; +use tokio::sync::Mutex; +use tokio::time::sleep; #[derive(serde::Serialize, serde::Deserialize)] struct TestVectorPart { message: T, @@ -55,50 +57,64 @@ struct TestVector { sign_message: TestVectorPart, } -fn write_message(msg_name: &str, s: T) { +fn write_message(_msg_name: &str, s: T) { if std::env::var("GENERATE_TEST_VECTOR").is_ok() { let mut buf = Vec::new(); s.type_id().write(&mut buf).unwrap(); s.write(&mut buf).unwrap(); - let t = TestVectorPart { + let _t = TestVectorPart { message: s, serialized: buf, }; - to_writer_pretty( - &std::fs::File::create(format!("{}.json", msg_name)).unwrap(), - &t, - ) - .unwrap(); + // to_writer_pretty( + // &std::fs::File::create(format!("{}.json", msg_name)).unwrap(), + // &t, + // ) + // .unwrap(); } } -fn create_test_vector() { +async fn create_test_vector() { if std::env::var("GENERATE_TEST_VECTOR").is_ok() { - let test_vector = TestVector { - offer_message: from_str(&std::fs::read_to_string("offer_message.json").unwrap()) - .unwrap(), - accept_message: from_str(&std::fs::read_to_string("accept_message.json").unwrap()) - .unwrap(), - sign_message: from_str(&std::fs::read_to_string("sign_message.json").unwrap()).unwrap(), + let _test_vector = TestVector { + offer_message: from_str( + &tokio::fs::read_to_string("offer_message.json") + .await + .unwrap(), + ) + .unwrap(), + accept_message: from_str( + &tokio::fs::read_to_string("accept_message.json") + .await + .unwrap(), + ) + .unwrap(), + sign_message: from_str( + &tokio::fs::read_to_string("sign_message.json") + .await + .unwrap(), + ) + .unwrap(), }; - let file_name = std::env::var("TEST_VECTOR_OUTPUT_NAME") + let _file_name = std::env::var("TEST_VECTOR_OUTPUT_NAME") .unwrap_or_else(|_| "test_vector.json".to_string()); - to_writer_pretty(std::fs::File::create(file_name).unwrap(), &test_vector).unwrap(); + // to_writer_pretty(std::fs::File::create(file_name).unwrap(), &test_vector).unwrap(); } } macro_rules! periodic_check { ($d:expr, $id:expr, $p:ident) => { $d.lock() - .unwrap() + .await .periodic_check(true) + .await .expect("Periodic check error"); assert_contract_state!($d, $id, $p); }; } -fn numerical_common( +async fn numerical_common( nb_oracles: usize, threshold: usize, payout_function_pieces_cb: F, @@ -124,10 +140,11 @@ fn numerical_common( ), TestPath::Close, manual_close, - ); + ) + .await; } -fn numerical_polynomial_common( +async fn numerical_polynomial_common( nb_oracles: usize, threshold: usize, difference_params: Option, @@ -139,10 +156,11 @@ fn numerical_polynomial_common( get_polynomial_payout_curve_pieces, difference_params, manual_close, - ); + ) + .await; } -fn numerical_common_diff_nb_digits( +async fn numerical_common_diff_nb_digits( nb_oracles: usize, threshold: usize, difference_params: Option, @@ -171,7 +189,8 @@ fn numerical_common_diff_nb_digits( ), TestPath::Close, manual_close, - ); + ) + .await; } #[derive(Eq, PartialEq, Clone)] @@ -184,330 +203,344 @@ enum TestPath { BadSignRefundSignature, } -#[test] +#[tokio::test] #[ignore] -fn single_oracle_numerical_test() { - numerical_polynomial_common(1, 1, None, false); +async fn single_oracle_numerical_test() { + numerical_polynomial_common(1, 1, None, false).await; } -#[test] +#[tokio::test] #[ignore] -fn single_oracle_numerical_manual_test() { - numerical_polynomial_common(1, 1, None, true); +async fn single_oracle_numerical_manual_test() { + numerical_polynomial_common(1, 1, None, true).await; } -#[test] +#[tokio::test] #[ignore] -fn single_oracle_numerical_hyperbola_test() { - numerical_common(1, 1, get_hyperbola_payout_curve_pieces, None, false); +async fn single_oracle_numerical_hyperbola_test() { + numerical_common(1, 1, get_hyperbola_payout_curve_pieces, None, false).await; } -#[test] +#[tokio::test] #[ignore] -fn three_of_three_oracle_numerical_test() { - numerical_polynomial_common(3, 3, None, false); +async fn three_of_three_oracle_numerical_test() { + numerical_polynomial_common(3, 3, None, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_test() { - numerical_polynomial_common(5, 2, None, false); +async fn two_of_five_oracle_numerical_test() { + numerical_polynomial_common(5, 2, None, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_manual_test() { - numerical_polynomial_common(5, 2, None, true); +async fn two_of_five_oracle_numerical_manual_test() { + numerical_polynomial_common(5, 2, None, true).await; } -#[test] +#[tokio::test] #[ignore] -fn three_of_three_oracle_numerical_with_diff_test() { - numerical_polynomial_common(3, 3, Some(get_difference_params()), false); +async fn three_of_three_oracle_numerical_with_diff_test() { + numerical_polynomial_common(3, 3, Some(get_difference_params()), false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_with_diff_test() { - numerical_polynomial_common(5, 2, Some(get_difference_params()), false); +async fn two_of_five_oracle_numerical_with_diff_test() { + numerical_polynomial_common(5, 2, Some(get_difference_params()), false).await; } -#[test] +#[tokio::test] #[ignore] -fn three_of_five_oracle_numerical_with_diff_test() { - numerical_polynomial_common(5, 3, Some(get_difference_params()), false); +async fn three_of_five_oracle_numerical_with_diff_test() { + numerical_polynomial_common(5, 3, Some(get_difference_params()), false).await; } -#[test] +#[tokio::test] #[ignore] -fn three_of_five_oracle_numerical_with_diff_manual_test() { - numerical_polynomial_common(5, 3, Some(get_difference_params()), true); +async fn three_of_five_oracle_numerical_with_diff_manual_test() { + numerical_polynomial_common(5, 3, Some(get_difference_params()), true).await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_test() { - manager_execution_test(get_enum_test_params(1, 1, None), TestPath::Close, false); +async fn enum_single_oracle_test() { + manager_execution_test(get_enum_test_params(1, 1, None), TestPath::Close, false).await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_manual_test() { - manager_execution_test(get_enum_test_params(1, 1, None), TestPath::Close, true); +async fn enum_single_oracle_manual_test() { + manager_execution_test(get_enum_test_params(1, 1, None), TestPath::Close, true).await; } -#[test] +#[tokio::test] #[ignore] -fn enum_3_of_3_test() { - manager_execution_test(get_enum_test_params(3, 3, None), TestPath::Close, false); +async fn enum_3_of_3_test() { + manager_execution_test(get_enum_test_params(3, 3, None), TestPath::Close, false).await; } -#[test] +#[tokio::test] #[ignore] -fn enum_3_of_3_manual_test() { - manager_execution_test(get_enum_test_params(3, 3, None), TestPath::Close, true); +async fn enum_3_of_3_manual_test() { + manager_execution_test(get_enum_test_params(3, 3, None), TestPath::Close, true).await; } -#[test] +#[tokio::test] #[ignore] -fn enum_3_of_5_test() { - manager_execution_test(get_enum_test_params(5, 3, None), TestPath::Close, false); +async fn enum_3_of_5_test() { + manager_execution_test(get_enum_test_params(5, 3, None), TestPath::Close, false).await; } -#[test] +#[tokio::test] #[ignore] -fn enum_3_of_5_manual_test() { - manager_execution_test(get_enum_test_params(5, 3, None), TestPath::Close, true); +async fn enum_3_of_5_manual_test() { + manager_execution_test(get_enum_test_params(5, 3, None), TestPath::Close, true).await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_with_diff_3_of_5_test() { +async fn enum_and_numerical_with_diff_3_of_5_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 3, true, Some(get_difference_params())), TestPath::Close, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_with_diff_3_of_5_manual_test() { +async fn enum_and_numerical_with_diff_3_of_5_manual_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 3, true, Some(get_difference_params())), TestPath::Close, true, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_with_diff_5_of_5_test() { +async fn enum_and_numerical_with_diff_5_of_5_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 5, true, Some(get_difference_params())), TestPath::Close, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_with_diff_5_of_5_manual_test() { +async fn enum_and_numerical_with_diff_5_of_5_manual_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 5, true, Some(get_difference_params())), TestPath::Close, true, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_3_of_5_test() { +async fn enum_and_numerical_3_of_5_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 3, false, None), TestPath::Close, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_3_of_5_manual_test() { +async fn enum_and_numerical_3_of_5_manual_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 3, false, None), TestPath::Close, true, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_5_of_5_test() { +async fn enum_and_numerical_5_of_5_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 5, false, None), TestPath::Close, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_and_numerical_5_of_5_manual_test() { +async fn enum_and_numerical_5_of_5_manual_test() { manager_execution_test( get_enum_and_numerical_test_params(5, 5, false, None), TestPath::Close, true, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_refund_test() { +async fn enum_single_oracle_refund_test() { manager_execution_test( get_enum_test_params(1, 1, Some(get_enum_oracles(1, 0))), TestPath::Refund, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_refund_manual_test() { +async fn enum_single_oracle_refund_manual_test() { manager_execution_test( get_enum_test_params(1, 1, Some(get_enum_oracles(1, 0))), TestPath::Refund, true, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_bad_accept_cet_sig_test() { +async fn enum_single_oracle_bad_accept_cet_sig_test() { manager_execution_test( get_enum_test_params(1, 1, Some(get_enum_oracles(1, 0))), TestPath::BadAcceptCetSignature, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_bad_accept_refund_sig_test() { +async fn enum_single_oracle_bad_accept_refund_sig_test() { manager_execution_test( get_enum_test_params(1, 1, Some(get_enum_oracles(1, 0))), TestPath::BadAcceptRefundSignature, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_bad_sign_cet_sig_test() { +async fn enum_single_oracle_bad_sign_cet_sig_test() { manager_execution_test( get_enum_test_params(1, 1, Some(get_enum_oracles(1, 0))), TestPath::BadSignCetSignature, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn enum_single_oracle_bad_sign_refund_sig_test() { +async fn enum_single_oracle_bad_sign_refund_sig_test() { manager_execution_test( get_enum_test_params(1, 1, Some(get_enum_oracles(1, 0))), TestPath::BadSignRefundSignature, false, - ); + ) + .await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_two_oracle_numerical_diff_nb_digits_test() { - numerical_common_diff_nb_digits(2, 2, None, false, false); +async fn two_of_two_oracle_numerical_diff_nb_digits_test() { + numerical_common_diff_nb_digits(2, 2, None, false, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_two_oracle_numerical_diff_nb_digits_manual_test() { - numerical_common_diff_nb_digits(2, 2, None, false, true); +async fn two_of_two_oracle_numerical_diff_nb_digits_manual_test() { + numerical_common_diff_nb_digits(2, 2, None, false, true).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_diff_nb_digits_test() { - numerical_common_diff_nb_digits(5, 2, None, false, false); +async fn two_of_five_oracle_numerical_diff_nb_digits_test() { + numerical_common_diff_nb_digits(5, 2, None, false, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_diff_nb_digits_manual_test() { - numerical_common_diff_nb_digits(5, 2, None, false, true); +async fn two_of_five_oracle_numerical_diff_nb_digits_manual_test() { + numerical_common_diff_nb_digits(5, 2, None, false, true).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_two_oracle_numerical_with_diff_diff_nb_digits_test() { - numerical_common_diff_nb_digits(2, 2, Some(get_difference_params()), false, false); +async fn two_of_two_oracle_numerical_with_diff_diff_nb_digits_test() { + numerical_common_diff_nb_digits(2, 2, Some(get_difference_params()), false, false).await; } -#[test] +#[tokio::test] #[ignore] -fn three_of_three_oracle_numerical_with_diff_diff_nb_digits_test() { - numerical_common_diff_nb_digits(3, 3, Some(get_difference_params()), false, false); +async fn three_of_three_oracle_numerical_with_diff_diff_nb_digits_test() { + numerical_common_diff_nb_digits(3, 3, Some(get_difference_params()), false, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_with_diff_diff_nb_digits_test() { - numerical_common_diff_nb_digits(5, 2, Some(get_difference_params()), false, false); +async fn two_of_five_oracle_numerical_with_diff_diff_nb_digits_test() { + numerical_common_diff_nb_digits(5, 2, Some(get_difference_params()), false, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_two_oracle_numerical_with_diff_diff_nb_digits_max_value_test() { - numerical_common_diff_nb_digits(2, 2, Some(get_difference_params()), true, false); +async fn two_of_two_oracle_numerical_with_diff_diff_nb_digits_max_value_test() { + numerical_common_diff_nb_digits(2, 2, Some(get_difference_params()), true, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_three_oracle_numerical_with_diff_diff_nb_digits_max_value_test() { - numerical_common_diff_nb_digits(3, 2, Some(get_difference_params()), true, false); +async fn two_of_three_oracle_numerical_with_diff_diff_nb_digits_max_value_test() { + numerical_common_diff_nb_digits(3, 2, Some(get_difference_params()), true, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_with_diff_diff_nb_digits_max_value_test() { - numerical_common_diff_nb_digits(5, 2, Some(get_difference_params()), true, false); +async fn two_of_five_oracle_numerical_with_diff_diff_nb_digits_max_value_test() { + numerical_common_diff_nb_digits(5, 2, Some(get_difference_params()), true, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_with_diff_diff_nb_digits_max_value_manual_test() { - numerical_common_diff_nb_digits(5, 2, Some(get_difference_params()), true, true); +async fn two_of_five_oracle_numerical_with_diff_diff_nb_digits_max_value_manual_test() { + numerical_common_diff_nb_digits(5, 2, Some(get_difference_params()), true, true).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_two_oracle_numerical_diff_nb_digits_max_value_test() { - numerical_common_diff_nb_digits(2, 2, None, true, false); +async fn two_of_two_oracle_numerical_diff_nb_digits_max_value_test() { + numerical_common_diff_nb_digits(2, 2, None, true, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_three_oracle_numerical_diff_nb_digits_max_value_test() { - numerical_common_diff_nb_digits(3, 2, None, true, false); +async fn two_of_three_oracle_numerical_diff_nb_digits_max_value_test() { + numerical_common_diff_nb_digits(3, 2, None, true, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_diff_nb_digits_max_value_test() { - numerical_common_diff_nb_digits(5, 2, None, true, false); +async fn two_of_five_oracle_numerical_diff_nb_digits_max_value_test() { + numerical_common_diff_nb_digits(5, 2, None, true, false).await; } -#[test] +#[tokio::test] #[ignore] -fn two_of_five_oracle_numerical_diff_nb_digits_max_value_manual_test() { - numerical_common_diff_nb_digits(5, 2, None, true, true); +async fn two_of_five_oracle_numerical_diff_nb_digits_max_value_manual_test() { + numerical_common_diff_nb_digits(5, 2, None, true, true).await; } fn alter_adaptor_sig(input: &mut CetAdaptorSignatures) { @@ -530,24 +563,21 @@ fn alter_refund_sig(refund_signature: &Signature) -> Signature { Signature::from_compact(©).unwrap() } -fn get_attestations(test_params: &TestParams) -> Vec<(usize, OracleAttestation)> { +async fn get_attestations(test_params: &TestParams) -> Vec<(usize, OracleAttestation)> { + let mut attestations = Vec::new(); for contract_info in test_params.contract_input.contract_infos.iter() { - let attestations: Vec<_> = contract_info - .oracles - .public_keys - .iter() - .enumerate() - .filter_map(|(i, pk)| { - let oracle = test_params - .oracles - .iter() - .find(|x| x.get_public_key() == *pk); - - oracle - .and_then(|o| o.get_attestation(&contract_info.oracles.event_id).ok()) - .map(|a| (i, a)) - }) - .collect(); + attestations.clear(); + for (i, pk) in contract_info.oracles.public_keys.iter().enumerate() { + let oracle = test_params + .oracles + .iter() + .find(|x| x.get_public_key() == *pk); + if let Some(o) = oracle { + if let Ok(attestation) = o.get_attestation(&contract_info.oracles.event_id).await { + attestations.push((i, attestation)); + } + } + } if attestations.len() >= contract_info.oracles.threshold as usize { return attestations; } @@ -556,14 +586,15 @@ fn get_attestations(test_params: &TestParams) -> Vec<(usize, OracleAttestation)> panic!("No attestations found"); } -fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: bool) { +async fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: bool) { env_logger::try_init().ok(); - let (alice_send, bob_receive) = channel::>(); - let (bob_send, alice_receive) = channel::>(); - let (sync_send, sync_receive) = channel::<()>(); + let (alice_send, mut bob_receive) = channel::>(100); + let (bob_send, mut alice_receive) = channel::>(100); + let (sync_send, mut sync_receive) = channel::<()>(100); let alice_sync_send = sync_send.clone(); let bob_sync_send = sync_send; let (_, _, sink_rpc) = init_clients(); + let sink = Arc::new(sink_rpc); let mut alice_oracles = HashMap::with_capacity(1); let mut bob_oracles = HashMap::with_capacity(1); @@ -599,55 +630,55 @@ fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: let alice_fund_address = alice_wallet.get_new_address().unwrap(); let bob_fund_address = bob_wallet.get_new_address().unwrap(); - sink_rpc - .send_to_address( - &alice_fund_address, - Amount::from_btc(2.0).unwrap(), - None, - None, - None, - None, - None, - None, - ) - .unwrap(); - - sink_rpc - .send_to_address( - &bob_fund_address, - Amount::from_btc(2.0).unwrap(), - None, - None, - None, - None, - None, - None, - ) - .unwrap(); - - let generate_blocks = |nb_blocks: u64| { - let prev_blockchain_height = electrs.get_blockchain_height().unwrap(); - - let sink_address = sink_rpc - .get_new_address(None, None) - .expect("RPC Error") - .assume_checked(); - sink_rpc - .generate_to_address(nb_blocks, &sink_address) - .expect("RPC Error"); - - // Wait for electrs to have processed the new blocks - let mut cur_blockchain_height = prev_blockchain_height; - while cur_blockchain_height < prev_blockchain_height + nb_blocks { - std::thread::sleep(std::time::Duration::from_millis(200)); - cur_blockchain_height = electrs.get_blockchain_height().unwrap(); - } + sink.send_to_address( + &alice_fund_address, + Amount::from_btc(2.0).unwrap(), + None, + None, + None, + None, + None, + None, + ) + .unwrap(); + + sink.send_to_address( + &bob_fund_address, + Amount::from_btc(2.0).unwrap(), + None, + None, + None, + None, + None, + None, + ) + .unwrap(); + + let generate_blocks = |nb_blocks: u64| -> Pin + Send>> { + let electrs_clone = electrs.clone(); + let sink_clone = sink.clone(); + Box::pin(async move { + let prev_blockchain_height = electrs_clone.get_blockchain_height().await.unwrap(); + let sink_address = sink_clone + .get_new_address(None, None) + .expect("RPC Error") + .assume_checked(); + sink_clone + .generate_to_address(nb_blocks, &sink_address) + .expect("RPC Error"); + // Wait for electrs to have processed the new blocks + let mut cur_blockchain_height = prev_blockchain_height; + while cur_blockchain_height < prev_blockchain_height + nb_blocks { + sleep(std::time::Duration::from_millis(200)).await; + cur_blockchain_height = electrs_clone.get_blockchain_height().await.unwrap(); + } + }) }; - generate_blocks(6); + generate_blocks(6).await; - refresh_wallet(&alice_wallet, 200000000); - refresh_wallet(&bob_wallet, 200000000); + refresh_wallet(&alice_wallet, 200000000).await; + refresh_wallet(&bob_wallet, 200000000).await; let alice_manager = Arc::new(Mutex::new( Manager::new( @@ -659,6 +690,7 @@ fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: Arc::clone(&mock_time), Arc::clone(&electrs), ) + .await .unwrap(), )); @@ -675,6 +707,7 @@ fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: Arc::clone(&mock_time), Arc::clone(&electrs), ) + .await .unwrap(), )); @@ -734,29 +767,34 @@ fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: let offer_msg = bob_manager_send .lock() - .unwrap() + .await .send_offer( &test_params.contract_input, "0218845781f631c48f1c9709e23092067d06837f30aa0cd0544ac887fe91ddd166" .parse() .unwrap(), ) + .await .expect("Send offer error"); write_message("offer_message", offer_msg.clone()); let temporary_contract_id = offer_msg.temporary_contract_id; - bob_send.send(Some(Message::Offer(offer_msg))).unwrap(); + bob_send + .send(Some(Message::Offer(offer_msg))) + .await + .unwrap(); assert_contract_state!(bob_manager_send, temporary_contract_id, Offered); - sync_receive.recv().expect("Error synchronizing"); + sync_receive.recv().await.expect("Error synchronizing"); assert_contract_state!(alice_manager_send, temporary_contract_id, Offered); let (contract_id, _, mut accept_msg) = alice_manager_send .lock() - .unwrap() + .await .accept_contract_offer(&temporary_contract_id) + .await .expect("Error accepting contract offer"); write_message("accept_message", accept_msg.clone()); @@ -775,33 +813,42 @@ fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: _ => {} }; bob_expect_error.store(true, Ordering::Relaxed); - alice_send.send(Some(Message::Accept(accept_msg))).unwrap(); - sync_receive.recv().expect("Error synchronizing"); + alice_send + .send(Some(Message::Accept(accept_msg))) + .await + .unwrap(); + sync_receive.recv().await.expect("Error synchronizing"); assert_contract_state!(bob_manager_send, temporary_contract_id, FailedAccept); } TestPath::BadSignCetSignature | TestPath::BadSignRefundSignature => { alice_expect_error.store(true, Ordering::Relaxed); - alice_send.send(Some(Message::Accept(accept_msg))).unwrap(); + alice_send + .send(Some(Message::Accept(accept_msg))) + .await + .unwrap(); // Bob receives accept message - sync_receive.recv().expect("Error synchronizing"); + sync_receive.recv().await.expect("Error synchronizing"); // Alice receives sign message - sync_receive.recv().expect("Error synchronizing"); + sync_receive.recv().await.expect("Error synchronizing"); assert_contract_state!(alice_manager_send, contract_id, FailedSign); } TestPath::Close | TestPath::Refund => { - alice_send.send(Some(Message::Accept(accept_msg))).unwrap(); - sync_receive.recv().expect("Error synchronizing"); + alice_send + .send(Some(Message::Accept(accept_msg))) + .await + .unwrap(); + sync_receive.recv().await.expect("Error synchronizing"); assert_contract_state!(bob_manager_send, contract_id, Signed); // Should not change state and should not error periodic_check!(bob_manager_send, contract_id, Signed); - sync_receive.recv().expect("Error synchronizing"); + sync_receive.recv().await.expect("Error synchronizing"); assert_contract_state!(alice_manager_send, contract_id, Signed); - generate_blocks(6); + generate_blocks(6).await; periodic_check!(alice_manager_send, contract_id, Confirmed); periodic_check!(bob_manager_send, contract_id, Confirmed); @@ -831,15 +878,16 @@ fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: if manual_close { periodic_check!(first, contract_id, Confirmed); - let attestations = get_attestations(&test_params); + let attestations = get_attestations(&test_params).await; - let f = first.lock().unwrap(); + let f = first.lock().await; let contract = f .close_confirmed_contract(&contract_id, attestations) + .await .expect("Error closing contract"); if let Contract::PreClosed(contract) = contract { - let mut s = second.lock().unwrap(); + let mut s = second.lock().await; let second_contract = s.get_store().get_contract(&contract_id).unwrap().unwrap(); if let Contract::Confirmed(signed) = second_contract { @@ -861,7 +909,7 @@ fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: // mine blocks for the CET to be confirmed if let Some(b) = blocks { - generate_blocks(b as u64); + generate_blocks(b as u64).await; } // Randomly check with or without having the CET mined @@ -883,13 +931,13 @@ fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: ((EVENT_MATURITY + ddk_manager::manager::REFUND_DELAY) as u64) + 1, ); - generate_blocks(10); + generate_blocks(10).await; periodic_check!(first, contract_id, Refunded); // Randomly check with or without having the Refund mined. if thread_rng().next_u32() % 2 == 0 { - generate_blocks(1); + generate_blocks(1).await; } periodic_check!(second, contract_id, Refunded); @@ -899,11 +947,11 @@ fn manager_execution_test(test_params: TestParams, path: TestPath, manual_close: } } - alice_send.send(None).unwrap(); - bob_send.send(None).unwrap(); + alice_send.send(None).await.unwrap(); + bob_send.send(None).await.unwrap(); - alice_handle.join().unwrap(); - bob_handle.join().unwrap(); + alice_handle.await.unwrap(); + bob_handle.await.unwrap(); - create_test_vector(); + create_test_vector().await; } diff --git a/ddk-manager/tests/test_utils.rs b/ddk-manager/tests/test_utils.rs index 9f87b6b3..fc1418b0 100644 --- a/ddk-manager/tests/test_utils.rs +++ b/ddk-manager/tests/test_utils.rs @@ -46,39 +46,46 @@ pub const ROUNDING_MOD: u64 = 1; #[macro_export] macro_rules! receive_loop { ($receive:expr, $manager:expr, $send:expr, $expect_err:expr, $sync_send:expr, $rcv_callback: expr, $msg_callback: expr) => { - thread::spawn(move || loop { - match $receive.recv() { - Ok(Some(msg)) => match $manager.lock().unwrap().on_dlc_message( - &msg, - "0218845781f631c48f1c9709e23092067d06837f30aa0cd0544ac887fe91ddd166" - .parse() - .unwrap(), - ) { - Ok(opt) => { - if $expect_err.load(Ordering::Relaxed) != false { - panic!("Expected error not raised"); - } - match opt { - Some(msg) => { - let msg_opt = $rcv_callback(msg); - if let Some(msg) = msg_opt { - #[allow(clippy::redundant_closure_call)] - $msg_callback(&msg); - (&$send).send(Some(msg)).expect("Error sending"); + tokio::spawn(async move { + loop { + match $receive.recv().await { + Some(Some(msg)) => match $manager + .lock() + .await + .on_dlc_message( + &msg, + "0218845781f631c48f1c9709e23092067d06837f30aa0cd0544ac887fe91ddd166" + .parse() + .unwrap(), + ) + .await + { + Ok(opt) => { + if $expect_err.load(Ordering::Relaxed) != false { + panic!("Expected error not raised"); + } + match opt { + Some(msg) => { + let msg_opt = $rcv_callback(msg); + if let Some(msg) = msg_opt { + #[allow(clippy::redundant_closure_call)] + $msg_callback(&msg); + (&$send).send(Some(msg)).await.expect("Error sending"); + } } + None => {} } - None => {} } - } - Err(e) => { - if $expect_err.load(Ordering::Relaxed) != true { - panic!("Unexpected error {}", e); + Err(e) => { + if $expect_err.load(Ordering::Relaxed) != true { + panic!("Unexpected error {}", e); + } } - } - }, - Ok(None) | Err(_) => return, - }; - $sync_send.send(()).expect("Error syncing"); + }, + None | Some(None) => return, + }; + $sync_send.send(()).await.expect("Error syncing"); + } }) }; } @@ -104,7 +111,7 @@ macro_rules! assert_contract_state { ($d:expr, $id:expr, $p:ident) => { let res = $d .lock() - .unwrap() + .await .get_store() .get_contract(&$id) .expect("Could not retrieve contract"); @@ -146,7 +153,7 @@ macro_rules! write_channel { #[macro_export] macro_rules! assert_channel_state { ($d:expr, $id:expr, $p:ident $(, $s: ident)?) => {{ - assert_channel_state_unlocked!($d.lock().unwrap(), $id, $p $(, $s)?) + assert_channel_state_unlocked!($d.lock().await, $id, $p $(, $s)?) }}; } @@ -568,7 +575,7 @@ pub fn get_variable_oracle_numeric_infos(nb_digits: &[usize]) -> OracleNumericIn } } -pub fn refresh_wallet( +pub async fn refresh_wallet( wallet: &simple_wallet::SimpleWallet, expected_funds: u64, ) where @@ -581,7 +588,7 @@ pub fn refresh_wallet( panic!("Wallet refresh taking too long.") } std::thread::sleep(std::time::Duration::from_millis(200)); - wallet.refresh().unwrap(); + wallet.refresh().await.unwrap(); retry += 1; } } diff --git a/electrs-blockchain-provider/Cargo.toml b/electrs-blockchain-provider/Cargo.toml index af3535c5..5768d25e 100644 --- a/electrs-blockchain-provider/Cargo.toml +++ b/electrs-blockchain-provider/Cargo.toml @@ -6,12 +6,13 @@ version = "0.1.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +async-trait = "0.1.83" bitcoin = {version = "0.32.2"} bitcoin-test-utils = {path = "../bitcoin-test-utils"} ddk-manager = {path = "../ddk-manager"} lightning = {version = "0.0.125"} lightning-block-sync = {version = "0.0.125"} -reqwest = {version = "0.11", features = ["blocking", "json"]} +reqwest = {version = "0.11", features = ["json"]} serde = {version = "*", features = ["derive"]} simple-wallet = {path = "../simple-wallet"} tokio = "1" diff --git a/electrs-blockchain-provider/src/lib.rs b/electrs-blockchain-provider/src/lib.rs index 9d9d631c..6b22020e 100644 --- a/electrs-blockchain-provider/src/lib.rs +++ b/electrs-blockchain-provider/src/lib.rs @@ -13,7 +13,7 @@ use bitcoin_test_utils::tx_to_string; use ddk_manager::{error::Error, Blockchain, Utxo}; use lightning::chain::chaininterface::{BroadcasterInterface, ConfirmationTarget, FeeEstimator}; use lightning_block_sync::{BlockData, BlockHeaderData, BlockSource, BlockSourceError}; -use reqwest::blocking::Response; +use reqwest::Response; use serde::Deserialize; use serde::Serialize; @@ -29,8 +29,7 @@ pub enum Target { pub struct ElectrsBlockchainProvider { host: String, - client: reqwest::blocking::Client, - async_client: reqwest::Client, + client: reqwest::Client, network: Network, fees: Arc>, } @@ -46,16 +45,16 @@ impl ElectrsBlockchainProvider { Self { host, network, - client: reqwest::blocking::Client::new(), - async_client: reqwest::Client::new(), + client: reqwest::Client::new(), fees, } } - fn get(&self, sub_url: &str) -> Result { + async fn get(&self, sub_url: &str) -> Result { self.client .get(format!("{}{}", self.host, sub_url)) .send() + .await .map_err(|x| { ddk_manager::error::Error::IOError(lightning::io::Error::new( lightning::io::ErrorKind::Other, @@ -65,14 +64,14 @@ impl ElectrsBlockchainProvider { } async fn get_async(&self, sub_url: &str) -> Result { - self.async_client + self.client .get(format!("{}{}", self.host, sub_url)) .send() .await } - fn get_text(&self, sub_url: &str) -> Result { - self.get(sub_url)?.text().map_err(|x| { + async fn get_text(&self, sub_url: &str) -> Result { + self.get(sub_url).await?.text().await.map_err(|x| { ddk_manager::error::Error::IOError(lightning::io::Error::new( lightning::io::ErrorKind::Other, x, @@ -80,41 +79,49 @@ impl ElectrsBlockchainProvider { }) } - fn get_u64(&self, sub_url: &str) -> Result { - self.get_text(sub_url)? + async fn get_u64(&self, sub_url: &str) -> Result { + self.get_text(sub_url) + .await? .parse() .map_err(|e: std::num::ParseIntError| Error::BlockchainError(e.to_string())) } - fn get_bytes(&self, sub_url: &str) -> Result, Error> { - let bytes = self.get(sub_url)?.bytes(); + async fn get_bytes(&self, sub_url: &str) -> Result, Error> { + let bytes = self.get(sub_url).await?.bytes().await; Ok(bytes .map_err(|e| Error::BlockchainError(e.to_string()))? .into_iter() .collect::>()) } - fn get_from_json(&self, sub_url: &str) -> Result + async fn get_from_json(&self, sub_url: &str) -> Result where T: serde::de::DeserializeOwned, { - self.get(sub_url)? + self.get(sub_url) + .await? .json::() + .await .map_err(|e| Error::BlockchainError(e.to_string())) } - pub fn get_outspends(&self, txid: &Txid) -> Result, Error> { - self.get_from_json(&format!("tx/{txid}/outspends")) + pub async fn get_outspends(&self, txid: &Txid) -> Result, Error> { + self.get_from_json(&format!("tx/{txid}/outspends")).await } } +#[async_trait::async_trait] impl Blockchain for ElectrsBlockchainProvider { - fn send_transaction(&self, transaction: &Transaction) -> Result<(), ddk_manager::error::Error> { + async fn send_transaction( + &self, + transaction: &Transaction, + ) -> Result<(), ddk_manager::error::Error> { let res = self .client .post(format!("{}tx", self.host)) .body(tx_to_string(transaction)) .send() + .await .map_err(|x| { ddk_manager::error::Error::IOError(lightning::io::Error::new( lightning::io::ErrorKind::Other, @@ -122,7 +129,7 @@ impl Blockchain for ElectrsBlockchainProvider { )) })?; if let Err(error) = res.error_for_status_ref() { - let body = res.text().unwrap_or_default(); + let body = res.text().await.unwrap_or_default(); return Err(ddk_manager::error::Error::InvalidParameters(format!( "Server returned error: {error} {body}" ))); @@ -134,31 +141,38 @@ impl Blockchain for ElectrsBlockchainProvider { Ok(self.network) } - fn get_blockchain_height(&self) -> Result { - self.get_u64("blocks/tip/height") + async fn get_blockchain_height(&self) -> Result { + self.get_u64("blocks/tip/height").await } - fn get_block_at_height(&self, height: u64) -> Result { - let hash_at_height = self.get_text(&format!("block-height/{height}"))?; - let raw_block = self.get_bytes(&format!("block/{hash_at_height}/raw"))?; + async fn get_block_at_height(&self, height: u64) -> Result { + let hash_at_height = self.get_text(&format!("block-height/{height}")).await?; + let raw_block = self + .get_bytes(&format!("block/{hash_at_height}/raw")) + .await?; // TODO: Bitcoin IO for all Block::consensus_decode(&mut bitcoin::io::Cursor::new(&*raw_block)) .map_err(|e| Error::BlockchainError(e.to_string())) } - fn get_transaction(&self, tx_id: &Txid) -> Result { - let raw_tx = self.get_bytes(&format!("tx/{tx_id}/raw"))?; + async fn get_transaction( + &self, + tx_id: &Txid, + ) -> Result { + let raw_tx = self.get_bytes(&format!("tx/{tx_id}/raw")).await?; Transaction::consensus_decode(&mut lightning::io::Cursor::new(&*raw_tx)) .map_err(|e| Error::BlockchainError(e.to_string())) } - fn get_transaction_confirmations( + async fn get_transaction_confirmations( &self, tx_id: &Txid, ) -> Result { - let tx_status = self.get_from_json::(&format!("tx/{tx_id}/status"))?; + let tx_status = self + .get_from_json::(&format!("tx/{tx_id}/status")) + .await?; if tx_status.confirmed { - let block_chain_height = self.get_blockchain_height()?; + let block_chain_height = self.get_blockchain_height().await?; if let Some(block_height) = tx_status.block_height { return Ok((block_chain_height - block_height + 1) as u32); } @@ -168,9 +182,12 @@ impl Blockchain for ElectrsBlockchainProvider { } } +#[async_trait::async_trait] impl simple_wallet::WalletBlockchainProvider for ElectrsBlockchainProvider { - fn get_utxos_for_address(&self, address: &bitcoin::Address) -> Result, Error> { - let utxos: Vec = self.get_from_json(&format!("address/{address}/utxo"))?; + async fn get_utxos_for_address(&self, address: &bitcoin::Address) -> Result, Error> { + let utxos: Vec = self + .get_from_json(&format!("address/{address}/utxo")) + .await?; utxos .into_iter() @@ -197,8 +214,10 @@ impl simple_wallet::WalletBlockchainProvider for ElectrsBlockchainProvider { .collect::, Error>>() } - fn is_output_spent(&self, txid: &Txid, vout: u32) -> Result { - let is_spent: SpentResp = self.get_from_json(&format!("tx/{txid}/outspend/{vout}"))?; + async fn is_output_spent(&self, txid: &Txid, vout: u32) -> Result { + let is_spent: SpentResp = self + .get_from_json(&format!("tx/{txid}/outspend/{vout}")) + .await?; Ok(is_spent.spent) } } @@ -327,25 +346,25 @@ impl BlockSource for ElectrsBlockchainProvider { impl BroadcasterInterface for ElectrsBlockchainProvider { fn broadcast_transactions(&self, txs: &[&Transaction]) { - let client = self.client.clone(); - let host = self.host.clone(); - let bodies = txs + let _client = self.client.clone(); + let _host = self.host.clone(); + let _bodies = txs .iter() .map(|tx| bitcoin_test_utils::tx_to_string(tx)) .collect::>(); - std::thread::spawn(move || { - for body in bodies { - match client.post(format!("{host}tx")).body(body).send() { - Err(_) => {} - Ok(res) => { - if res.error_for_status_ref().is_err() { - // let body = res.text().unwrap_or_default(); - // TODO(tibo): log - } - } - }; - } - }); + // std::thread::spawn(move || { + // for body in bodies { + // match client.post(format!("{host}tx")).body(body).send() { + // Err(_) => {} + // Ok(res) => { + // if res.error_for_status_ref().is_err() { + // // let body = res.text().unwrap_or_default(); + // // TODO(tibo): log + // } + // } + // }; + // } + // }); } } @@ -399,16 +418,16 @@ fn store_estimate_for_target( fn poll_for_fee_estimates(fees: Arc>, host: &str) { let host = host.to_owned(); - std::thread::spawn(move || loop { - if let Ok(res) = reqwest::blocking::get(format!("{host}fee-estimates")) { - if let Ok(fee_estimates) = res.json::() { + tokio::spawn(async move { + if let Ok(res) = reqwest::get(format!("{host}fee-estimates")).await { + if let Ok(fee_estimates) = res.json::().await { store_estimate_for_target(&fees, &fee_estimates, Target::Background); store_estimate_for_target(&fees, &fee_estimates, Target::HighPriority); store_estimate_for_target(&fees, &fee_estimates, Target::Normal); } } - std::thread::sleep(Duration::from_secs(60)); + tokio::time::sleep(Duration::from_secs(60)).await; }); } diff --git a/mocks/Cargo.toml b/mocks/Cargo.toml index 1a8e9f7d..28f40eeb 100644 --- a/mocks/Cargo.toml +++ b/mocks/Cargo.toml @@ -1,10 +1,11 @@ [package] authors = ["Crypto Garage"] -edition = "2018" +edition = "2021" name = "mocks" version = "0.1.0" [dependencies] +async-trait = "0.1.83" bitcoin = "0.32.2" ddk-dlc = {path = "../ddk-dlc"} ddk-manager = {path = "../ddk-manager"} diff --git a/mocks/src/mock_blockchain.rs b/mocks/src/mock_blockchain.rs index 46d42d3e..01fac213 100644 --- a/mocks/src/mock_blockchain.rs +++ b/mocks/src/mock_blockchain.rs @@ -23,21 +23,22 @@ impl Default for MockBlockchain { } } +#[async_trait::async_trait] impl Blockchain for MockBlockchain { - fn send_transaction(&self, transaction: &Transaction) -> Result<(), Error> { + async fn send_transaction(&self, transaction: &Transaction) -> Result<(), Error> { self.transactions.lock().unwrap().push(transaction.clone()); Ok(()) } fn get_network(&self) -> Result { Ok(bitcoin::Network::Regtest) } - fn get_blockchain_height(&self) -> Result { + async fn get_blockchain_height(&self) -> Result { Ok(10) } - fn get_block_at_height(&self, _height: u64) -> Result { + async fn get_block_at_height(&self, _height: u64) -> Result { unimplemented!(); } - fn get_transaction(&self, tx_id: &Txid) -> Result { + async fn get_transaction(&self, tx_id: &Txid) -> Result { Ok(self .transactions .lock() @@ -47,17 +48,18 @@ impl Blockchain for MockBlockchain { .unwrap() .clone()) } - fn get_transaction_confirmations(&self, _tx_id: &Txid) -> Result { + async fn get_transaction_confirmations(&self, _tx_id: &Txid) -> Result { Ok(6) } } +#[async_trait::async_trait] impl WalletBlockchainProvider for MockBlockchain { - fn get_utxos_for_address(&self, _address: &bitcoin::Address) -> Result, Error> { + async fn get_utxos_for_address(&self, _address: &bitcoin::Address) -> Result, Error> { unimplemented!() } - fn is_output_spent(&self, _txid: &Txid, _vout: u32) -> Result { + async fn is_output_spent(&self, _txid: &Txid, _vout: u32) -> Result { unimplemented!() } } diff --git a/mocks/src/mock_oracle_provider.rs b/mocks/src/mock_oracle_provider.rs index 678430c3..80dacafd 100644 --- a/mocks/src/mock_oracle_provider.rs +++ b/mocks/src/mock_oracle_provider.rs @@ -55,12 +55,13 @@ impl Default for MockOracle { } } +#[async_trait::async_trait] impl Oracle for MockOracle { fn get_public_key(&self) -> XOnlyPublicKey { XOnlyPublicKey::from_keypair(&self.key_pair).0 } - fn get_announcement(&self, event_id: &str) -> Result { + async fn get_announcement(&self, event_id: &str) -> Result { let res = self .announcements .get(event_id) @@ -68,7 +69,7 @@ impl Oracle for MockOracle { Ok(res.clone()) } - fn get_attestation(&self, event_id: &str) -> Result { + async fn get_attestation(&self, event_id: &str) -> Result { let res = self .attestations .get(event_id) diff --git a/mocks/src/mock_wallet.rs b/mocks/src/mock_wallet.rs index 00671ce0..a54084d8 100644 --- a/mocks/src/mock_wallet.rs +++ b/mocks/src/mock_wallet.rs @@ -14,7 +14,7 @@ pub struct MockWallet { } impl MockWallet { - pub fn new(blockchain: &Rc, utxo_values: &[u64]) -> Self { + pub async fn new(blockchain: &Rc, utxo_values: &[u64]) -> Self { let mut utxos = Vec::with_capacity(utxo_values.len()); for utxo_value in utxo_values { @@ -28,7 +28,7 @@ impl MockWallet { input: vec![], output: vec![tx_out.clone()], }; - blockchain.send_transaction(&tx).unwrap(); + blockchain.send_transaction(&tx).await.unwrap(); let utxo = Utxo { tx_out, outpoint: bitcoin::OutPoint { diff --git a/p2pd-oracle-client/Cargo.toml b/p2pd-oracle-client/Cargo.toml index bfa531e7..db811d52 100644 --- a/p2pd-oracle-client/Cargo.toml +++ b/p2pd-oracle-client/Cargo.toml @@ -6,14 +6,17 @@ license-file = "../LICENSE" name = "p2pd-oracle-client" repository = "https://github.com/p2pderivatives/rust-dlc/tree/master/p2pd-oracle-client" version = "0.1.0" +edition = "2021" [dependencies] +async-trait = "0.1.83" chrono = {version = "0.4.19", features = ["serde"]} ddk-manager = {path = "../ddk-manager"} ddk-messages = {path = "../ddk-messages", features = ["use-serde"]} -reqwest = {version = "0.11", features = ["blocking", "json"]} +reqwest = {version = "0.11", features = ["json"]} secp256k1-zkp = {version = "0.11.0" } serde = {version = "*", features = ["derive"]} [dev-dependencies] mockito = "0.31.0" +tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread", "test-util"] } diff --git a/p2pd-oracle-client/src/lib.rs b/p2pd-oracle-client/src/lib.rs index 7b82fb9c..b2e6f3b3 100644 --- a/p2pd-oracle-client/src/lib.rs +++ b/p2pd-oracle-client/src/lib.rs @@ -69,17 +69,19 @@ struct AttestationResponse { values: Vec, } -fn get(path: &str) -> Result +async fn get(path: &str) -> Result where T: serde::de::DeserializeOwned, { - reqwest::blocking::get(path) + reqwest::get(path) + .await .map_err(|x| { ddk_manager::error::Error::IOError( std::io::Error::new(std::io::ErrorKind::Other, x).into(), ) })? .json::() + .await .map_err(|e| ddk_manager::error::Error::OracleError(e.to_string())) } @@ -109,7 +111,7 @@ impl P2PDOracleClient { /// Try to create an instance of an oracle client connecting to the provided /// host. Returns an error if the host could not be reached. Panics if the /// oracle uses an incompatible format. - pub fn new(host: &str) -> Result { + pub async fn new(host: &str) -> Result { if host.is_empty() { return Err(DlcManagerError::InvalidParameters( "Invalid host".to_string(), @@ -121,7 +123,7 @@ impl P2PDOracleClient { host.to_string() }; let path = pubkey_path(&host); - let public_key = get::(&path)?.public_key; + let public_key = get::(&path).await?.public_key; Ok(P2PDOracleClient { host, public_key }) } } @@ -144,19 +146,23 @@ fn parse_event_id(event_id: &str) -> Result<(String, DateTime), DlcManagerE Ok((asset_id.to_string(), date_time)) } +#[async_trait::async_trait] impl Oracle for P2PDOracleClient { fn get_public_key(&self) -> XOnlyPublicKey { self.public_key } - fn get_announcement(&self, event_id: &str) -> Result { + async fn get_announcement( + &self, + event_id: &str, + ) -> Result { let (asset_id, date_time) = parse_event_id(event_id)?; let path = announcement_path(&self.host, &asset_id, &date_time); - let announcement = get(&path)?; + let announcement = get(&path).await?; Ok(announcement) } - fn get_attestation( + async fn get_attestation( &self, event_id: &str, ) -> Result { @@ -166,7 +172,7 @@ impl Oracle for P2PDOracleClient { event_id: _, signatures, values, - } = get::(&path)?; + } = get::(&path).await?; Ok(OracleAttestation { event_id: event_id.to_string(), @@ -202,8 +208,8 @@ mod tests { ).create() } - #[test] - fn get_public_key_test() { + #[tokio::test] + async fn get_public_key_test() { let url = &mockito::server_url(); let _m = pubkey_mock(); let expected_pk: XOnlyPublicKey = @@ -211,13 +217,15 @@ mod tests { .parse() .unwrap(); - let client = P2PDOracleClient::new(url).expect("Error creating client instance."); + let client = P2PDOracleClient::new(url) + .await + .expect("Error creating client instance."); assert_eq!(expected_pk, client.get_public_key()); } - #[test] - fn get_announcement_test() { + #[tokio::test] + async fn get_announcement_test() { let url = &mockito::server_url(); let _pubkey_mock = pubkey_mock(); let path: &str = &announcement_path( @@ -229,15 +237,18 @@ mod tests { ); let _m = mock("GET", path).with_body(r#"{"announcementSignature":"f83db0ca25e4c209b55156737b0c65470a9702fe9d1d19a129994786384289397895e403ff37710095a04a0841a95738e3e8bc35bdef6bce50bf34eeb182bd9b","oraclePublicKey":"10dc8cf51ae3ee1c7967ffb9c9633a5ab06206535d8e1319f005a01ba33bc05d","oracleEvent":{"oracleNonces":["aca32fc8dead13983c655638ef921f1d38ef2f5286e58b2a1dab32b6e086e208","89603f8179830590fdce45eb17ba8bdf74e295a4633b58b46c9ede8274774164","5f3fcdfbba9ec75cb0868e04ec1f97089b4153fb2076bd1e017048e9df633aa1","8436d00f7331491dc6512e560a1f2414be42e893992eccb495642eefc7c5bf37","0d2593764c9c27eba0be3ca6c71a2de4e49a5f4aa1ce1e2cc379be3939547501","414318491e96919e67583db7a47eb1f8b4f1194bcb5b5dcc4fd10492d89926e4","b9a5ded7295e0343f385e5abedfd9e5f4137de8f67de0afa9396f7e0f996ef79","badf0bfe230ed605161630d8e3a092d7448461042db38912bc6c6a0ab195ff71","6e4780213cd7ed9de1300146079b897cae89dec7800065f615974193f58aa6db","7b12b48ad95634ee4ca476dd57e634fddc328e10276e71d27e0ae626fad7d699","a8058604adf590a1c38f8be19aa44175eb2d1130eb4d7f39a34f89f0a3fbed27","ffc3208f60b585cdc778be1290b352c34c22652d5348a87885816bcf17a80116","cb34c13f80b49e729e863035f30e1f8ea7777618eedb6d666c3b1c85a5b8a637","5000991f4631c0bba5d026f02125fdbe77e019dde57d31ce7f23ae3601a18623","094433a2432b81bbb6d6b7d65dc3498e2a7c9de5f35672d67097d54d920eadd2","11dff6b40b0938e1943c7888633d88871c2a2a1c16f412b22b80ba7ed8af8788","d5957f1a199b4abbc06894479c722ad0c4f120f0d5afeb76d589127213e33170","80e09bb453e6a0a444ec3ba222a62ecd59540b9dd8280566a17bebdfdfbd7a9e","0fe775b79b2172cb961e7c1aa54d521360903680680aaa55ea8be0404ee3768c","bfcdbb2cbcffba41048149d4bcf2a41cd5fd0a713df6f48104ade3022c284575"],"eventMaturityEpoch":1653865200,"eventDescriptor":{"digitDecompositionEvent":{"base":2,"isSigned":false,"unit":"usd/btc","precision":0,"nbDigits":20}},"eventId":"btcusd1653865200"}}"#).create(); - let client = P2PDOracleClient::new(url).expect("Error creating client instance"); + let client = P2PDOracleClient::new(url) + .await + .expect("Error creating client instance"); client .get_announcement("btcusd1624943400") + .await .expect("Error getting announcement"); } - #[test] - fn get_attestation_test() { + #[tokio::test] + async fn get_attestation_test() { let url = &mockito::server_url(); let _pubkey_mock = pubkey_mock(); let path: &str = &attestation_path( @@ -250,10 +261,13 @@ mod tests { let _m = mock("GET", path).with_body(r#"{"eventId":"btcusd1653517020","signatures":["ee05b1211d5f974732b10107dd302da062be47cd18f061c5080a50743412f9fd590cad90cfea762472e6fe865c4223bd388c877b7881a27892e15843ff1ac360","59ab83597089b48f5b3c2fd07c11edffa6b1180bdb6d9e7d6924979292d9c53fe79396ceb0782d5941c284d1642377136c06b2d9c2b85bda5969a773a971b5b0","d1f8c31a83bb34433da5b9808bb3692dd212b9022b7bc8f269fc817e96a7195db18262e934bebd4e68a3f2c96550826a5530350662df4c86c004f5cf1121ca67","e5cec554c39c4dd544d70175128271eecad77c1e3eaa6994c657e257d5c1c9dcd19b041ea8030e75448245b7f91705ad914c32761671a6172f928904b439ea6b","a209116d20f0931113c0880e8cd22d3f003609a32322ff8df241ef16e7d4efd1a9b723f582a22073e21188635f09f41f270f3126014542861be14b62b09c0ecc","f1da0b482f08f545a92338392b71cec33d948a5e5732ee4d5c0a87bd6b6cc12feeb1498da7afd93ae48ec4ce581ee79c0e92f338d3777c2ef06578e4ec1a853c","d9ab68244a3b47cc8cbd5a972f2f5059fc6b9711dba8d4a7a23607a99b9655593bab3abc1d3b02402cd0809c3c7016c741742efb363227de2bcfdcf290a053b3","c1146c1767a947f77794d05a2f58e50af824e3c8d70adde883e58d2dc1ddb157323b0aaf8cfb5b076a12395756bdcda64ab5d4799e43c88a41993659e6d49471","0d29d9383c9ee41055e1cb40104c9ca75280162779c0162cb6bf9aca2b223aba17de4b3f0f29ae6b749f22ba467b7e9f05456e8abb3ec328f62b7a924c6d4828","2bcc54002ceb271a940f24bc6dd0562b99c2d76cfb8f145f42ac37bc34fd3e94adba1194c5be91932b818c5715c73f287e066e228d796a373c4aec67fd777070","a91f77e3435c577682ff744d6f7da66c865a42e8645276dedf2ed2b8bc4c80285dff4b553b2231592e0fa8b4f242acb6888519fe82c457cc5204e5d9d511303a","546409d6bcdcfd5bef39957c8b1b09f7805b08ec2311bc73cf6927ae11f3567ffe8428aa7faa661518e9c02a702212ab05e494aab84624c3dd1a710f8c4c369b","9d601ee8a3d28dcdfdd05581f1b24d6e5a576f0b5544eb7c9921cb87a23fdb293c1edca89b43b5b84c1e305fbe52facbe6b03575aed8f95b4faccc90e0eb45ef","636b8028e9cd6cba6be5b3c1789b62aecfc17e9c28d7a621cfad2c3cf751046528028e1dbd6cee050d5d570cf5a3d8986471d73e7edca4093e36fc8e1097fb65","57c6337b52dc7fd8f49b29105f168fc9b4cb88ed2ba5f0e9a80a21e20836f87f875c3fe92afb437dd5647630b54eda6ba1be76ba6df8b641eb2e8be8ff1182dc","9e8843e32f9de4cd6d5bb9e938fd014babe11bb1faf35fc411d754259bc374f34dd841ed91f6bb3f030bc55a4791cdc41471c33b3f05fd35b9d1768fd381f953","97da4963747ab5e50534b93274065cba4fd24e6b7a9d3310db2596af24f70961fb03535e2a5ae272f7ea14e86daafa57073631596fecf7ceadf4ae3e6941b69e","94a414569743f87f1462a503be8cff1f229096d190b8b1349519c612b74eea872d5d763570aaaa54fad0605a43d742203bce489deea5570750030191e293c253","4d7117b89aad73eca7b341749bd54ffdd459b9b8b4ff128344d09273f66a3d2c01d2c86b61f7642d6e81f488580b456685cd68660458cff83b8858a05c9a1f4d","b12153a393a4fddac3079c1878cb89afccfe0ac8f539743c0608049f445e49ac7c89e33fcf832cda8d7e8a4f4dae94a303170f16c697feed8b78015873bd5ffc"],"values":["0","0","0","0","0","1","1","1","0","1","0","0","0","0","1","1","1","0","1","0"]}"#).create(); - let client = P2PDOracleClient::new(url).expect("Error creating client instance"); + let client = P2PDOracleClient::new(url) + .await + .expect("Error creating client instance"); client .get_attestation("btcusd1624943400") + .await .expect("Error getting attestation"); } } diff --git a/sample/src/cli.rs b/sample/src/cli.rs index b51f93b9..72a22ae7 100644 --- a/sample/src/cli.rs +++ b/sample/src/cli.rs @@ -19,8 +19,9 @@ use std::io; use std::io::{BufRead, Write}; use std::net::{SocketAddr, ToSocketAddrs}; use std::str::SplitWhitespace; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::Duration; +use tokio::sync::Mutex; #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] @@ -97,7 +98,7 @@ pub(crate) async fn poll_for_user_input( print!("> "); io::stdout().flush().unwrap(); // Without flushing, the `>` doesn't print for line in stdin.lock().lines() { - process_incoming_messages(&peer_manager, &dlc_manager, &dlc_message_handler); + process_incoming_messages(&peer_manager, &dlc_manager, &dlc_message_handler).await; let line = line.unwrap(); let mut words = line.split_whitespace(); if let Some(word) = words.next() { @@ -164,32 +165,30 @@ pub(crate) async fn poll_for_user_input( .expect("Error deserializing contract input."); let manager_clone = dlc_manager.clone(); let is_contract = o == "offercontract"; - let offer = tokio::task::spawn_blocking(move || { - if is_contract { - DlcMessage::Offer( - manager_clone - .lock() - .unwrap() - .send_offer(&contract_input, pubkey) - .expect("Error sending offer"), - ) - } else { - DlcMessage::OfferChannel( - manager_clone - .lock() - .unwrap() - .offer_channel(&contract_input, pubkey) - .expect("Error sending offer channel"), - ) - } - }) - .await - .unwrap(); + let offer = if is_contract { + DlcMessage::Offer( + manager_clone + .lock() + .await + .send_offer(&contract_input, pubkey) + .await + .expect("Error sending offer"), + ) + } else { + DlcMessage::OfferChannel( + manager_clone + .lock() + .await + .offer_channel(&contract_input, pubkey) + .await + .expect("Error sending offer channel"), + ) + }; dlc_message_handler.send_message(pubkey, offer); peer_manager.process_events(); } "listoffers" => { - let locked_manager = dlc_manager.lock().unwrap(); + let locked_manager = dlc_manager.lock().await; for offer in locked_manager .get_store() .get_contract_offers() @@ -213,8 +212,9 @@ pub(crate) async fn poll_for_user_input( let (_, node_id, msg) = dlc_manager .lock() - .unwrap() + .await .accept_contract_offer(&contract_id) + .await .expect("Error accepting contract."); dlc_message_handler.send_message(node_id, DlcMessage::Accept(msg)); peer_manager.process_events(); @@ -222,62 +222,60 @@ pub(crate) async fn poll_for_user_input( "listcontracts" => { let manager_clone = dlc_manager.clone(); // Because the oracle client is currently blocking we need to use `spawn_blocking` here. - tokio::task::spawn_blocking(move || { - manager_clone - .lock() - .unwrap() - .periodic_check(true) - .expect("Error doing periodic check."); - let contracts = manager_clone - .lock() - .unwrap() - .get_store() - .get_contracts() - .expect("Error retrieving contract list."); - for contract in contracts { - let id = hex_str(&contract.get_id()); - match contract { - Contract::Offered(_) => { - println!("Offered contract: {}", id); - } - Contract::Accepted(_) => { - println!("Accepted contract: {}", id); - } - Contract::Confirmed(_) => { - println!("Confirmed contract: {}", id); - } - Contract::Signed(_) => { - println!("Signed contract: {}", id); - } - Contract::Closed(closed) => { - println!("Closed contract: {}", id); - if let Some(attestations) = closed.attestations { - println!( - "Outcomes: {:?}", - attestations - .iter() - .map(|x| x.outcomes.clone()) - .collect::>() - ); - } - println!("PnL: {} sats", closed.pnl) - } - Contract::Refunded(_) => { - println!("Refunded contract: {}", id); - } - Contract::FailedAccept(_) | Contract::FailedSign(_) => { - println!("Failed contract: {}", id); + + manager_clone + .lock() + .await + .periodic_check(true) + .await + .expect("Error doing periodic check."); + let contracts = manager_clone + .lock() + .await + .get_store() + .get_contracts() + .expect("Error retrieving contract list."); + for contract in contracts { + let id = hex_str(&contract.get_id()); + match contract { + Contract::Offered(_) => { + println!("Offered contract: {}", id); + } + Contract::Accepted(_) => { + println!("Accepted contract: {}", id); + } + Contract::Confirmed(_) => { + println!("Confirmed contract: {}", id); + } + Contract::Signed(_) => { + println!("Signed contract: {}", id); + } + Contract::Closed(closed) => { + println!("Closed contract: {}", id); + if let Some(attestations) = closed.attestations { + println!( + "Outcomes: {:?}", + attestations + .iter() + .map(|x| x.outcomes.clone()) + .collect::>() + ); } - Contract::Rejected(_) => println!("Rejected contract: {}", id), - Contract::PreClosed(_) => println!("Pre-closed contract: {}", id), + println!("PnL: {} sats", closed.pnl) + } + Contract::Refunded(_) => { + println!("Refunded contract: {}", id); + } + Contract::FailedAccept(_) | Contract::FailedSign(_) => { + println!("Failed contract: {}", id); } + Contract::Rejected(_) => println!("Rejected contract: {}", id), + Contract::PreClosed(_) => println!("Pre-closed contract: {}", id), } - }) - .await - .expect("Error listing contract info"); + } } "listchanneloffers" => { - let locked_manager = dlc_manager.lock().unwrap(); + let locked_manager = dlc_manager.lock().await; for offer in locked_manager .get_store() .get_offered_channels() @@ -305,8 +303,9 @@ pub(crate) async fn poll_for_user_input( let (msg, _, _, node_id) = dlc_manager .lock() - .unwrap() + .await .accept_channel(&channel_id) + .await .expect("Error accepting channel."); dlc_message_handler.send_message(node_id, DlcMessage::AcceptChannel(msg)); peer_manager.process_events(); @@ -323,7 +322,7 @@ pub(crate) async fn poll_for_user_input( let (msg, node_id) = dlc_manager .lock() - .unwrap() + .await .settle_offer(&channel_id, counter_payout) .expect("Error getting settle offer message."); dlc_message_handler.send_message(node_id, DlcMessage::SettleOffer(msg)); @@ -333,7 +332,7 @@ pub(crate) async fn poll_for_user_input( let channel_id = read_id_or_continue!(words, l, "channel id"); let (msg, node_id) = dlc_manager .lock() - .unwrap() + .await .accept_settle_offer(&channel_id) .expect("Error accepting settle channel offer."); dlc_message_handler.send_message(node_id, DlcMessage::SettleAccept(msg)); @@ -343,14 +342,14 @@ pub(crate) async fn poll_for_user_input( let channel_id = read_id_or_continue!(words, l, "channel id"); let (msg, node_id) = dlc_manager .lock() - .unwrap() + .await .reject_settle_offer(&channel_id) .expect("Error rejecting settle channel offer."); dlc_message_handler.send_message(node_id, DlcMessage::Reject(msg)); peer_manager.process_events(); } "listsettlechanneloffers" => { - let locked_manager = dlc_manager.lock().unwrap(); + let locked_manager = dlc_manager.lock().await; for channel in locked_manager .get_store() .get_signed_channels(Some(SignedChannelStateType::SettledReceived)) @@ -380,20 +379,17 @@ pub(crate) async fn poll_for_user_input( let contract_input: ContractInput = serde_json::from_str(&contract_input_str) .expect("Error deserializing contract input."); let manager_clone = dlc_manager.clone(); - let (renew_offer, node_id) = tokio::task::spawn_blocking(move || { - manager_clone - .lock() - .unwrap() - .renew_offer(&channel_id, counter_payout, &contract_input) - .expect("Error sending offer") - }) - .await - .unwrap(); + let (renew_offer, node_id) = manager_clone + .lock() + .await + .renew_offer(&channel_id, counter_payout, &contract_input) + .await + .expect("Error sending offer"); dlc_message_handler.send_message(node_id, DlcMessage::RenewOffer(renew_offer)); peer_manager.process_events(); } "listrenewchanneloffers" => { - let locked_manager = dlc_manager.lock().unwrap(); + let locked_manager = dlc_manager.lock().await; for channel in locked_manager .get_store() .get_signed_channels(Some(SignedChannelStateType::RenewOffered)) @@ -426,7 +422,7 @@ pub(crate) async fn poll_for_user_input( let channel_id = read_id_or_continue!(words, l, "channel id"); let (msg, node_id) = dlc_manager .lock() - .unwrap() + .await .accept_renew_offer(&channel_id) .expect("Error accepting channel."); dlc_message_handler.send_message(node_id, DlcMessage::RenewAccept(msg)); @@ -436,14 +432,14 @@ pub(crate) async fn poll_for_user_input( let channel_id = read_id_or_continue!(words, l, "channel id"); let (msg, node_id) = dlc_manager .lock() - .unwrap() + .await .reject_renew_offer(&channel_id) .expect("Error rejecting settle channel offer."); dlc_message_handler.send_message(node_id, DlcMessage::Reject(msg)); peer_manager.process_events(); } "listsignedchannels" => { - let locked_manager = dlc_manager.lock().unwrap(); + let locked_manager = dlc_manager.lock().await; for channel in locked_manager .get_store() .get_signed_channels(None) @@ -589,7 +585,7 @@ pub(crate) fn parse_peer_info( Ok((pubkey.unwrap(), peer_addr.unwrap().unwrap())) } -fn process_incoming_messages( +async fn process_incoming_messages( peer_manager: &Arc, dlc_manager: &Arc>, dlc_message_handler: &Arc, @@ -601,8 +597,9 @@ fn process_incoming_messages( println!("Processing message from {}", node_id); let resp = dlc_manager .lock() - .unwrap() + .await .on_dlc_message(&message, node_id) + .await .expect("Error processing message"); if let Some(msg) = resp { println!("Sending message to {}", node_id); diff --git a/sample/src/main.rs b/sample/src/main.rs index e725e26b..545ceef0 100644 --- a/sample/src/main.rs +++ b/sample/src/main.rs @@ -18,8 +18,9 @@ use p2pd_oracle_client::P2PDOracleClient; use std::collections::hash_map::HashMap; use std::env; use std::fs; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::SystemTime; +use tokio::sync::Mutex; pub(crate) type PeerManager = LdkPeerManager< SocketDescriptor, @@ -72,11 +73,9 @@ async fn main() { // client uses reqwest in blocking mode to satisfy the non async oracle interface // so we need to use `spawn_blocking`. let oracle_host = config.oracle_config.host; - let oracle = tokio::task::spawn_blocking(move || { - P2PDOracleClient::new(&oracle_host).expect("Error creating oracle client") - }) - .await - .unwrap(); + let oracle = P2PDOracleClient::new(&oracle_host) + .await + .expect("Error creating oracle client"); let mut oracles = HashMap::new(); oracles.insert(oracle.get_public_key(), Box::new(oracle)); @@ -94,6 +93,7 @@ async fn main() { Arc::new(ddk_manager::SystemTimeProvider {}), bitcoind_provider.clone(), ) + .await .expect("Could not create manager."), )); diff --git a/simple-wallet/Cargo.toml b/simple-wallet/Cargo.toml index e09ccf05..686ed301 100644 --- a/simple-wallet/Cargo.toml +++ b/simple-wallet/Cargo.toml @@ -6,6 +6,7 @@ version = "0.1.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +async-trait = "0.1.83" bdk_wallet = "1.0.0-beta.2" bitcoin = "0.32.2" ddk-dlc = {path = "../ddk-dlc"} diff --git a/simple-wallet/src/lib.rs b/simple-wallet/src/lib.rs index 9a59e331..2ac6f04e 100644 --- a/simple-wallet/src/lib.rs +++ b/simple-wallet/src/lib.rs @@ -19,10 +19,11 @@ use secp256k1_zkp::{rand::thread_rng, All, PublicKey, Secp256k1, SecretKey}; type Result = core::result::Result; +#[async_trait::async_trait] /// Trait providing blockchain information to the wallet. pub trait WalletBlockchainProvider: Blockchain + FeeEstimator { - fn get_utxos_for_address(&self, address: &Address) -> Result>; - fn is_output_spent(&self, txid: &Txid, vout: u32) -> Result; + async fn get_utxos_for_address(&self, address: &Address) -> Result>; + async fn is_output_spent(&self, txid: &Txid, vout: u32) -> Result; } /// Trait enabling the wallet to persist data. @@ -68,13 +69,14 @@ where } /// Refresh the wallet checking and updating the UTXO states. - pub fn refresh(&self) -> Result<()> { + pub async fn refresh(&self) -> Result<()> { let utxos: Vec = self.storage.get_utxos()?; for utxo in &utxos { let is_spent = self .blockchain - .is_output_spent(&utxo.outpoint.txid, utxo.outpoint.vout)?; + .is_output_spent(&utxo.outpoint.txid, utxo.outpoint.vout) + .await?; if is_spent { self.storage.delete_utxo(utxo)?; } @@ -83,7 +85,7 @@ where let addresses = self.storage.get_addresses()?; for address in &addresses { - let utxos = self.blockchain.get_utxos_for_address(address)?; + let utxos = self.blockchain.get_utxos_for_address(address).await?; for utxo in &utxos { if !self.storage.has_utxo(utxo)? { @@ -117,7 +119,7 @@ where /// Creates a transaction with all wallet UTXOs as inputs and a single output /// sending everything to the given address. - pub fn empty_to_address(&self, address: &Address) -> Result<()> { + pub async fn empty_to_address(&self, address: &Address) -> Result<()> { let utxos = self .storage .get_utxos() @@ -172,7 +174,7 @@ where .extract_tx() .expect("could not extract transaction from psbt"); - self.blockchain.send_transaction(&tx) + self.blockchain.send_transaction(&tx).await } }