From 0f94f39a0f7d1a8bf67272aaf408b7021eb0a474 Mon Sep 17 00:00:00 2001 From: yancy Date: Tue, 13 Feb 2024 08:48:16 +0100 Subject: [PATCH] Return Iterator instead of Vector for SRD --- benches/coin_selection.rs | 2 +- src/branch_and_bound.rs | 231 ++++++++++++++------------------------ src/lib.rs | 15 ++- src/single_random_draw.rs | 97 +++++++++------- 4 files changed, 153 insertions(+), 192 deletions(-) diff --git a/benches/coin_selection.rs b/benches/coin_selection.rs index 40e3c30..7a5b741 100644 --- a/benches/coin_selection.rs +++ b/benches/coin_selection.rs @@ -33,7 +33,7 @@ pub fn criterion_benchmark(c: &mut Criterion) { black_box(cost_of_change), black_box(FeeRate::ZERO), black_box(FeeRate::ZERO), - black_box(&mut utxo_pool), + black_box(&utxo_pool), ) .unwrap() .collect(); diff --git a/src/branch_and_bound.rs b/src/branch_and_bound.rs index db8edcd..d562f24 100644 --- a/src/branch_and_bound.rs +++ b/src/branch_and_bound.rs @@ -134,7 +134,7 @@ pub fn select_coins_bnb( cost_of_change: Amount, fee_rate: FeeRate, long_term_fee_rate: FeeRate, - weighted_utxos: &mut [WeightedUtxo] + weighted_utxos: &[WeightedUtxo], ) -> Option> { // Total_Tries in Core: // https://github.com/bitcoin/bitcoin/blob/1d9da8da309d1dbf9aef15eb8dc43b4a2dc3d309/src/wallet/coinselection.cpp#L74 @@ -395,17 +395,12 @@ mod tests { #[test] fn select_coins_bnb_one() { let target = Amount::from_str("1 cBTC").unwrap(); - let mut weighted_utxos = create_weighted_utxos(Amount::ZERO); - - let list: Vec<_> = select_coins_bnb( - target, - Amount::ZERO, - FeeRate::ZERO, - FeeRate::ZERO, - &mut weighted_utxos, - ) - .unwrap() - .collect(); + let weighted_utxos = create_weighted_utxos(Amount::ZERO); + + let list: Vec<_> = + select_coins_bnb(target, Amount::ZERO, FeeRate::ZERO, FeeRate::ZERO, &weighted_utxos) + .unwrap() + .collect(); assert_eq!(list.len(), 1); assert_eq!(list[0].utxo.value, Amount::from_str("1 cBTC").unwrap()); @@ -414,17 +409,12 @@ mod tests { #[test] fn select_coins_bnb_two() { let target = Amount::from_str("2 cBTC").unwrap(); - let mut weighted_utxos = create_weighted_utxos(Amount::ZERO); - - let list: Vec<_> = select_coins_bnb( - target, - Amount::ZERO, - FeeRate::ZERO, - FeeRate::ZERO, - &mut weighted_utxos, - ) - .unwrap() - .collect(); + let weighted_utxos = create_weighted_utxos(Amount::ZERO); + + let list: Vec<_> = + select_coins_bnb(target, Amount::ZERO, FeeRate::ZERO, FeeRate::ZERO, &weighted_utxos) + .unwrap() + .collect(); assert_eq!(list.len(), 1); assert_eq!(list[0].utxo.value, Amount::from_str("2 cBTC").unwrap()); @@ -434,17 +424,12 @@ mod tests { fn select_coins_bnb_three() { let target = Amount::from_str("3 cBTC").unwrap(); - let mut weighted_utxos = create_weighted_utxos(Amount::ZERO); - - let list: Vec<_> = select_coins_bnb( - target, - Amount::ZERO, - FeeRate::ZERO, - FeeRate::ZERO, - &mut weighted_utxos, - ) - .unwrap() - .collect(); + let weighted_utxos = create_weighted_utxos(Amount::ZERO); + + let list: Vec<_> = + select_coins_bnb(target, Amount::ZERO, FeeRate::ZERO, FeeRate::ZERO, &weighted_utxos) + .unwrap() + .collect(); assert_eq!(list.len(), 2); assert_eq!(list[0].utxo.value, Amount::from_str("2 cBTC").unwrap()); @@ -454,17 +439,12 @@ mod tests { #[test] fn select_coins_bnb_four() { let target = Amount::from_str("4 cBTC").unwrap(); - let mut weighted_utxos = create_weighted_utxos(Amount::ZERO); - - let list: Vec<_> = select_coins_bnb( - target, - Amount::ZERO, - FeeRate::ZERO, - FeeRate::ZERO, - &mut weighted_utxos, - ) - .unwrap() - .collect(); + let weighted_utxos = create_weighted_utxos(Amount::ZERO); + + let list: Vec<_> = + select_coins_bnb(target, Amount::ZERO, FeeRate::ZERO, FeeRate::ZERO, &weighted_utxos) + .unwrap() + .collect(); assert_eq!(list.len(), 2); assert_eq!(list[0].utxo.value, Amount::from_str("3 cBTC").unwrap()); @@ -474,17 +454,12 @@ mod tests { #[test] fn select_coins_bnb_five() { let target = Amount::from_str("5 cBTC").unwrap(); - let mut weighted_utxos = create_weighted_utxos(Amount::ZERO); - - let list: Vec<_> = select_coins_bnb( - target, - Amount::ZERO, - FeeRate::ZERO, - FeeRate::ZERO, - &mut weighted_utxos, - ) - .unwrap() - .collect(); + let weighted_utxos = create_weighted_utxos(Amount::ZERO); + + let list: Vec<_> = + select_coins_bnb(target, Amount::ZERO, FeeRate::ZERO, FeeRate::ZERO, &weighted_utxos) + .unwrap() + .collect(); assert_eq!(list.len(), 2); assert_eq!(list[0].utxo.value, Amount::from_str("3 cBTC").unwrap()); @@ -494,17 +469,12 @@ mod tests { #[test] fn select_coins_bnb_six() { let target = Amount::from_str("6 cBTC").unwrap(); - let mut weighted_utxos = create_weighted_utxos(Amount::ZERO); - - let list: Vec<_> = select_coins_bnb( - target, - Amount::ZERO, - FeeRate::ZERO, - FeeRate::ZERO, - &mut weighted_utxos, - ) - .unwrap() - .collect(); + let weighted_utxos = create_weighted_utxos(Amount::ZERO); + + let list: Vec<_> = + select_coins_bnb(target, Amount::ZERO, FeeRate::ZERO, FeeRate::ZERO, &weighted_utxos) + .unwrap() + .collect(); assert_eq!(list.len(), 3); assert_eq!(list[0].utxo.value, Amount::from_str("3 cBTC").unwrap()); @@ -515,17 +485,12 @@ mod tests { #[test] fn select_coins_bnb_seven() { let target = Amount::from_str("7 cBTC").unwrap(); - let mut weighted_utxos = create_weighted_utxos(Amount::ZERO); - - let list: Vec<_> = select_coins_bnb( - target, - Amount::ZERO, - FeeRate::ZERO, - FeeRate::ZERO, - &mut weighted_utxos, - ) - .unwrap() - .collect(); + let weighted_utxos = create_weighted_utxos(Amount::ZERO); + + let list: Vec<_> = + select_coins_bnb(target, Amount::ZERO, FeeRate::ZERO, FeeRate::ZERO, &weighted_utxos) + .unwrap() + .collect(); assert_eq!(list.len(), 3); assert_eq!(list[0].utxo.value, Amount::from_str("4 cBTC").unwrap()); @@ -536,17 +501,12 @@ mod tests { #[test] fn select_coins_bnb_eight() { let target = Amount::from_str("8 cBTC").unwrap(); - let mut weighted_utxos = create_weighted_utxos(Amount::ZERO); - - let list: Vec<_> = select_coins_bnb( - target, - Amount::ZERO, - FeeRate::ZERO, - FeeRate::ZERO, - &mut weighted_utxos, - ) - .unwrap() - .collect(); + let weighted_utxos = create_weighted_utxos(Amount::ZERO); + + let list: Vec<_> = + select_coins_bnb(target, Amount::ZERO, FeeRate::ZERO, FeeRate::ZERO, &weighted_utxos) + .unwrap() + .collect(); assert_eq!(list.len(), 3); assert_eq!(list[0].utxo.value, Amount::from_str("4 cBTC").unwrap()); @@ -557,17 +517,12 @@ mod tests { #[test] fn select_coins_bnb_nine() { let target = Amount::from_str("9 cBTC").unwrap(); - let mut weighted_utxos = create_weighted_utxos(Amount::ZERO); - - let list: Vec<_> = select_coins_bnb( - target, - Amount::ZERO, - FeeRate::ZERO, - FeeRate::ZERO, - &mut weighted_utxos, - ) - .unwrap() - .collect(); + let weighted_utxos = create_weighted_utxos(Amount::ZERO); + + let list: Vec<_> = + select_coins_bnb(target, Amount::ZERO, FeeRate::ZERO, FeeRate::ZERO, &weighted_utxos) + .unwrap() + .collect(); assert_eq!(list.len(), 3); assert_eq!(list[0].utxo.value, Amount::from_str("4 cBTC").unwrap()); @@ -578,17 +533,12 @@ mod tests { #[test] fn select_coins_bnb_ten() { let target = Amount::from_str("10 cBTC").unwrap(); - let mut weighted_utxos = create_weighted_utxos(Amount::ZERO); - - let list: Vec<_> = select_coins_bnb( - target, - Amount::ZERO, - FeeRate::ZERO, - FeeRate::ZERO, - &mut weighted_utxos, - ) - .unwrap() - .collect(); + let weighted_utxos = create_weighted_utxos(Amount::ZERO); + + let list: Vec<_> = + select_coins_bnb(target, Amount::ZERO, FeeRate::ZERO, FeeRate::ZERO, &weighted_utxos) + .unwrap() + .collect(); assert_eq!(list.len(), 4); assert_eq!(list[0].utxo.value, Amount::from_str("4 cBTC").unwrap()); @@ -613,18 +563,17 @@ mod tests { }, }]; - let mut wu = weighted_utxos.clone(); + let wu = weighted_utxos.clone(); let list: Vec<_> = - select_coins_bnb(target, cost_of_change, FeeRate::ZERO, FeeRate::ZERO, &mut wu) + select_coins_bnb(target, cost_of_change, FeeRate::ZERO, FeeRate::ZERO, &wu) .unwrap() .collect(); assert_eq!(list.len(), 1); assert_eq!(list[0].utxo.value, Amount::from_str("1.5 cBTC").unwrap()); - let index_list = - select_coins_bnb(target, Amount::ZERO, FeeRate::ZERO, FeeRate::ZERO, &mut wu); + let index_list = select_coins_bnb(target, Amount::ZERO, FeeRate::ZERO, FeeRate::ZERO, &wu); assert!(index_list.is_none()); } @@ -644,8 +593,8 @@ mod tests { }, }]; - let mut wu = weighted_utxos.clone(); - let index_list = select_coins_bnb(target, Amount::ZERO, fee_rate, fee_rate, &mut wu); + let wu = weighted_utxos.clone(); + let index_list = select_coins_bnb(target, Amount::ZERO, fee_rate, fee_rate, &wu); assert!(index_list.is_none()); } @@ -680,10 +629,9 @@ mod tests { }, ]; - let mut wu = weighted_utxos.clone(); - let list: Vec<_> = select_coins_bnb(target, cost_of_change, fee_rate, fee_rate, &mut wu) - .unwrap() - .collect(); + let wu = weighted_utxos.clone(); + let list: Vec<_> = + select_coins_bnb(target, cost_of_change, fee_rate, fee_rate, &wu).unwrap().collect(); assert_eq!(list.len(), 1); assert_eq!(list[0].utxo.value, Amount::from_str("1.5 cBTC").unwrap()); } @@ -691,14 +639,9 @@ mod tests { #[test] fn select_coins_bnb_target_greater_than_value() { let target = Amount::from_str("11 cBTC").unwrap(); - let mut weighted_utxos = create_weighted_utxos(Amount::ZERO); - let list = select_coins_bnb( - target, - Amount::ZERO, - FeeRate::ZERO, - FeeRate::ZERO, - &mut weighted_utxos, - ); + let weighted_utxos = create_weighted_utxos(Amount::ZERO); + let list = + select_coins_bnb(target, Amount::ZERO, FeeRate::ZERO, FeeRate::ZERO, &weighted_utxos); assert!(list.is_none()); } @@ -706,7 +649,7 @@ mod tests { fn select_coins_bnb_consume_more_inputs_when_cheap() { let target = Amount::from_str("6 cBTC").unwrap(); let fee = Amount::from_str("2 sats").unwrap(); - let mut weighted_utxos = create_weighted_utxos(fee); + let weighted_utxos = create_weighted_utxos(fee); let fee_rate = FeeRate::from_sat_per_kwu(10); let lt_fee_rate = FeeRate::from_sat_per_kwu(20); @@ -714,7 +657,7 @@ mod tests { // the possible combinations are 2,4 or 1,2,3 // fees are cheap, so use 1,2,3 let list: Vec<_> = - select_coins_bnb(target, Amount::ZERO, fee_rate, lt_fee_rate, &mut weighted_utxos) + select_coins_bnb(target, Amount::ZERO, fee_rate, lt_fee_rate, &weighted_utxos) .unwrap() .collect(); @@ -728,7 +671,7 @@ mod tests { fn select_coins_bnb_consume_less_inputs_when_expensive() { let target = Amount::from_str("6 cBTC").unwrap(); let fee = Amount::from_str("4 sats").unwrap(); - let mut weighted_utxos = create_weighted_utxos(fee); + let weighted_utxos = create_weighted_utxos(fee); let fee_rate = FeeRate::from_sat_per_kwu(20); let lt_fee_rate = FeeRate::from_sat_per_kwu(10); @@ -736,7 +679,7 @@ mod tests { // the possible combinations are 2,4 or 1,2,3 // fees are expensive, so use 2,4 let list: Vec<_> = - select_coins_bnb(target, Amount::ZERO, fee_rate, lt_fee_rate, &mut weighted_utxos) + select_coins_bnb(target, Amount::ZERO, fee_rate, lt_fee_rate, &weighted_utxos) .unwrap() .collect(); @@ -750,7 +693,7 @@ mod tests { let target = Amount::from_str("1 cBTC").unwrap(); let satisfaction_weight = Weight::from_wu(204); let value = SignedAmount::MAX.to_unsigned().unwrap(); - let mut weighted_utxos = vec![ + let weighted_utxos = vec![ WeightedUtxo { satisfaction_weight, utxo: TxOut { value, script_pubkey: ScriptBuf::new() }, @@ -760,13 +703,8 @@ mod tests { utxo: TxOut { value, script_pubkey: ScriptBuf::new() }, }, ]; - let list = select_coins_bnb( - target, - Amount::ZERO, - FeeRate::ZERO, - FeeRate::ZERO, - &mut weighted_utxos, - ); + let list = + select_coins_bnb(target, Amount::ZERO, FeeRate::ZERO, FeeRate::ZERO, &weighted_utxos); assert!(list.is_none()); } @@ -778,18 +716,13 @@ mod tests { let cost_of_change = Amount::MAX; let satisfaction_weight = Weight::from_wu(204); - let mut weighted_utxos = vec![WeightedUtxo { + let weighted_utxos = vec![WeightedUtxo { satisfaction_weight, utxo: TxOut { value: target, script_pubkey: ScriptBuf::new() }, }]; - let list = select_coins_bnb( - target, - cost_of_change, - FeeRate::ZERO, - FeeRate::ZERO, - &mut weighted_utxos, - ); + let list = + select_coins_bnb(target, cost_of_change, FeeRate::ZERO, FeeRate::ZERO, &weighted_utxos); assert!(list.is_none()); } } diff --git a/src/lib.rs b/src/lib.rs index bfad26b..9e4afd7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -83,7 +83,16 @@ pub fn select_coins( cost_of_change: Amount, fee_rate: FeeRate, long_term_fee_rate: FeeRate, - weighted_utxos: &mut [WeightedUtxo], -) -> Option> { - select_coins_bnb(target, cost_of_change, fee_rate, long_term_fee_rate, weighted_utxos) + weighted_utxos: &[WeightedUtxo], +) -> Option> { + { + let bnb = + select_coins_bnb(target, cost_of_change, fee_rate, long_term_fee_rate, weighted_utxos); + + if bnb.is_some() { + bnb + } else { + select_coins_srd(target, fee_rate, weighted_utxos, &mut thread_rng()) + } + } } diff --git a/src/single_random_draw.rs b/src/single_random_draw.rs index 6347ba7..2d5bdb6 100644 --- a/src/single_random_draw.rs +++ b/src/single_random_draw.rs @@ -28,20 +28,22 @@ use rand::seq::SliceRandom; /// /// * `fee_rate` - ratio of transaction amount per size. /// /// * `weighted_utxos` - Weighted UTXOs from which to sum the target amount. /// /// * `rng` - used primarily by tests to make the selection deterministic. -pub fn select_coins_srd( +pub fn select_coins_srd<'a, R: rand::Rng + ?Sized>( target: Amount, fee_rate: FeeRate, - weighted_utxos: &mut [WeightedUtxo], + weighted_utxos: &'a [WeightedUtxo], rng: &mut R, -) -> Option> { - let mut result: Vec = Vec::new(); +) -> Option> { + let mut result: Vec<_> = weighted_utxos.iter().collect(); + let mut origin = result.to_owned(); + origin.shuffle(rng); - weighted_utxos.shuffle(rng); + result.clear(); let threshold = target + CHANGE_LOWER; let mut value = Amount::ZERO; - for w_utxo in weighted_utxos { + for w_utxo in origin { let utxo_value = w_utxo.utxo.value; let effective_value = effective_value(fee_rate, w_utxo.satisfaction_weight, utxo_value)?; @@ -50,14 +52,14 @@ pub fn select_coins_srd( Err(_) => continue, }; - result.push(w_utxo.clone()); + result.push(w_utxo); if value >= threshold { - return Some(result); + return Some(result.into_iter()); } } - Some(Vec::new()) + None } #[cfg(test)] @@ -113,34 +115,43 @@ mod tests { #[test] fn select_coins_srd_with_solution() { let target: Amount = Amount::from_str("1.5 cBTC").unwrap(); - let mut weighted_utxos: Vec = create_weighted_utxos(); + let weighted_utxos: Vec = create_weighted_utxos(); - let result = select_coins_srd(target, FEE_RATE, &mut weighted_utxos, &mut get_rng()) - .expect("unexpected error"); + let result: Vec<&WeightedUtxo> = + select_coins_srd(target, FEE_RATE, &weighted_utxos, &mut get_rng()) + .expect("unexpected error") + .collect(); - assert_eq!(vec![weighted_utxos[0].clone()], result); + let expected_result = Amount::from_str("2 cBTC").unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(expected_result, result[0].utxo.value); } #[test] fn select_coins_srd_no_solution() { let target: Amount = Amount::from_str("4 cBTC").unwrap(); - let mut weighted_utxos: Vec = create_weighted_utxos(); + let weighted_utxos: Vec = create_weighted_utxos(); - let result = select_coins_srd(target, FEE_RATE, &mut weighted_utxos, &mut get_rng()) - .expect("unexpected error"); - - assert!(result.is_empty()); + let result = select_coins_srd(target, FEE_RATE, &weighted_utxos, &mut get_rng()); + assert!(result.is_none()) } #[test] fn select_coins_srd_all_solution() { let target: Amount = Amount::from_str("2.5 cBTC").unwrap(); - let mut weighted_utxos: Vec = create_weighted_utxos(); + let weighted_utxos: Vec = create_weighted_utxos(); + + let result: Vec<&WeightedUtxo> = + select_coins_srd(target, FeeRate::ZERO, &weighted_utxos, &mut get_rng()) + .expect("unexpected error") + .collect(); - let result = select_coins_srd(target, FeeRate::ZERO, &mut weighted_utxos, &mut get_rng()) - .expect("unexpected error"); + let expected_second_element = Amount::from_str("1 cBTC").unwrap(); + let expected_first_element = Amount::from_str("2 cBTC").unwrap(); - assert_eq!(weighted_utxos.clone(), result); + assert_eq!(result.len(), 2); + assert_eq!(result[0].utxo.value, expected_first_element); + assert_eq!(result[1].utxo.value, expected_second_element); } #[test] @@ -157,32 +168,35 @@ mod tests { }); let mut rng = get_rng(); - let result = select_coins_srd(target, FEE_RATE, &mut weighted_utxos, &mut rng) - .expect("unexpected error"); + let result: Vec<_> = select_coins_srd(target, FEE_RATE, &weighted_utxos, &mut rng) + .expect("unexpected error") + .collect(); - let mut expected_utxos = create_weighted_utxos(); - expected_utxos.shuffle(&mut rng); - assert_eq!(result, expected_utxos); + let expected_second_element = Amount::from_str("1 cBTC").unwrap(); + let expected_first_element = Amount::from_str("2 cBTC").unwrap(); + + assert_eq!(result.len(), 2); + assert_eq!(result[0].utxo.value, expected_first_element); + assert_eq!(result[1].utxo.value, expected_second_element); } #[test] fn select_coins_srd_fee_rate_error() { let target: Amount = Amount::from_str("2 cBTC").unwrap(); - let mut weighted_utxos: Vec = create_weighted_utxos(); + let weighted_utxos: Vec = create_weighted_utxos(); - let result = select_coins_srd(target, FeeRate::MAX, &mut weighted_utxos, &mut get_rng()); + let result = select_coins_srd(target, FeeRate::MAX, &weighted_utxos, &mut get_rng()); assert!(result.is_none()); } #[test] fn select_coins_srd_change_output_too_small() { let target: Amount = Amount::from_str("3 cBTC").unwrap(); - let mut weighted_utxos: Vec = create_weighted_utxos(); + let weighted_utxos: Vec = create_weighted_utxos(); - let result = select_coins_srd(target, FEE_RATE, &mut weighted_utxos, &mut get_rng()) - .expect("unexpected error"); + let result = select_coins_srd(target, FEE_RATE, &weighted_utxos, &mut get_rng()); - assert!(result.is_empty()); + assert!(result.is_none()); } #[test] @@ -195,19 +209,24 @@ mod tests { // fee = 15 sats, since // 40 sat/kwu * (204 + BASE_WEIGHT) = 15 sats let fee_rate: FeeRate = FeeRate::from_sat_per_kwu(40); - let mut weighted_utxos: Vec = create_weighted_utxos(); + let weighted_utxos: Vec = create_weighted_utxos(); - let result = select_coins_srd(target, fee_rate, &mut weighted_utxos, &mut get_rng()) - .expect("unexpected error"); + let result: Vec<_> = select_coins_srd(target, fee_rate, &weighted_utxos, &mut get_rng()) + .expect("unexpected error") + .collect(); + let expected_second_element = Amount::from_str("1 cBTC").unwrap(); + let expected_first_element = Amount::from_str("2 cBTC").unwrap(); - assert_eq!(weighted_utxos.clone(), result); + assert_eq!(result.len(), 2); + assert_eq!(result[0].utxo.value, expected_first_element); + assert_eq!(result[1].utxo.value, expected_second_element); } #[test] fn select_coins_srd_addition_overflow() { let target: Amount = Amount::from_str("2 cBTC").unwrap(); - let mut weighted_utxos: Vec = vec![WeightedUtxo { + let weighted_utxos: Vec = vec![WeightedUtxo { satisfaction_weight: Weight::MAX, utxo: TxOut { value: Amount::from_str("1 cBTC").unwrap(), @@ -215,7 +234,7 @@ mod tests { }, }]; - let result = select_coins_srd(target, FEE_RATE, &mut weighted_utxos, &mut get_rng()); + let result = select_coins_srd(target, FEE_RATE, &weighted_utxos, &mut get_rng()); assert!(result.is_none()); } }