From fd316946086779a5e25194e3c98c994df03c7413 Mon Sep 17 00:00:00 2001 From: "Mayeul@Zama" <69792125+mayeul-zama@users.noreply.github.com> Date: Wed, 20 Nov 2024 15:18:08 +0100 Subject: [PATCH] chore(strings): use FunctionExecutor in tests --- tfhe/src/strings/mod.rs | 142 --- tfhe/src/strings/test_functions/mod.rs | 76 -- .../src/strings/test_functions/test_common.rs | 534 +++++------ .../src/strings/test_functions/test_concat.rs | 223 +++-- .../strings/test_functions/test_contains.rs | 225 ++--- .../test_functions/test_find_replace.rs | 543 ++++++----- tfhe/src/strings/test_functions/test_split.rs | 842 +++++++----------- .../test_functions/test_up_low_case.rs | 225 ++--- .../strings/test_functions/test_whitespace.rs | 254 +++--- 9 files changed, 1252 insertions(+), 1812 deletions(-) diff --git a/tfhe/src/strings/mod.rs b/tfhe/src/strings/mod.rs index 5ca854f09f..9e63d3f7cb 100644 --- a/tfhe/src/strings/mod.rs +++ b/tfhe/src/strings/mod.rs @@ -9,145 +9,3 @@ mod test_functions; // Used as the const argument for StaticUnsignedBigInt, specifying the max chars length of a // ClearString const N: usize = 32; - -#[cfg(test)] -pub(crate) use test::TestKeys; - -#[cfg(test)] -mod test { - use super::ciphertext::FheString; - use super::client_key::EncU16; - use crate::integer::keycache::KEY_CACHE; - use crate::integer::{ClientKey, ServerKey}; - use crate::shortint::parameters::{ - PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64, PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, - }; - use crate::shortint::PBSParameters; - - #[test] - fn test_all() { - for param in [ - PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64, - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, - ] { - test_all_impl( - param, - "a", - Some(1), - "a", - Some(1), - "a", - Some(1), - "a", - Some(1), - 0, - 0, - TestKind::Trivial, - ); - } - } - - #[allow(clippy::too_many_arguments)] - pub fn test_all_impl>( - params: P, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - to: &str, - to_pad: Option, - rhs: &str, - rhs_pad: Option, - n: u16, - max: u16, - test_kind: TestKind, - ) { - let keys = TestKeys::new(params, test_kind); - - keys.check_len_fhe_string_vs_rust_str(str, str_pad); - keys.check_is_empty_fhe_string_vs_rust_str(str, str_pad); - - keys.check_encrypt_decrypt_fhe_string_vs_rust_str(str, str_pad); - - keys.check_contains_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad); - keys.check_ends_with_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad); - keys.check_starts_with_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad); - - keys.check_find_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad); - keys.check_rfind_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad); - - keys.check_strip_prefix_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad); - keys.check_strip_suffix_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad); - - keys.check_eq_ignore_case_fhe_string_vs_rust_str(str, str_pad, rhs, rhs_pad); - keys.check_comp_fhe_string_vs_rust_str(str, str_pad, rhs, rhs_pad); - - keys.check_to_lowercase_fhe_string_vs_rust_str(str, str_pad); - keys.check_to_uppercase_fhe_string_vs_rust_str(str, str_pad); - - keys.check_concat_fhe_string_vs_rust_str(str, str_pad, rhs, rhs_pad); - keys.check_repeat_fhe_string_vs_rust_str(str, str_pad, n, max); - - keys.check_trim_end_fhe_string_vs_rust_str(str, str_pad); - keys.check_trim_start_fhe_string_vs_rust_str(str, str_pad); - keys.check_trim_fhe_string_vs_rust_str(str, str_pad); - keys.check_split_ascii_whitespace_fhe_string_vs_rust_str(str, str_pad); - - keys.check_split_once_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad); - keys.check_rsplit_once_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad); - - keys.check_split_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad); - keys.check_rsplit_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad); - - keys.check_split_terminator_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad); - keys.check_rsplit_terminator_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad); - keys.check_split_inclusive_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad); - - keys.check_splitn_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad, n, max); - keys.check_rsplitn_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad, n, max); - - keys.check_replace_fhe_string_vs_rust_str(str, str_pad, pat, pat_pad, to, to_pad); - keys.check_replacen_fhe_string_vs_rust_str( - (str, str_pad), - (pat, pat_pad), - (to, to_pad), - n, - max, - ); - } - - pub(crate) struct TestKeys { - pub ck: ClientKey, - pub sk: ServerKey, - pub test_kind: TestKind, - } - - pub enum TestKind { - Trivial, - Encrypted, - } - - impl TestKeys { - pub fn new>(params: P, test_kind: TestKind) -> Self { - let (ck, sk) = KEY_CACHE.get_from_params(params, crate::integer::IntegerKeyKind::Radix); - - Self { ck, sk, test_kind } - } - - pub fn encrypt_string(&self, str: &str, padding: Option) -> FheString { - match self.test_kind { - TestKind::Trivial => FheString::new_trivial(&self.ck, str, padding), - TestKind::Encrypted => FheString::new(&self.ck, str, padding), - } - } - - pub fn encrypt_u16(&self, val: u16, max: Option) -> EncU16 { - match self.test_kind { - TestKind::Trivial => self.ck.trivial_encrypt_u16(val, max), - TestKind::Encrypted => self.ck.encrypt_u16(val, max), - } - } - } -} diff --git a/tfhe/src/strings/test_functions/mod.rs b/tfhe/src/strings/test_functions/mod.rs index ff83d21cad..88c52ff01c 100644 --- a/tfhe/src/strings/test_functions/mod.rs +++ b/tfhe/src/strings/test_functions/mod.rs @@ -5,79 +5,3 @@ mod test_find_replace; mod test_split; mod test_up_low_case; mod test_whitespace; - -use std::time::Duration; - -fn result_message(str: &str, expected: T, dec: T, dur: Duration) -where - T: std::fmt::Debug, -{ - println!( - "\x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mString: \x1b[0m\x1b[0;33m{str:?}\x1b[0m\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{expected:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{dec:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{dur:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - ); -} - -fn result_message_pat(str: &str, pat: &str, expected: T, dec: T, dur: Duration) -where - T: std::fmt::Debug, -{ - println!( - "\x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mString: \x1b[0m\x1b[0;33m{str:?}\x1b[0m\n\ - \x1b[1;32;1mPattern: \x1b[0m\x1b[0;33m{pat:?}\x1b[0m\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{expected:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{dec:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{dur:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - ); -} - -fn result_message_clear_pat(str: &str, pat: &str, expected: T, dec: T, dur: Duration) -where - T: std::fmt::Debug, -{ - println!( - "\x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mString: \x1b[0m\x1b[0;33m{str:?}\x1b[0m\n\ - \x1b[1;32;1mPattern (clear): \x1b[0m\x1b[0;33m{pat:?}\x1b[0m\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{expected:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{dec:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{dur:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - ); -} - -fn result_message_rhs(str: &str, pat: &str, expected: T, dec: T, dur: Duration) -where - T: std::fmt::Debug, -{ - println!( - "\x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mLhs: \x1b[0m\x1b[0;33m{str:?}\x1b[0m\n\ - \x1b[1;32;1mRhs: \x1b[0m\x1b[0;33m{pat:?}\x1b[0m\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{expected:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{dec:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{dur:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - ); -} - -fn result_message_clear_rhs(str: &str, pat: &str, expected: T, dec: T, dur: Duration) -where - T: std::fmt::Debug, -{ - println!( - "\x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mLhs: \x1b[0m\x1b[0;33m{str:?}\x1b[0m\n\ - \x1b[1;32;1mRhs (clear): \x1b[0m\x1b[0;33m{pat:?}\x1b[0m\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{expected:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{dec:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{dur:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - ); -} diff --git a/tfhe/src/strings/test_functions/test_common.rs b/tfhe/src/strings/test_functions/test_common.rs index 8b7ede655d..7c82b2e885 100644 --- a/tfhe/src/strings/test_functions/test_common.rs +++ b/tfhe/src/strings/test_functions/test_common.rs @@ -1,128 +1,94 @@ -use crate::integer::{BooleanBlock, ServerKey}; +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_unsigned::{CpuFunctionExecutor, NotTuple}; +use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; +use crate::shortint::PBSParameters; use crate::strings::ciphertext::{ClearString, FheString, GenericPattern, GenericPatternRef}; use crate::strings::server_key::{FheStringIsEmpty, FheStringLen}; -use crate::strings::test::TestKind; -use crate::strings::test_functions::{ - result_message, result_message_clear_pat, result_message_clear_rhs, result_message_pat, - result_message_rhs, -}; -use crate::strings::TestKeys; -use std::time::{Duration, Instant}; +use std::sync::Arc; #[test] -fn test_len_is_empty_trivial() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Trivial, - ); +fn test_encrypt_decrypt_parameterized() { + test_encrypt_decrypt(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} + +fn test_encrypt_decrypt

(param: P) +where + P: Into, +{ + let (cks, _sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); for str in ["", "a", "abc"] { for pad in 0..3 { - keys.check_len_fhe_string_vs_rust_str(str, Some(pad)); - keys.check_is_empty_fhe_string_vs_rust_str(str, Some(pad)); + let enc_str = FheString::new(&cks, str, Some(pad)); + + let dec = cks.decrypt_ascii(&enc_str); + + assert_eq!(str, &dec); } } } #[test] -fn test_len_is_empty() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Encrypted, - ); - - keys.check_len_fhe_string_vs_rust_str("", Some(1)); - keys.check_is_empty_fhe_string_vs_rust_str("", Some(1)); - - keys.check_len_fhe_string_vs_rust_str("abc", Some(1)); - keys.check_is_empty_fhe_string_vs_rust_str("abc", Some(1)); +fn string_is_empty_test_parameterized() { + string_is_empty_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } -#[test] -fn test_encrypt_decrypt_trivial() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Trivial, - ); +impl NotTuple for &FheString {} - for str in ["", "a", "abc"] { - for pad in 0..3 { - keys.check_encrypt_decrypt_fhe_string_vs_rust_str(str, Some(pad)); - } - } +#[allow(clippy::needless_pass_by_value)] +fn string_is_empty_test

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::is_empty); + string_is_empty_test_impl(param, executor); } -#[test] -fn test_encrypt_decrypt() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Encrypted, - ); +pub(crate) fn string_is_empty_test_impl(param: P, mut is_empty_executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a FheString, FheStringIsEmpty>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); + + is_empty_executor.setup(&cks2, sks); + // trivial for str in ["", "a", "abc"] { for pad in 0..3 { - keys.check_encrypt_decrypt_fhe_string_vs_rust_str(str, Some(pad)); - } - } -} + let expected_result = str.is_empty(); -#[test] -fn test_strip_trivial() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Trivial, - ); + let enc_str = FheString::new_trivial(&cks, str, Some(pad)); - for str_pad in 0..2 { - for pat_pad in 0..2 { - for pat in ["", "a", "abc"] { - for str in ["", "a", "abc", "b", "ab", "dddabc", "abceeee", "dddabceee"] { - keys.check_strip_prefix_fhe_string_vs_rust_str( - str, - Some(str_pad), - pat, - Some(pat_pad), - ); - keys.check_strip_suffix_fhe_string_vs_rust_str( - str, - Some(str_pad), - pat, - Some(pat_pad), - ); + let result = is_empty_executor.execute(&enc_str); + + match result { + FheStringIsEmpty::NoPadding(result) => assert_eq!(result, expected_result), + FheStringIsEmpty::Padding(result) => { + assert_eq!(cks.decrypt_bool(&result), expected_result) } } } } -} + // encrypted + { + let pad = 1; -#[test] -fn test_strip() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Encrypted, - ); - keys.check_strip_prefix_fhe_string_vs_rust_str("abc", Some(1), "a", Some(1)); - keys.check_strip_suffix_fhe_string_vs_rust_str("abc", Some(1), "c", Some(1)); - - keys.check_strip_prefix_fhe_string_vs_rust_str("abc", Some(1), "d", Some(1)); - keys.check_strip_suffix_fhe_string_vs_rust_str("abc", Some(1), "d", Some(1)); -} + for str in ["", "abc"] { + let expected_result = str.is_empty(); -const TEST_CASES_COMP: [&str; 5] = ["", "a", "aa", "ab", "abc"]; + let enc_str = FheString::new(&cks, str, Some(pad)); -#[test] -fn test_comparisons_trivial() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Trivial, - ); + let result = is_empty_executor.execute(&enc_str); - for str_pad in 0..2 { - for rhs_pad in 0..2 { - for str in TEST_CASES_COMP { - for rhs in TEST_CASES_COMP { - keys.check_comp_fhe_string_vs_rust_str(str, Some(str_pad), rhs, Some(rhs_pad)); + match result { + FheStringIsEmpty::NoPadding(result) => assert_eq!(result, expected_result), + FheStringIsEmpty::Padding(result) => { + assert_eq!(cks.decrypt_bool(&result), expected_result) } } } @@ -130,221 +96,255 @@ fn test_comparisons_trivial() { } #[test] -fn test_comparisons() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Encrypted, - ); - - keys.check_comp_fhe_string_vs_rust_str("a", Some(1), "a", Some(1)); +fn string_len_test_parameterized() { + string_len_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} - keys.check_comp_fhe_string_vs_rust_str("a", Some(1), "b", Some(1)); +#[allow(clippy::needless_pass_by_value)] +fn string_len_test

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::len); + string_len_test_impl(param, executor); } -impl TestKeys { - pub fn check_len_fhe_string_vs_rust_str(&self, str: &str, str_pad: Option) { - let expected = str.len(); +pub(crate) fn string_len_test_impl(param: P, mut len_executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a FheString, FheStringLen>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); - let enc_str = self.encrypt_string(str, str_pad); + len_executor.setup(&cks2, sks); - let start = Instant::now(); - let result = self.sk.len(&enc_str); - let end = Instant::now(); + // trivial + for str in ["", "a", "abc"] { + for pad in 0..3 { + let expected_result = str.len(); - let dec = match result { - FheStringLen::NoPadding(clear_len) => clear_len, - FheStringLen::Padding(enc_len) => self.ck.decrypt_radix::(&enc_len) as usize, - }; + let enc_str = FheString::new_trivial(&cks, str, Some(pad)); - println!("\n\x1b[1mLen:\x1b[0m"); - result_message(str, expected, dec, end.duration_since(start)); + let result = len_executor.execute(&enc_str); - assert_eq!(dec, expected); + match result { + FheStringLen::NoPadding(result) => { + assert_eq!(result, expected_result) + } + FheStringLen::Padding(result) => { + assert_eq!(cks.decrypt_radix::(&result), expected_result as u16) + } + } + } } + // encrypted + { + let pad = 1; - pub fn check_is_empty_fhe_string_vs_rust_str(&self, str: &str, str_pad: Option) { - let expected = str.is_empty(); + for str in ["", "abc"] { + let expected_result = str.len(); - let enc_str = self.encrypt_string(str, str_pad); + let enc_str = FheString::new(&cks, str, Some(pad)); - let start = Instant::now(); - let result = self.sk.is_empty(&enc_str); - let end = Instant::now(); + let result = len_executor.execute(&enc_str); - let dec = match result { - FheStringIsEmpty::NoPadding(clear_len) => clear_len, - FheStringIsEmpty::Padding(enc_len) => self.ck.decrypt_bool(&enc_len), - }; + match result { + FheStringLen::NoPadding(result) => { + assert_eq!(result, expected_result) + } + FheStringLen::Padding(result) => { + assert_eq!(cks.decrypt_radix::(&result), expected_result as u64) + } + } + } + } +} - println!("\n\x1b[1mIs_empty:\x1b[0m"); - result_message(str, expected, dec, end.duration_since(start)); +#[test] +fn string_strip_test_parameterized() { + string_strip_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} - assert_eq!(dec, expected); +#[allow(clippy::needless_pass_by_value)] +fn string_strip_test

(param: P) +where + P: Into, +{ + #[allow(clippy::type_complexity)] + let ops: [( + for<'a> fn(&'a str, &'a str) -> Option<&'a str>, + fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> (FheString, BooleanBlock), + ); 2] = [ + (|lhs, rhs| lhs.strip_prefix(rhs), ServerKey::strip_prefix), + (|lhs, rhs| lhs.strip_suffix(rhs), ServerKey::strip_suffix), + ]; + + let param = param.into(); + + for (clear_op, encrypted_op) in ops { + let executor = CpuFunctionExecutor::new(&encrypted_op); + string_strip_test_impl(param, executor, clear_op); } +} - pub fn check_encrypt_decrypt_fhe_string_vs_rust_str(&self, str: &str, str_pad: Option) { - let enc_str = self.encrypt_string(str, str_pad); +pub(crate) fn string_strip_test_impl( + param: P, + mut strip_executor: T, + clear_function: for<'a> fn(&'a str, &'a str) -> Option<&'a str>, +) where + P: Into, + T: for<'a> FunctionExecutor<(&'a FheString, GenericPatternRef<'a>), (FheString, BooleanBlock)>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); - let dec = self.ck.decrypt_ascii(&enc_str); + strip_executor.setup(&cks2, sks); - println!("\n\x1b[1mEncrypt/Decrypt:\x1b[0m"); - result_message(str, str, &dec, Duration::from_nanos(0)); + let assert_result = |expected_result: (&str, bool), result: (FheString, BooleanBlock)| { + assert_eq!(expected_result.1, cks.decrypt_bool(&result.1)); - assert_eq!(str, &dec); - } + assert_eq!(expected_result.0, cks.decrypt_ascii(&result.0)); + }; - pub fn check_strip_prefix_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - ) { - let expected = str.strip_prefix(pat); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - let clear_pat = GenericPattern::Clear(ClearString::new(pat.to_string())); - - let start = Instant::now(); - let (result, is_some) = self.sk.strip_prefix(&enc_str, enc_pat.as_ref()); - let end = Instant::now(); - - let dec_result = self.ck.decrypt_ascii(&result); - let dec_is_some = self.ck.decrypt_bool(&is_some); - if !dec_is_some { - // When it's None, the FheString returned is the original str - assert_eq!(dec_result, str); - } + // trivial + for str_pad in 0..2 { + for pat_pad in 0..2 { + for pat in ["", "a", "abc"] { + for str in ["", "a", "abc", "b", "ab", "dddabc", "abceeee", "dddabceee"] { + let expected_result = + clear_function(str, pat).map_or((str, false), |str| (str, true)); + + let enc_lhs = FheString::new_trivial(&cks, str, Some(str_pad)); + let enc_rhs = + GenericPattern::Enc(FheString::new_trivial(&cks, pat, Some(pat_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(pat.to_string())); - let dec = dec_is_some.then_some(dec_result.as_str()); + for rhs in [enc_rhs, clear_rhs] { + let result = strip_executor.execute((&enc_lhs, rhs.as_ref())); + + assert_result(expected_result, result); + } + } + } + } + } + // encrypted + { + let str = "abc"; + let str_pad = 1; + let rhs_pad = 1; - println!("\n\x1b[1mStrip_prefix:\x1b[0m"); - result_message_pat(str, pat, expected, dec, end.duration_since(start)); + for rhs in ["a", "c", "d"] { + let expected_result = clear_function(str, rhs).map_or((str, false), |str| (str, true)); - assert_eq!(dec, expected); + let enc_lhs = FheString::new(&cks, str, Some(str_pad)); + let enc_rhs = GenericPattern::Enc(FheString::new(&cks, rhs, Some(rhs_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(rhs.to_string())); - let start = Instant::now(); - let (result, is_some) = self.sk.strip_prefix(&enc_str, clear_pat.as_ref()); - let end = Instant::now(); + for rhs in [enc_rhs, clear_rhs] { + let result = strip_executor.execute((&enc_lhs, rhs.as_ref())); - let dec_result = self.ck.decrypt_ascii(&result); - let dec_is_some = self.ck.decrypt_bool(&is_some); - if !dec_is_some { - // When it's None, the FheString returned is the original str - assert_eq!(dec_result, str); + assert_result(expected_result, result); + } } + } +} - let dec = dec_is_some.then_some(dec_result.as_str()); +const TEST_CASES_COMP: [&str; 5] = ["", "a", "aa", "ab", "abc"]; - println!("\n\x1b[1mStrip_prefix:\x1b[0m"); - result_message_clear_pat(str, pat, expected, dec, end.duration_since(start)); +#[test] +fn string_comp_test_parameterized() { + string_comp_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} - assert_eq!(dec, expected); +#[allow(clippy::needless_pass_by_value)] +fn string_comp_test

(param: P) +where + P: Into, +{ + #[allow(clippy::type_complexity)] + let ops: [( + fn(&str, &str) -> bool, + fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> BooleanBlock, + ); 6] = [ + (|lhs, rhs| lhs == rhs, ServerKey::string_eq), + (|lhs, rhs| lhs != rhs, ServerKey::string_ne), + (|lhs, rhs| lhs >= rhs, ServerKey::string_ge), + (|lhs, rhs| lhs <= rhs, ServerKey::string_le), + (|lhs, rhs| lhs > rhs, ServerKey::string_gt), + (|lhs, rhs| lhs < rhs, ServerKey::string_lt), + ]; + + let param = param.into(); + + for (clear_op, encrypted_op) in ops { + let executor = CpuFunctionExecutor::new(&encrypted_op); + string_comp_test_impl(param, executor, clear_op); } +} - pub fn check_strip_suffix_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - ) { - let expected = str.strip_suffix(pat); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - let clear_pat = GenericPattern::Clear(ClearString::new(pat.to_string())); - - let start = Instant::now(); - let (result, is_some) = self.sk.strip_suffix(&enc_str, enc_pat.as_ref()); - let end = Instant::now(); - - let dec_result = self.ck.decrypt_ascii(&result); - let dec_is_some = self.ck.decrypt_bool(&is_some); - if !dec_is_some { - // When it's None, the FheString returned is the original str - assert_eq!(dec_result, str); - } +pub(crate) fn string_comp_test_impl( + param: P, + mut comp_executor: T, + clear_function: fn(&str, &str) -> bool, +) where + P: Into, + T: for<'a> FunctionExecutor<(&'a FheString, GenericPatternRef<'a>), BooleanBlock>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); - let dec = dec_is_some.then_some(dec_result.as_str()); + let assert_result = |expected_result, result: BooleanBlock| { + let dec_result = cks.decrypt_bool(&result); - println!("\n\x1b[1mStrip_suffix:\x1b[0m"); - result_message_pat(str, pat, expected, dec, end.duration_since(start)); + assert_eq!(dec_result, expected_result); + }; - assert_eq!(dec, expected); + comp_executor.setup(&cks2, sks); - let start = Instant::now(); - let (result, is_some) = self.sk.strip_suffix(&enc_str, clear_pat.as_ref()); - let end = Instant::now(); + // trivial + for str_pad in 0..2 { + for rhs_pad in 0..2 { + for str in TEST_CASES_COMP { + for rhs in TEST_CASES_COMP { + let expected_result = clear_function(str, rhs); + + let enc_lhs = FheString::new_trivial(&cks, str, Some(str_pad)); + let enc_rhs = + GenericPattern::Enc(FheString::new_trivial(&cks, rhs, Some(rhs_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(rhs.to_string())); - let dec_result = self.ck.decrypt_ascii(&result); - let dec_is_some = self.ck.decrypt_bool(&is_some); - if !dec_is_some { - // When it's None, the FheString returned is the original str - assert_eq!(dec_result, str); + for rhs in [enc_rhs, clear_rhs] { + let result = comp_executor.execute((&enc_lhs, rhs.as_ref())); + + assert_result(expected_result, result); + } + } + } } + } + // encrypted + { + let str = "a"; + let str_pad = 1; + let rhs_pad = 1; - let dec = dec_is_some.then_some(dec_result.as_str()); + for rhs in ["a", "b"] { + let expected_result = clear_function(str, rhs); - println!("\n\x1b[1mStrip_suffix:\x1b[0m"); - result_message_clear_pat(str, pat, expected, dec, end.duration_since(start)); + let enc_lhs = FheString::new(&cks, str, Some(str_pad)); + let enc_rhs = GenericPattern::Enc(FheString::new(&cks, rhs, Some(rhs_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(rhs.to_string())); - assert_eq!(dec, expected); - } + for rhs in [enc_rhs, clear_rhs] { + let result = comp_executor.execute((&enc_lhs, rhs.as_ref())); - pub fn check_comp_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - rhs: &str, - rhs_pad: Option, - ) { - let enc_lhs = self.encrypt_string(str, str_pad); - let enc_rhs = GenericPattern::Enc(self.encrypt_string(rhs, rhs_pad)); - let clear_rhs = GenericPattern::Clear(ClearString::new(rhs.to_string())); - - #[allow(clippy::type_complexity)] - let ops: [( - bool, - fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> BooleanBlock, - ); 6] = [ - (str == rhs, ServerKey::string_eq), - (str != rhs, ServerKey::string_ne), - (str >= rhs, ServerKey::string_ge), - (str <= rhs, ServerKey::string_le), - (str > rhs, ServerKey::string_gt), - (str < rhs, ServerKey::string_lt), - ]; - - for (expected_result, encrypted_op) in ops { - // Encrypted rhs - let start = Instant::now(); - let result = encrypted_op(&self.sk, &enc_lhs, enc_rhs.as_ref()); - let end = Instant::now(); - - let dec_result = self.ck.decrypt_bool(&result); - - println!("\n\x1b[1mEq:\x1b[0m"); - result_message_rhs( - str, - rhs, - expected_result, - dec_result, - end.duration_since(start), - ); - assert_eq!(dec_result, expected_result); - - // Clear rhs - let start = Instant::now(); - let result_eq = encrypted_op(&self.sk, &enc_lhs, clear_rhs.as_ref()); - let end = Instant::now(); - - let dec_eq = self.ck.decrypt_bool(&result_eq); - - println!("\n\x1b[1mEq:\x1b[0m"); - result_message_clear_rhs(str, rhs, expected_result, dec_eq, end.duration_since(start)); - assert_eq!(dec_eq, expected_result); + assert_result(expected_result, result); + } } } } diff --git a/tfhe/src/strings/test_functions/test_concat.rs b/tfhe/src/strings/test_functions/test_concat.rs index 20686f021d..6dc325b1b0 100644 --- a/tfhe/src/strings/test_functions/test_concat.rs +++ b/tfhe/src/strings/test_functions/test_concat.rs @@ -1,157 +1,140 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; +use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; -use crate::strings::ciphertext::UIntArg; -use crate::strings::test::TestKind; -use crate::strings::test_functions::result_message_rhs; -use crate::strings::TestKeys; -use std::time::Instant; +use crate::shortint::PBSParameters; +use crate::strings::ciphertext::{FheString, UIntArg}; +use std::sync::Arc; const TEST_CASES_CONCAT: [&str; 5] = ["", "a", "ab", "abc", "abcd"]; #[test] -fn test_concat_trivial() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Trivial, - ); +fn string_concat_test_parameterized() { + string_concat_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} + +#[allow(clippy::needless_pass_by_value)] +fn string_concat_test

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::concat); + string_concat_test_impl(param, executor); +} + +pub(crate) fn string_concat_test_impl(param: P, mut concat_executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a FheString, &'a FheString), FheString>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); + concat_executor.setup(&cks2, sks); + + // trivial for str_pad in 0..2 { for rhs_pad in 0..2 { for str in TEST_CASES_CONCAT { for rhs in TEST_CASES_CONCAT { - keys.check_concat_fhe_string_vs_rust_str( - str, - Some(str_pad), - rhs, - Some(rhs_pad), - ); + let expected_result = str.to_owned() + rhs; + + let enc_lhs = FheString::new_trivial(&cks, str, Some(str_pad)); + let enc_rhs = FheString::new_trivial(&cks, rhs, Some(rhs_pad)); + + let result = concat_executor.execute((&enc_lhs, &enc_rhs)); + + assert_eq!(expected_result, cks.decrypt_ascii(&result)); } } } } + // encrypted + { + let str = "a"; + let str_pad = 1; + let rhs = "b"; + let rhs_pad = 1; + + let expected_result = str.to_owned() + rhs; + + let enc_lhs = FheString::new(&cks, str, Some(str_pad)); + let enc_rhs = FheString::new(&cks, rhs, Some(rhs_pad)); + + let result = concat_executor.execute((&enc_lhs, &enc_rhs)); + + assert_eq!(expected_result, cks.decrypt_ascii(&result)); + } } #[test] -fn test_concat() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Encrypted, - ); +fn string_repeat_test_parameterized() { + string_repeat_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} - keys.check_concat_fhe_string_vs_rust_str("a", Some(1), "b", Some(1)); +#[allow(clippy::needless_pass_by_value)] +fn string_repeat_test

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::repeat); + string_repeat_test_impl(param, executor); } -#[test] -fn test_repeat_trivial() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Trivial, - ); +pub(crate) fn string_repeat_test_impl(param: P, mut repeat_executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a FheString, &'a UIntArg), FheString>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); + repeat_executor.setup(&cks2, sks); + + // trivial for str_pad in 0..2 { for n in 0..3 { for str in TEST_CASES_CONCAT { for max in n..n + 2 { - keys.check_repeat_fhe_string_vs_rust_str(str, Some(str_pad), n, max); + let expected_result = str.repeat(n as usize); + + let enc_str = FheString::new_trivial(&cks, str, Some(str_pad)); + + let enc_n = UIntArg::Enc(cks.trivial_encrypt_u16(n, Some(max))); + + let clear_n = UIntArg::Clear(n); + + for n in [clear_n, enc_n] { + let result = repeat_executor.execute((&enc_str, &n)); + + assert_eq!(expected_result, cks.decrypt_ascii(&result)); + } } } } } -} + // encrypted + { + let str = "a"; + let str_pad = 1; + let n = 1; + let max = 2; -#[test] -fn test_repeat() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Encrypted, - ); - - keys.check_repeat_fhe_string_vs_rust_str("a", Some(1), 1, 2); -} - -impl TestKeys { - pub fn check_concat_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - rhs: &str, - rhs_pad: Option, - ) { - let expected = str.to_owned() + rhs; + let expected_result = str.repeat(n as usize); - let enc_lhs = self.encrypt_string(str, str_pad); - let enc_rhs = self.encrypt_string(rhs, rhs_pad); + let enc_str = FheString::new(&cks, str, Some(str_pad)); - let start = Instant::now(); - let result = self.sk.concat(&enc_lhs, &enc_rhs); - let end = Instant::now(); + let enc_n = UIntArg::Enc(cks.encrypt_u16(n, Some(max))); - let dec = self.ck.decrypt_ascii(&result); + let clear_n = UIntArg::Clear(n); - println!("\n\x1b[1mConcat (+):\x1b[0m"); - result_message_rhs(str, rhs, &expected, &dec, end.duration_since(start)); + for n in [clear_n, enc_n] { + let result = repeat_executor.execute((&enc_str, &n)); - assert_eq!(dec, expected); - } - - pub fn check_repeat_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - n: u16, - max: u16, - ) { - let expected = str.repeat(n as usize); - - let enc_str = self.encrypt_string(str, str_pad); - - // Clear n - let start = Instant::now(); - let result = self.sk.repeat(&enc_str, &UIntArg::Clear(n)); - let end = Instant::now(); - - let dec = self.ck.decrypt_ascii(&result); - - println!( - "\n\x1b[1mRepeat:\x1b[0m\n\ - \x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mString: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mTimes (clear): \x1b[0m{}\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - str, - n, - expected, - dec, - end.duration_since(start), - ); - assert_eq!(dec, expected); - - // Encrypted n - let enc_n = self.encrypt_u16(n, Some(max)); - - let start = Instant::now(); - let result = self.sk.repeat(&enc_str, &UIntArg::Enc(enc_n)); - let end = Instant::now(); - - let dec = self.ck.decrypt_ascii(&result); - - println!( - "\n\x1b[1mRepeat:\x1b[0m\n\ - \x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mString: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mTimes (encrypted): \x1b[0m{}\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - str, - n, - expected, - dec, - end.duration_since(start), - ); - assert_eq!(dec, expected); + assert_eq!(expected_result, cks.decrypt_ascii(&result)); + } } } diff --git a/tfhe/src/strings/test_functions/test_contains.rs b/tfhe/src/strings/test_functions/test_contains.rs index 7df66def91..091e4d1b7d 100644 --- a/tfhe/src/strings/test_functions/test_contains.rs +++ b/tfhe/src/strings/test_functions/test_contains.rs @@ -1,168 +1,93 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; +use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; -use crate::strings::ciphertext::{ClearString, GenericPattern}; -use crate::strings::test::TestKind; -use crate::strings::test_functions::{result_message_clear_pat, result_message_pat}; -use crate::strings::TestKeys; -use std::time::Instant; +use crate::shortint::PBSParameters; +use crate::strings::ciphertext::{ClearString, FheString, GenericPattern, GenericPatternRef}; +use std::sync::Arc; #[test] -fn test_contains_start_end_trivial() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Trivial, - ); +fn string_contains_test_parameterized() { + string_contains_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} + +#[allow(clippy::needless_pass_by_value)] +fn string_contains_test

(param: P) +where + P: Into, +{ + #[allow(clippy::type_complexity)] + let ops: [( + for<'a> fn(&'a str, &'a str) -> bool, + fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> BooleanBlock, + ); 3] = [ + (|lhs, rhs| lhs.contains(rhs), ServerKey::contains), + (|lhs, rhs| lhs.starts_with(rhs), ServerKey::starts_with), + (|lhs, rhs| lhs.ends_with(rhs), ServerKey::ends_with), + ]; + + let param = param.into(); + + for (clear_op, encrypted_op) in ops { + let executor = CpuFunctionExecutor::new(&encrypted_op); + string_contains_test_impl(param, executor, clear_op); + } +} +pub(crate) fn string_contains_test_impl( + param: P, + mut contains_executor: T, + clear_function: for<'a> fn(&'a str, &'a str) -> bool, +) where + P: Into, + T: for<'a> FunctionExecutor<(&'a FheString, GenericPatternRef<'a>), BooleanBlock>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); + + contains_executor.setup(&cks2, sks); + + // trivial for str_pad in 0..2 { for pat_pad in 0..2 { for str in ["", "a", "abc", "b", "ab", "dddabc", "abceeee", "dddabceee"] { for pat in ["", "a", "abc"] { - keys.check_contains_fhe_string_vs_rust_str( - str, - Some(str_pad), - pat, - Some(pat_pad), - ); - keys.check_starts_with_fhe_string_vs_rust_str( - str, - Some(str_pad), - pat, - Some(pat_pad), - ); - keys.check_ends_with_fhe_string_vs_rust_str( - str, - Some(str_pad), - pat, - Some(pat_pad), - ); - } - } - } - } -} - -#[test] -fn test_contains_start_end() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Encrypted, - ); - - keys.check_contains_fhe_string_vs_rust_str("ab", Some(1), "a", Some(1)); - keys.check_contains_fhe_string_vs_rust_str("ab", Some(1), "c", Some(1)); - - keys.check_starts_with_fhe_string_vs_rust_str("ab", Some(1), "a", Some(1)); - keys.check_starts_with_fhe_string_vs_rust_str("ab", Some(1), "c", Some(1)); - - keys.check_ends_with_fhe_string_vs_rust_str("ab", Some(1), "b", Some(1)); - keys.check_ends_with_fhe_string_vs_rust_str("ab", Some(1), "c", Some(1)); -} - -impl TestKeys { - pub fn check_contains_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - ) { - let expected = str.contains(pat); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - let clear_pat = GenericPattern::Clear(ClearString::new(pat.to_string())); - - let start = Instant::now(); - let result = self.sk.contains(&enc_str, enc_pat.as_ref()); - let end = Instant::now(); - - let dec = self.ck.decrypt_bool(&result); - - println!("\n\x1b[1mContains:\x1b[0m"); - result_message_pat(str, pat, expected, dec, end.duration_since(start)); - - assert_eq!(dec, expected); - - let start = Instant::now(); - let result = self.sk.contains(&enc_str, clear_pat.as_ref()); - let end = Instant::now(); + let expected_result = clear_function(str, pat); - let dec = self.ck.decrypt_bool(&result); + let enc_lhs = FheString::new_trivial(&cks, str, Some(str_pad)); + let enc_rhs = + GenericPattern::Enc(FheString::new_trivial(&cks, pat, Some(pat_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(pat.to_string())); - println!("\n\x1b[1mContains:\x1b[0m"); - result_message_clear_pat(str, pat, expected, dec, end.duration_since(start)); + for rhs in [enc_rhs, clear_rhs] { + let result = contains_executor.execute((&enc_lhs, rhs.as_ref())); - assert_eq!(dec, expected); - } - - pub fn check_ends_with_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - ) { - let expected = str.ends_with(pat); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - let clear_pat = GenericPattern::Clear(ClearString::new(pat.to_string())); - - let start = Instant::now(); - let result = self.sk.ends_with(&enc_str, enc_pat.as_ref()); - let end = Instant::now(); - - let dec = self.ck.decrypt_bool(&result); - - println!("\n\x1b[1mEnds_with:\x1b[0m"); - result_message_pat(str, pat, expected, dec, end.duration_since(start)); - - assert_eq!(dec, expected); - - let start = Instant::now(); - let result = self.sk.ends_with(&enc_str, clear_pat.as_ref()); - let end = Instant::now(); - - let dec = self.ck.decrypt_bool(&result); - - println!("\n\x1b[1mEnds_with:\x1b[0m"); - result_message_clear_pat(str, pat, expected, dec, end.duration_since(start)); - - assert_eq!(dec, expected); + assert_eq!(expected_result, cks.decrypt_bool(&result)); + } + } + } + } } + // encrypted + { + let str = "ab"; + let str_pad = 1; + let rhs_pad = 1; - pub fn check_starts_with_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - ) { - let expected = str.starts_with(pat); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - let clear_pat = GenericPattern::Clear(ClearString::new(pat.to_string())); - - let start = Instant::now(); - let result = self.sk.starts_with(&enc_str, enc_pat.as_ref()); - let end = Instant::now(); - - let dec = self.ck.decrypt_bool(&result); + for rhs in ["a", "b", "c"] { + let expected_result = clear_function(str, rhs); - println!("\n\x1b[1mStarts_with:\x1b[0m"); - result_message_pat(str, pat, expected, dec, end.duration_since(start)); + let enc_lhs = FheString::new(&cks, str, Some(str_pad)); + let enc_rhs = GenericPattern::Enc(FheString::new(&cks, rhs, Some(rhs_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(rhs.to_string())); - assert_eq!(dec, expected); + for rhs in [enc_rhs, clear_rhs] { + let result = contains_executor.execute((&enc_lhs, rhs.as_ref())); - let start = Instant::now(); - let result = self.sk.starts_with(&enc_str, clear_pat.as_ref()); - let end = Instant::now(); - - let dec = self.ck.decrypt_bool(&result); - - println!("\n\x1b[1mStarts_with:\x1b[0m"); - result_message_clear_pat(str, pat, expected, dec, end.duration_since(start)); - - assert_eq!(dec, expected); + assert_eq!(expected_result, cks.decrypt_bool(&result)); + } + } } } diff --git a/tfhe/src/strings/test_functions/test_find_replace.rs b/tfhe/src/strings/test_functions/test_find_replace.rs index 22e929a9ef..63d3dfd333 100644 --- a/tfhe/src/strings/test_functions/test_find_replace.rs +++ b/tfhe/src/strings/test_functions/test_find_replace.rs @@ -1,78 +1,167 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; +use crate::integer::{BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; -use crate::strings::ciphertext::{ClearString, GenericPattern, UIntArg}; -use crate::strings::test::TestKind; -use crate::strings::test_functions::{result_message_clear_pat, result_message_pat}; -use crate::strings::TestKeys; -use std::time::Instant; +use crate::shortint::PBSParameters; +use crate::strings::ciphertext::{ + ClearString, FheString, GenericPattern, GenericPatternRef, UIntArg, +}; +use std::sync::Arc; const TEST_CASES_FIND: [&str; 8] = ["", "a", "abc", "b", "ab", "dabc", "abce", "dabce"]; const PATTERN_FIND: [&str; 5] = ["", "a", "b", "ab", "abc"]; #[test] -fn test_find_trivial() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Trivial, - ); +fn string_find_test_parameterized() { + string_find_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} +#[allow(clippy::needless_pass_by_value)] +fn string_find_test

(param: P) +where + P: Into, +{ + #[allow(clippy::type_complexity)] + let ops: [( + for<'a> fn(&'a str, &'a str) -> Option, + fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> (RadixCiphertext, BooleanBlock), + ); 2] = [ + (|lhs, rhs| lhs.find(rhs), ServerKey::find), + (|lhs, rhs| lhs.rfind(rhs), ServerKey::rfind), + ]; + + let param = param.into(); + + for (clear_op, encrypted_op) in ops { + let executor = CpuFunctionExecutor::new(&encrypted_op); + string_find_test_impl(param, executor, clear_op); + } +} + +pub(crate) fn string_find_test_impl( + param: P, + mut find_executor: T, + clear_function: for<'a> fn(&'a str, &'a str) -> Option, +) where + P: Into, + T: for<'a> FunctionExecutor< + (&'a FheString, GenericPatternRef<'a>), + (RadixCiphertext, BooleanBlock), + >, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); + + find_executor.setup(&cks2, sks); + + // trivial for str_pad in 0..2 { for pat_pad in 0..2 { for str in TEST_CASES_FIND { for pat in PATTERN_FIND { - keys.check_find_fhe_string_vs_rust_str(str, Some(str_pad), pat, Some(pat_pad)); - keys.check_rfind_fhe_string_vs_rust_str(str, Some(str_pad), pat, Some(pat_pad)); + let expected_result = clear_function(str, pat); + + let enc_lhs = FheString::new_trivial(&cks, str, Some(str_pad)); + let enc_rhs = + GenericPattern::Enc(FheString::new_trivial(&cks, pat, Some(pat_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(pat.to_string())); + + for rhs in [enc_rhs, clear_rhs] { + let (index, is_some) = find_executor.execute((&enc_lhs, rhs.as_ref())); + + let dec_index = cks.decrypt_radix::(&index); + let dec_is_some = cks.decrypt_bool(&is_some); + + let dec = dec_is_some.then_some(dec_index as usize); + + assert_eq!(dec, expected_result); + } } } } } -} + // encrypted + { + let str = "aba"; + let str_pad = 1; + let rhs_pad = 1; -#[test] -fn test_find() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Encrypted, - ); + for rhs in ["a", "c"] { + let expected_result = clear_function(str, rhs); + + let enc_lhs = FheString::new(&cks, str, Some(str_pad)); + let enc_rhs = GenericPattern::Enc(FheString::new(&cks, rhs, Some(rhs_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(rhs.to_string())); + + for rhs in [enc_rhs, clear_rhs] { + let (index, is_some) = find_executor.execute((&enc_lhs, rhs.as_ref())); - keys.check_find_fhe_string_vs_rust_str("aba", Some(1), "a", Some(1)); - keys.check_find_fhe_string_vs_rust_str("aba", Some(1), "c", Some(1)); + let dec_index = cks.decrypt_radix::(&index); + let dec_is_some = cks.decrypt_bool(&is_some); - keys.check_rfind_fhe_string_vs_rust_str("aba", Some(1), "a", Some(1)); - keys.check_rfind_fhe_string_vs_rust_str("aba", Some(1), "c", Some(1)); + let dec = dec_is_some.then_some(dec_index as usize); + + assert_eq!(dec, expected_result); + } + } + } } #[test] -fn test_replace_trivial() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Trivial, - ); +fn string_replace_test_parameterized() { + string_replace_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} + +#[allow(clippy::needless_pass_by_value)] +fn string_replace_test

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::replace); + string_replace_test_impl(param, executor); +} + +pub(crate) fn string_replace_test_impl(param: P, mut replace_executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a FheString, GenericPatternRef<'a>, &'a FheString), FheString>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); + + replace_executor.setup(&cks2, sks); + // trivial for str_pad in 0..2 { for from_pad in 0..2 { for to_pad in 0..2 { for str in TEST_CASES_FIND { for from in PATTERN_FIND { for to in ["", " ", "a", "abc"] { - keys.check_replace_fhe_string_vs_rust_str( - str, - Some(str_pad), + let expected_result = str.replace(from, to); + + let enc_str = FheString::new_trivial(&cks, str, Some(str_pad)); + let enc_from = GenericPattern::Enc(FheString::new_trivial( + &cks, from, Some(from_pad), - to, - Some(to_pad), - ); - for n in 0..=2 { - for max in n..n + 2 { - keys.check_replacen_fhe_string_vs_rust_str( - (str, Some(str_pad)), - (from, Some(from_pad)), - (to, Some(to_pad)), - n, - max, - ); - } + )); + let clear_from = + GenericPattern::Clear(ClearString::new(from.to_string())); + + let enc_to = FheString::new_trivial(&cks, to, Some(to_pad)); + + for from in [enc_from, clear_from] { + let result = + replace_executor.execute((&enc_str, from.as_ref(), &enc_to)); + + let dec_result = cks.decrypt_ascii(&result); + + assert_eq!(dec_result, expected_result); } } } @@ -80,265 +169,145 @@ fn test_replace_trivial() { } } } -} - -#[test] -fn test_replace() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Encrypted, - ); - - keys.check_replace_fhe_string_vs_rust_str("ab", Some(1), "a", Some(1), "d", Some(1)); - keys.check_replace_fhe_string_vs_rust_str("ab", Some(1), "c", Some(1), "d", Some(1)); - - keys.check_replacen_fhe_string_vs_rust_str( - ("ab", Some(1)), - ("a", Some(1)), - ("d", Some(1)), - 1, - 2, - ); - keys.check_replacen_fhe_string_vs_rust_str( - ("ab", Some(1)), - ("c", Some(1)), - ("d", Some(1)), - 1, - 2, - ); -} - -impl TestKeys { - pub fn check_find_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - ) { - let expected = str.find(pat); + // encrypted + { + let str = "ab"; + let str_pad = 1; + let from_pad = 1; + let to = "d"; + let to_pad = 1; - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - let clear_pat = GenericPattern::Clear(ClearString::new(pat.to_string())); + for from in ["a", "c"] { + let expected_result = str.replace(from, to); - let start = Instant::now(); - let (index, is_some) = self.sk.find(&enc_str, enc_pat.as_ref()); - let end = Instant::now(); + let enc_str = FheString::new_trivial(&cks, str, Some(str_pad)); + let enc_from = GenericPattern::Enc(FheString::new_trivial(&cks, from, Some(from_pad))); + let clear_from = GenericPattern::Clear(ClearString::new(from.to_string())); - let dec_index = self.ck.decrypt_radix::(&index); - let dec_is_some = self.ck.decrypt_bool(&is_some); + let enc_to = FheString::new_trivial(&cks, to, Some(to_pad)); - let dec = dec_is_some.then_some(dec_index as usize); + for from in [enc_from, clear_from] { + let result = replace_executor.execute((&enc_str, from.as_ref(), &enc_to)); - println!("\n\x1b[1mFind:\x1b[0m"); - result_message_pat(str, pat, expected, dec, end.duration_since(start)); + let dec_result = cks.decrypt_ascii(&result); - assert_eq!(dec, expected); - - let start = Instant::now(); - let (index, is_some) = self.sk.find(&enc_str, clear_pat.as_ref()); - let end = Instant::now(); - - let dec_index = self.ck.decrypt_radix::(&index); - let dec_is_some = self.ck.decrypt_bool(&is_some); - - let dec = dec_is_some.then_some(dec_index as usize); - - println!("\n\x1b[1mFind:\x1b[0m"); - result_message_clear_pat(str, pat, expected, dec, end.duration_since(start)); - - assert_eq!(dec, expected); + assert_eq!(dec_result, expected_result); + } + } } +} - pub fn check_rfind_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - ) { - let expected = str.rfind(pat); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - let clear_pat = GenericPattern::Clear(ClearString::new(pat.to_string())); - - let start = Instant::now(); - let (index, is_some) = self.sk.rfind(&enc_str, enc_pat.as_ref()); - let end = Instant::now(); - - let dec_index = self.ck.decrypt_radix::(&index); - let dec_is_some = self.ck.decrypt_bool(&is_some); +#[test] +fn string_replacen_test_parameterized() { + string_replacen_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} - let dec = dec_is_some.then_some(dec_index as usize); +#[allow(clippy::needless_pass_by_value)] +fn string_replacen_test

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::replacen); + string_replacen_test_impl(param, executor); +} - println!("\n\x1b[1mRfind:\x1b[0m"); - result_message_pat(str, pat, expected, dec, end.duration_since(start)); +pub(crate) fn string_replacen_test_impl(param: P, mut replacen_executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + ( + &'a FheString, + GenericPatternRef<'a>, + &'a FheString, + &'a UIntArg, + ), + FheString, + >, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); + + replacen_executor.setup(&cks2, sks); + + // trivial + for str_pad in 0..2 { + for from_pad in 0..2 { + for to_pad in 0..2 { + for str in TEST_CASES_FIND { + for from in PATTERN_FIND { + for to in ["", " ", "a", "abc"] { + for n in 0..=2 { + for max in n..n + 2 { + let expected_result = str.replacen(from, to, n as usize); + + let enc_str = FheString::new_trivial(&cks, str, Some(str_pad)); + let enc_from = GenericPattern::Enc(FheString::new_trivial( + &cks, + from, + Some(from_pad), + )); + let clear_from = + GenericPattern::Clear(ClearString::new(from.to_string())); + + let enc_to = FheString::new_trivial(&cks, to, Some(to_pad)); + + let clear_n = UIntArg::Clear(n); + let enc_n = UIntArg::Enc(cks.trivial_encrypt_u16(n, Some(max))); + + for from in [enc_from, clear_from] { + for n in [&clear_n, &enc_n] { + let result = replacen_executor.execute(( + &enc_str, + from.as_ref(), + &enc_to, + n, + )); + + let dec_result = cks.decrypt_ascii(&result); + + assert_eq!(dec_result, expected_result); + } + } + } + } + } + } + } + } + } + } + // encrypted + { + let str = "ab"; + let str_pad = 1; + let from_pad = 1; + let to = "d"; + let to_pad = 1; + let n = 1; + let max = 2; - assert_eq!(dec, expected); + for from in ["a", "c"] { + let expected_result = str.replacen(from, to, n as usize); - let start = Instant::now(); - let (index, is_some) = self.sk.rfind(&enc_str, clear_pat.as_ref()); - let end = Instant::now(); + let enc_str = FheString::new_trivial(&cks, str, Some(str_pad)); + let enc_from = GenericPattern::Enc(FheString::new_trivial(&cks, from, Some(from_pad))); + let clear_from = GenericPattern::Clear(ClearString::new(from.to_string())); - let dec_index = self.ck.decrypt_radix::(&index); - let dec_is_some = self.ck.decrypt_bool(&is_some); + let enc_to = FheString::new_trivial(&cks, to, Some(to_pad)); - let dec = dec_is_some.then_some(dec_index as usize); + let clear_n = UIntArg::Clear(n); + let enc_n = UIntArg::Enc(cks.encrypt_u16(n, Some(max))); - println!("\n\x1b[1mRfind:\x1b[0m"); - result_message_clear_pat(str, pat, expected, dec, end.duration_since(start)); + for from in [enc_from, clear_from] { + for n in [&clear_n, &enc_n] { + let result = replacen_executor.execute((&enc_str, from.as_ref(), &enc_to, n)); - assert_eq!(dec, expected); - } - pub fn check_replace_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - to: &str, - to_pad: Option, - ) { - let expected = str.replace(pat, to); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - let clear_pat = GenericPattern::Clear(ClearString::new(pat.to_string())); - let enc_to = self.encrypt_string(to, to_pad); - - let start = Instant::now(); - let result = self.sk.replace(&enc_str, enc_pat.as_ref(), &enc_to); - let end = Instant::now(); - - let dec = self.ck.decrypt_ascii(&result); - - println!( - "\n\x1b[1mReplace:\x1b[0m\n\ - \x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mString: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mFrom: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mTo: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - str, - pat, - to, - expected, - dec, - end.duration_since(start), - ); - - assert_eq!(dec, expected); - - let start = Instant::now(); - let result = self.sk.replace(&enc_str, clear_pat.as_ref(), &enc_to); - let end = Instant::now(); - - let dec = self.ck.decrypt_ascii(&result); - - println!( - "\n\x1b[1mReplace:\x1b[0m\n\ - \x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mString: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mFrom (clear): \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mTo: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - str, - pat, - to, - expected, - dec, - end.duration_since(start), - ); - - assert_eq!(dec, expected); - } + let dec_result = cks.decrypt_ascii(&result); - pub fn check_replacen_fhe_string_vs_rust_str( - &self, - str: (&str, Option), - pat: (&str, Option), - to: (&str, Option), - n: u16, - max: u16, - ) { - let (str, str_pad) = (str.0, str.1); - let (pat, pat_pad) = (pat.0, pat.1); - let (to, to_pad) = (to.0, to.1); - - let expected = str.replacen(pat, to, n as usize); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - let clear_pat = GenericPattern::Clear(ClearString::new(pat.to_string())); - let enc_to = self.encrypt_string(to, to_pad); - - let clear_n = UIntArg::Clear(n); - let enc_n = UIntArg::Enc(self.encrypt_u16(n, Some(max))); - - let start = Instant::now(); - let result = self - .sk - .replacen(&enc_str, enc_pat.as_ref(), &enc_to, &clear_n); - let end = Instant::now(); - - let dec = self.ck.decrypt_ascii(&result); - - println!( - "\n\x1b[1mReplacen:\x1b[0m\n\ - \x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mString: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mFrom: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mTo: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mTimes (clear): \x1b[0m{}\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - str, - pat, - to, - n, - expected, - dec, - end.duration_since(start), - ); - assert_eq!(dec, expected); - - let start = Instant::now(); - let result = self - .sk - .replacen(&enc_str, clear_pat.as_ref(), &enc_to, &enc_n); - let end = Instant::now(); - - let dec = self.ck.decrypt_ascii(&result); - - println!( - "\n\x1b[1mReplacen:\x1b[0m\n\ - \x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mString: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mFrom (clear): \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mTo: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mTimes (encrypted): \x1b[0m{}\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - str, - pat, - to, - n, - expected, - dec, - end.duration_since(start), - ); - assert_eq!(dec, expected); + assert_eq!(dec_result, expected_result); + } + } + } } } diff --git a/tfhe/src/strings/test_functions/test_split.rs b/tfhe/src/strings/test_functions/test_split.rs index 683500c40d..9e7d945a8c 100644 --- a/tfhe/src/strings/test_functions/test_split.rs +++ b/tfhe/src/strings/test_functions/test_split.rs @@ -1,10 +1,15 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; +use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; -use crate::strings::ciphertext::{GenericPattern, UIntArg}; +use crate::shortint::PBSParameters; +use crate::strings::ciphertext::{ + ClearString, FheString, GenericPattern, GenericPatternRef, UIntArg, +}; use crate::strings::server_key::FheStringIterator; -use crate::strings::test::TestKind; -use crate::strings::test_functions::result_message_pat; -use crate::strings::TestKeys; -use std::time::Instant; +use std::iter::once; +use std::sync::Arc; const TEST_CASES_SPLIT: [(&str, &str); 21] = [ ("", ""), @@ -31,590 +36,361 @@ const TEST_CASES_SPLIT: [(&str, &str); 21] = [ ]; #[test] -fn test_split_trivial() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Trivial, - ); +fn string_split_once_test_parameterized() { + string_split_once_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} +#[allow(clippy::needless_pass_by_value)] +fn string_split_once_test

(param: P) +where + P: Into, +{ + #[allow(clippy::type_complexity)] + let ops: [( + for<'a> fn(&'a str, &'a str) -> Option<(&'a str, &'a str)>, + fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> (FheString, FheString, BooleanBlock), + ); 2] = [ + ( + |lhs: &str, rhs: &str| lhs.split_once(rhs), + |a, b, c| ServerKey::split_once(a, b, c), + ), + ( + |lhs: &str, rhs: &str| lhs.rsplit_once(rhs), + |a, b, c| ServerKey::rsplit_once(a, b, c), + ), + ]; + + let param = param.into(); + + for (clear_op, encrypted_op) in ops { + let executor = CpuFunctionExecutor::new(&encrypted_op); + string_split_once_test_impl(param, executor, clear_op); + } +} + +pub(crate) fn string_split_once_test_impl( + param: P, + mut split_once_executor: T, + clear_function: for<'a> fn(&'a str, &'a str) -> Option<(&'a str, &'a str)>, +) where + P: Into, + T: for<'a> FunctionExecutor< + (&'a FheString, GenericPatternRef<'a>), + (FheString, FheString, BooleanBlock), + >, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); + + split_once_executor.setup(&cks2, sks); + + // trivial for str_pad in 0..2 { for pat_pad in 0..2 { for (str, pat) in TEST_CASES_SPLIT { - keys.check_split_once_fhe_string_vs_rust_str( - str, - Some(str_pad), - pat, - Some(pat_pad), - ); - keys.check_rsplit_once_fhe_string_vs_rust_str( - str, - Some(str_pad), - pat, - Some(pat_pad), - ); - keys.check_split_fhe_string_vs_rust_str(str, Some(str_pad), pat, Some(pat_pad)); - keys.check_rsplit_fhe_string_vs_rust_str(str, Some(str_pad), pat, Some(pat_pad)); + let expected = clear_function(str, pat); - for n in 0..3 { - for max in n..n + 2 { - keys.check_splitn_fhe_string_vs_rust_str( - str, - Some(str_pad), - pat, - Some(pat_pad), - n, - max, - ); - keys.check_rsplitn_fhe_string_vs_rust_str( - str, - Some(str_pad), - pat, - Some(pat_pad), - n, - max, - ); - } - } + let enc_lhs = FheString::new_trivial(&cks, str, Some(str_pad)); + let enc_rhs = GenericPattern::Enc(FheString::new_trivial(&cks, pat, Some(pat_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(pat.to_string())); - keys.check_split_terminator_fhe_string_vs_rust_str( - str, - Some(str_pad), - pat, - Some(pat_pad), - ); - keys.check_rsplit_terminator_fhe_string_vs_rust_str( - str, - Some(str_pad), - pat, - Some(pat_pad), - ); - keys.check_split_inclusive_fhe_string_vs_rust_str( - str, - Some(str_pad), - pat, - Some(pat_pad), - ); - } - } - } -} + for rhs in [enc_rhs, clear_rhs] { + let (split1, split2, is_some) = + split_once_executor.execute((&enc_lhs, rhs.as_ref())); -#[test] -fn test_split() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Encrypted, - ); - - keys.check_split_once_fhe_string_vs_rust_str("", Some(1), "", Some(1)); - keys.check_rsplit_once_fhe_string_vs_rust_str("", Some(1), "", Some(1)); - keys.check_split_fhe_string_vs_rust_str("", Some(1), "", Some(1)); - keys.check_rsplit_fhe_string_vs_rust_str("", Some(1), "", Some(1)); - - keys.check_splitn_fhe_string_vs_rust_str("", Some(1), "", Some(1), 1, 2); - keys.check_rsplitn_fhe_string_vs_rust_str("", Some(1), "", Some(1), 1, 2); - - keys.check_split_terminator_fhe_string_vs_rust_str("", Some(1), "", Some(1)); - keys.check_rsplit_terminator_fhe_string_vs_rust_str("", Some(1), "", Some(1)); - keys.check_split_inclusive_fhe_string_vs_rust_str("", Some(1), "", Some(1)); -} + let dec_split1 = cks.decrypt_ascii(&split1); -impl TestKeys { - pub fn check_split_once_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - ) { - let expected = str.split_once(pat); + let dec_split2 = cks.decrypt_ascii(&split2); - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); + let dec_is_some = cks.decrypt_bool(&is_some); - let start = Instant::now(); - let (lhs, rhs, is_some) = self.sk.split_once(&enc_str, enc_pat.as_ref()); - let end = Instant::now(); + let dec = dec_is_some.then_some((dec_split1.as_str(), dec_split2.as_str())); - let dec_lhs = self.ck.decrypt_ascii(&lhs); - let dec_rhs = self.ck.decrypt_ascii(&rhs); - let dec_is_some = self.ck.decrypt_bool(&is_some); + assert_eq!(expected, dec) + } + } + } + } + // encrypted + { + let str = "aba"; + let str_pad = 1; + let rhs_pad = 1; - let dec = dec_is_some.then_some((dec_lhs.as_str(), dec_rhs.as_str())); + for rhs in ["a", "c"] { + let expected = clear_function(str, rhs); - println!("\n\x1b[1mSplit_once:\x1b[0m"); - result_message_pat(str, pat, expected, dec, end.duration_since(start)); + let enc_lhs = FheString::new(&cks, str, Some(str_pad)); + let enc_rhs = GenericPattern::Enc(FheString::new(&cks, rhs, Some(rhs_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(rhs.to_string())); - assert_eq!(dec, expected); - } + for rhs in [enc_rhs, clear_rhs] { + let (split1, split2, is_some) = + split_once_executor.execute((&enc_lhs, rhs.as_ref())); - pub fn check_rsplit_once_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - ) { - let expected = str.rsplit_once(pat); + let dec_split1 = cks.decrypt_ascii(&split1); - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); + let dec_split2 = cks.decrypt_ascii(&split2); - let start = Instant::now(); - let (lhs, rhs, is_some) = self.sk.rsplit_once(&enc_str, enc_pat.as_ref()); - let end = Instant::now(); + let dec_is_some = cks.decrypt_bool(&is_some); - let dec_lhs = self.ck.decrypt_ascii(&lhs); - let dec_rhs = self.ck.decrypt_ascii(&rhs); - let dec_is_some = self.ck.decrypt_bool(&is_some); + let dec = dec_is_some.then_some((dec_split1.as_str(), dec_split2.as_str())); - let dec = dec_is_some.then_some((dec_lhs.as_str(), dec_rhs.as_str())); + assert_eq!(expected, dec) + } + } + } +} - println!("\n\x1b[1mRsplit_once:\x1b[0m"); - result_message_pat(str, pat, expected, dec, end.duration_since(start)); +#[test] +fn string_split_test_parameterized() { + string_split_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} - assert_eq!(dec, expected); +#[allow(clippy::needless_pass_by_value)] +fn string_split_test

(param: P) +where + P: Into, +{ + #[allow(clippy::type_complexity)] + let ops: [( + for<'a> fn(&'a str, &'a str) -> Box + 'a>, + fn(&ServerKey, &FheString, GenericPatternRef<'_>) -> Box, + ); 5] = [ + ( + |lhs: &str, rhs: &str| Box::new(lhs.split(rhs)), + |a, b, c| Box::new(ServerKey::split(a, b, c)), + ), + ( + |lhs: &str, rhs: &str| Box::new(lhs.rsplit(rhs)), + |a, b, c| Box::new(ServerKey::rsplit(a, b, c)), + ), + ( + |lhs: &str, rhs: &str| Box::new(lhs.split_terminator(rhs)), + |a, b, c| Box::new(ServerKey::split_terminator(a, b, c)), + ), + ( + |lhs: &str, rhs: &str| Box::new(lhs.rsplit_terminator(rhs)), + |a, b, c| Box::new(ServerKey::rsplit_terminator(a, b, c)), + ), + ( + |lhs: &str, rhs: &str| Box::new(lhs.split_inclusive(rhs)), + |a, b, c| Box::new(ServerKey::split_inclusive(a, b, c)), + ), + ]; + + let param = param.into(); + + for (clear_op, encrypted_op) in ops { + let executor = CpuFunctionExecutor::new(&encrypted_op); + string_split_test_impl(param, executor, clear_op); } +} - pub fn check_split_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - ) { - let mut expected: Vec<_> = str.split(pat).map(Some).collect(); - expected.push(None); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - - let mut results = Vec::with_capacity(expected.len()); - - // Call next enough times - let start = Instant::now(); - let mut split_iter = self.sk.split(&enc_str, enc_pat.as_ref()); - for _ in 0..expected.len() { - results.push(split_iter.next(&self.sk)) - } - let end = Instant::now(); +pub(crate) fn string_split_test_impl( + param: P, + mut split_executor: T, + clear_function: for<'a> fn(&'a str, &'a str) -> Box + 'a>, +) where + P: Into, + T: for<'a> FunctionExecutor<(&'a FheString, GenericPatternRef<'a>), Box>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); + + split_executor.setup(&cks2, sks.clone()); + + // trivial + for str_pad in 0..2 { + for pat_pad in 0..2 { + for (str, pat) in TEST_CASES_SPLIT { + let expected: Vec<_> = clear_function(str, pat) + .map(Some) + .chain(once(None)) + .collect(); - // Collect the decrypted results properly - let dec: Vec<_> = results - .iter() - .map(|(result, is_some)| { - let dec_is_some = self.ck.decrypt_bool(is_some); + let enc_lhs = FheString::new_trivial(&cks, str, Some(str_pad)); + let enc_rhs = GenericPattern::Enc(FheString::new_trivial(&cks, pat, Some(pat_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(pat.to_string())); - dec_is_some.then_some(self.ck.decrypt_ascii(result)) - }) - .collect(); + for rhs in [enc_rhs, clear_rhs] { + let mut iterator = split_executor.execute((&enc_lhs, rhs.as_ref())); - let dec_as_str: Vec<_> = dec - .iter() - .map(|option| option.as_ref().map(|s| s.as_str())) - .collect(); + for expected in &expected { + let (split, is_some) = iterator.next(&sks); - println!("\n\x1b[1mSplit:\x1b[0m"); - result_message_pat(str, pat, &expected, &dec_as_str, end.duration_since(start)); + let dec_split = cks.decrypt_ascii(&split); + let dec_is_some = cks.decrypt_bool(&is_some); - assert_eq!(dec_as_str, expected); - } + let dec = dec_is_some.then_some(dec_split); - pub fn check_rsplit_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - ) { - let mut expected: Vec<_> = str.rsplit(pat).map(Some).collect(); - expected.push(None); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - - let mut results = Vec::with_capacity(expected.len()); - - // Call next enough times - let start = Instant::now(); - let mut split_iter = self.sk.rsplit(&enc_str, enc_pat.as_ref()); - for _ in 0..expected.len() { - results.push(split_iter.next(&self.sk)) + assert_eq!(expected, &dec.as_deref()) + } + } + } } - let end = Instant::now(); + } + // encrypted + { + let str = "aba"; + let str_pad = 1; + let rhs_pad = 1; - // Collect the decrypted results properly - let dec: Vec<_> = results - .iter() - .map(|(result, is_some)| { - let dec_is_some = self.ck.decrypt_bool(is_some); + for rhs in ["a", "c"] { + let expected: Vec<_> = clear_function(str, rhs) + .map(Some) + .chain(once(None)) + .collect(); - dec_is_some.then_some(self.ck.decrypt_ascii(result)) - }) - .collect(); + let enc_lhs = FheString::new(&cks, str, Some(str_pad)); + let enc_rhs = GenericPattern::Enc(FheString::new(&cks, rhs, Some(rhs_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(rhs.to_string())); - let dec_as_str: Vec<_> = dec - .iter() - .map(|option| option.as_ref().map(|s| s.as_str())) - .collect(); + for rhs in [enc_rhs, clear_rhs] { + let mut iterator = split_executor.execute((&enc_lhs, rhs.as_ref())); - println!("\n\x1b[1mRsplit:\x1b[0m"); - result_message_pat(str, pat, &expected, &dec_as_str, end.duration_since(start)); + for expected in &expected { + let (split, is_some) = iterator.next(&sks); - assert_eq!(dec_as_str, expected); - } + let dec_split = cks.decrypt_ascii(&split); + let dec_is_some = cks.decrypt_bool(&is_some); - pub fn check_split_terminator_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - ) { - let mut expected: Vec<_> = str.split_terminator(pat).map(Some).collect(); - expected.push(None); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - - let mut results = Vec::with_capacity(expected.len()); - - // Call next enough times - let start = Instant::now(); - let mut split_iter = self.sk.split_terminator(&enc_str, enc_pat.as_ref()); - for _ in 0..expected.len() { - results.push(split_iter.next(&self.sk)) - } - let end = Instant::now(); + let dec = dec_is_some.then_some(dec_split); - // Collect the decrypted results properly - let dec: Vec<_> = results - .iter() - .map(|(result, is_some)| { - let dec_is_some = self.ck.decrypt_bool(is_some); + assert_eq!(expected, &dec.as_deref()); + } + } + } + } +} - dec_is_some.then_some(self.ck.decrypt_ascii(result)) - }) - .collect(); +#[test] +fn string_splitn_test_parameterized() { + string_splitn_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} - let dec_as_str: Vec<_> = dec - .iter() - .map(|option| option.as_ref().map(|s| s.as_str())) - .collect(); +#[allow(clippy::needless_pass_by_value)] +fn string_splitn_test

(param: P) +where + P: Into, +{ + #[allow(clippy::type_complexity)] + let ops: [( + for<'a> fn(&'a str, &'a str, u16) -> Box + 'a>, + fn(&ServerKey, &FheString, GenericPatternRef<'_>, UIntArg) -> Box, + ); 2] = [ + ( + |lhs: &str, rhs: &str, n: u16| Box::new(lhs.splitn(n as usize, rhs)), + |a, b, c, d| Box::new(ServerKey::splitn(a, b, c, d)), + ), + ( + |lhs: &str, rhs: &str, n: u16| Box::new(lhs.rsplitn(n as usize, rhs)), + |a, b, c, d| Box::new(ServerKey::rsplitn(a, b, c, d)), + ), + ]; + + let param = param.into(); + + for (clear_op, encrypted_op) in ops { + let executor = CpuFunctionExecutor::new(&encrypted_op); + string_splitn_test_impl(param, executor, clear_op); + } +} - println!("\n\x1b[1mSplit_terminator:\x1b[0m"); - result_message_pat(str, pat, &expected, &dec_as_str, end.duration_since(start)); +pub(crate) fn string_splitn_test_impl( + param: P, + mut splitn_executor: T, + clear_function: for<'a> fn(&'a str, &'a str, u16) -> Box + 'a>, +) where + P: Into, + T: for<'a> FunctionExecutor< + (&'a FheString, GenericPatternRef<'a>, UIntArg), + Box, + >, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); + + splitn_executor.setup(&cks2, sks.clone()); + + // trivial + for str_pad in 0..2 { + for pat_pad in 0..2 { + for (str, pat) in TEST_CASES_SPLIT { + for n in 0..3 { + for max in n..n + 2 { + let expected: Vec<_> = clear_function(str, pat, n) + .map(Some) + .chain(once(None)) + .collect(); - assert_eq!(dec_as_str, expected); - } + let enc_lhs = FheString::new_trivial(&cks, str, Some(str_pad)); + let enc_rhs = + GenericPattern::Enc(FheString::new_trivial(&cks, pat, Some(pat_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(pat.to_string())); - pub fn check_rsplit_terminator_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - ) { - let mut expected: Vec<_> = str.rsplit_terminator(pat).map(Some).collect(); - expected.push(None); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - - let mut results = Vec::with_capacity(expected.len()); - - // Call next enough times - let start = Instant::now(); - let mut split_iter = self.sk.rsplit_terminator(&enc_str, enc_pat.as_ref()); - for _ in 0..expected.len() { - results.push(split_iter.next(&self.sk)) - } - let end = Instant::now(); + let clear_n = UIntArg::Clear(n); + let enc_n = UIntArg::Enc(cks.trivial_encrypt_u16(n, Some(max))); - // Collect the decrypted results properly - let dec: Vec<_> = results - .iter() - .map(|(result, is_some)| { - let dec_is_some = self.ck.decrypt_bool(is_some); + for rhs in [enc_rhs, clear_rhs] { + for n in [clear_n.clone(), enc_n.clone()] { + let mut iterator = + splitn_executor.execute((&enc_lhs, rhs.as_ref(), n)); - dec_is_some.then_some(self.ck.decrypt_ascii(result)) - }) - .collect(); + for expected in &expected { + let (split, is_some) = iterator.next(&sks); - let dec_as_str: Vec<_> = dec - .iter() - .map(|option| option.as_ref().map(|s| s.as_str())) - .collect(); + let dec_split = cks.decrypt_ascii(&split); + let dec_is_some = cks.decrypt_bool(&is_some); - println!("\n\x1b[1mRsplit_terminator:\x1b[0m"); - result_message_pat(str, pat, &expected, &dec_as_str, end.duration_since(start)); + let dec = dec_is_some.then_some(dec_split); - assert_eq!(dec_as_str, expected); + assert_eq!(expected, &dec.as_deref()) + } + } + } + } + } + } + } } + // encrypted + { + let str = "aba"; + let str_pad = 1; + let rhs_pad = 1; + let n = 1; + let max = 2; - pub fn check_split_inclusive_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - ) { - let mut expected: Vec<_> = str.split_inclusive(pat).map(Some).collect(); - expected.push(None); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - - let mut results = Vec::with_capacity(expected.len()); - - // Call next enough times - let start = Instant::now(); - let mut split_iter = self.sk.split_inclusive(&enc_str, enc_pat.as_ref()); - for _ in 0..expected.len() { - results.push(split_iter.next(&self.sk)) - } - let end = Instant::now(); + for rhs in ["a", "c"] { + let expected: Vec<_> = clear_function(str, rhs, n) + .map(Some) + .chain(once(None)) + .collect(); - // Collect the decrypted results properly - let dec: Vec<_> = results - .iter() - .map(|(result, is_some)| { - let dec_is_some = self.ck.decrypt_bool(is_some); + let enc_lhs = FheString::new(&cks, str, Some(str_pad)); + let enc_rhs = GenericPattern::Enc(FheString::new(&cks, rhs, Some(rhs_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(rhs.to_string())); - dec_is_some.then_some(self.ck.decrypt_ascii(result)) - }) - .collect(); + let enc_n = UIntArg::Enc(cks.encrypt_u16(n, Some(max))); - let dec_as_str: Vec<_> = dec - .iter() - .map(|option| option.as_ref().map(|s| s.as_str())) - .collect(); + for rhs in [enc_rhs, clear_rhs] { + let mut iterator = splitn_executor.execute((&enc_lhs, rhs.as_ref(), enc_n.clone())); - println!("\n\x1b[1mSplit_inclusive:\x1b[0m"); - result_message_pat(str, pat, &expected, &dec_as_str, end.duration_since(start)); + for expected in &expected { + let (split, is_some) = iterator.next(&sks); - assert_eq!(dec_as_str, expected); - } + let dec_split = cks.decrypt_ascii(&split); + let dec_is_some = cks.decrypt_bool(&is_some); - pub fn check_splitn_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - n: u16, - max: u16, - ) { - let mut expected: Vec<_> = str.splitn(n as usize, pat).map(Some).collect(); - expected.push(None); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - - let mut results = Vec::with_capacity(expected.len()); - - // Call next enough times - let start = Instant::now(); - let mut split_iter = self - .sk - .splitn(&enc_str, enc_pat.as_ref(), UIntArg::Clear(n)); - for _ in 0..expected.len() { - results.push(split_iter.next(&self.sk)) - } - let end = Instant::now(); - - // Collect the decrypted results properly - let dec: Vec<_> = results - .iter() - .map(|(result, is_some)| { - let dec_is_some = self.ck.decrypt_bool(is_some); - - dec_is_some.then_some(self.ck.decrypt_ascii(result)) - }) - .collect(); - - println!( - "\n\x1b[1mSplitn:\x1b[0m\n\ - \x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mString: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mPattern: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mTimes (clear): \x1b[0m{}\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - str, - pat, - n, - expected, - dec, - end.duration_since(start), - ); - - let dec_as_str: Vec<_> = dec - .iter() - .map(|option| option.as_ref().map(|s| s.as_str())) - .collect(); - - assert_eq!(dec_as_str, expected); - - let enc_n = self.encrypt_u16(n, Some(max)); - results.clear(); - - // Call next enough times - let start = Instant::now(); - let mut split_iter = self - .sk - .splitn(&enc_str, enc_pat.as_ref(), UIntArg::Enc(enc_n)); - for _ in 0..expected.len() { - results.push(split_iter.next(&self.sk)) - } - let end = Instant::now(); - - // Collect the decrypted results properly - let dec: Vec<_> = results - .iter() - .map(|(result, is_some)| { - let dec_is_some = self.ck.decrypt_bool(is_some); - - dec_is_some.then_some(self.ck.decrypt_ascii(result)) - }) - .collect(); - - println!( - "\n\x1b[1mSplitn:\x1b[0m\n\ - \x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mString: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mPattern: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mTimes (encrypted): \x1b[0m{}\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - str, - pat, - n, - expected, - dec, - end.duration_since(start), - ); - - let dec_as_str: Vec<_> = dec - .iter() - .map(|option| option.as_ref().map(|s| s.as_str())) - .collect(); - - assert_eq!(dec_as_str, expected); - } + let dec = dec_is_some.then_some(dec_split); - pub fn check_rsplitn_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - pat: &str, - pat_pad: Option, - n: u16, - max: u16, - ) { - let mut expected: Vec<_> = str.rsplitn(n as usize, pat).map(Some).collect(); - expected.push(None); - - let enc_str = self.encrypt_string(str, str_pad); - let enc_pat = GenericPattern::Enc(self.encrypt_string(pat, pat_pad)); - - let mut results = Vec::with_capacity(expected.len()); - - // Call next enough times - let start = Instant::now(); - let mut split_iter = self - .sk - .rsplitn(&enc_str, enc_pat.as_ref(), UIntArg::Clear(n)); - for _ in 0..expected.len() { - results.push(split_iter.next(&self.sk)) - } - let end = Instant::now(); - - // Collect the decrypted results properly - let dec: Vec<_> = results - .iter() - .map(|(result, is_some)| { - let dec_is_some = self.ck.decrypt_bool(is_some); - - dec_is_some.then_some(self.ck.decrypt_ascii(result)) - }) - .collect(); - - println!( - "\n\x1b[1mRsplitn:\x1b[0m\n\ - \x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mString: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mPattern: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mTimes (clear): \x1b[0m{}\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - str, - pat, - n, - expected, - dec, - end.duration_since(start), - ); - - let dec_as_str: Vec<_> = dec - .iter() - .map(|option| option.as_ref().map(|s| s.as_str())) - .collect(); - - assert_eq!(dec_as_str, expected); - - let enc_n = self.encrypt_u16(n, Some(max)); - results.clear(); - - // Call next enough times - let start = Instant::now(); - let mut split_iter = self - .sk - .rsplitn(&enc_str, enc_pat.as_ref(), UIntArg::Enc(enc_n)); - for _ in 0..expected.len() { - results.push(split_iter.next(&self.sk)) + assert_eq!(expected, &dec.as_deref()); + } + } } - let end = Instant::now(); - - // Collect the decrypted results properly - let dec: Vec<_> = results - .iter() - .map(|(result, is_some)| { - let dec_is_some = self.ck.decrypt_bool(is_some); - - dec_is_some.then_some(self.ck.decrypt_ascii(result)) - }) - .collect(); - - println!( - "\n\x1b[1mRsplitn:\x1b[0m\n\ - \x1b[1;32m--------------------------------\x1b[0m\n\ - \x1b[1;32;1mString: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mPattern: \x1b[0m\x1b[0;33m{:?}\x1b[0m\n\ - \x1b[1;32;1mTimes (encrypted): \x1b[0m{}\n\ - \x1b[1;32;1mClear API Result: \x1b[0m{:?}\n\ - \x1b[1;32;1mT-fhe API Result: \x1b[0m{:?}\n\ - \x1b[1;34mExecution Time: \x1b[0m{:?}\n\ - \x1b[1;32m--------------------------------\x1b[0m", - str, - pat, - n, - expected, - dec, - end.duration_since(start), - ); - - let dec_as_str: Vec<_> = dec - .iter() - .map(|option| option.as_ref().map(|s| s.as_str())) - .collect(); - - assert_eq!(dec_as_str, expected); } } diff --git a/tfhe/src/strings/test_functions/test_up_low_case.rs b/tfhe/src/strings/test_functions/test_up_low_case.rs index 6ea71cf5bc..bf886a30b4 100644 --- a/tfhe/src/strings/test_functions/test_up_low_case.rs +++ b/tfhe/src/strings/test_functions/test_up_low_case.rs @@ -1,11 +1,11 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; +use crate::integer::{BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; -use crate::strings::ciphertext::{ClearString, GenericPattern}; -use crate::strings::test::TestKind; -use crate::strings::test_functions::{ - result_message, result_message_clear_rhs, result_message_rhs, -}; -use crate::strings::TestKeys; -use std::time::Instant; +use crate::shortint::PBSParameters; +use crate::strings::ciphertext::{ClearString, FheString, GenericPattern, GenericPatternRef}; +use std::sync::Arc; const UP_LOW_CASE: [&str; 21] = [ "", // @@ -14,140 +14,143 @@ const UP_LOW_CASE: [&str; 21] = [ "[", "\\", "]", "^", "_", "`", // chars between 'Z' and 'a' "a", "z", // "{", // just after 'z' - "a ", " a", "A", "A ", " A", "aA", " aA", "aA ", + "a ", " a", "A ", " A", "aA", " aA", "aA ", "a A", ]; #[test] -fn test_to_lower_upper_case_trivial() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Trivial, - ); +fn string_to_lower_upper_case_test_parameterized() { + string_to_lower_upper_case_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} +#[allow(clippy::needless_pass_by_value)] +fn string_to_lower_upper_case_test

(param: P) +where + P: Into, +{ + #[allow(clippy::type_complexity)] + let ops: [( + for<'a> fn(&'a str) -> String, + fn(&ServerKey, &FheString) -> FheString, + ); 2] = [ + (|lhs| lhs.to_lowercase(), ServerKey::to_lowercase), + (|lhs| lhs.to_uppercase(), ServerKey::to_uppercase), + ]; + + let param = param.into(); + + for (clear_op, encrypted_op) in ops { + let executor = CpuFunctionExecutor::new(&encrypted_op); + string_to_lower_upper_case_test_impl(param, executor, clear_op); + } +} + +pub(crate) fn string_to_lower_upper_case_test_impl( + param: P, + mut to_lower_upper_case_executor: T, + clear_function: for<'a> fn(&'a str) -> String, +) where + P: Into, + T: for<'a> FunctionExecutor<&'a FheString, FheString>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); + + to_lower_upper_case_executor.setup(&cks2, sks); + + // trivial for str_pad in 0..2 { for str in UP_LOW_CASE { - keys.check_to_lowercase_fhe_string_vs_rust_str(str, Some(str_pad)); - keys.check_to_uppercase_fhe_string_vs_rust_str(str, Some(str_pad)); + let expected_result = clear_function(str); + + let enc_str = FheString::new(&cks, str, Some(str_pad)); + + let result = to_lower_upper_case_executor.execute(&enc_str); + + assert_eq!(expected_result, cks.decrypt_ascii(&result)); } } -} - -#[test] -fn test_to_lower_upper_case() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Encrypted, - ); + // encrypted + { + let str_pad = 1; - keys.check_to_lowercase_fhe_string_vs_rust_str("ab", Some(1)); - keys.check_to_lowercase_fhe_string_vs_rust_str("AB", Some(1)); + for str in ["ab", "AB"] { + let expected_result = clear_function(str); - keys.check_to_uppercase_fhe_string_vs_rust_str("AB", Some(1)); - keys.check_to_uppercase_fhe_string_vs_rust_str("ab", Some(1)); -} + let enc_str = FheString::new(&cks, str, Some(str_pad)); -#[test] -fn test_eq_ignore_case_trivial() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Trivial, - ); + let result = to_lower_upper_case_executor.execute(&enc_str); - for str_pad in 0..2 { - for rhs_pad in 0..2 { - for str in UP_LOW_CASE { - for rhs in UP_LOW_CASE { - keys.check_eq_ignore_case_fhe_string_vs_rust_str( - str, - Some(str_pad), - rhs, - Some(rhs_pad), - ); - } - } + assert_eq!(expected_result, cks.decrypt_ascii(&result)); } } } #[test] -fn test_eq_ignore_case() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Encrypted, - ); - - keys.check_eq_ignore_case_fhe_string_vs_rust_str("aB", Some(1), "Ab", Some(1)); - keys.check_eq_ignore_case_fhe_string_vs_rust_str("aB", Some(1), "Ac", Some(1)); +fn string_eq_ignore_case_test_parameterized() { + string_eq_ignore_case_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); } -impl TestKeys { - pub fn check_eq_ignore_case_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - rhs: &str, - rhs_pad: Option, - ) { - let expected = str.eq_ignore_ascii_case(rhs); - - let enc_lhs = self.encrypt_string(str, str_pad); - let enc_rhs = GenericPattern::Enc(self.encrypt_string(rhs, rhs_pad)); - let clear_rhs = GenericPattern::Clear(ClearString::new(rhs.to_string())); - - let start = Instant::now(); - let result = self.sk.eq_ignore_case(&enc_lhs, enc_rhs.as_ref()); - let end = Instant::now(); - - let dec = self.ck.decrypt_bool(&result); - - println!("\n\x1b[1mEq_ignore_case:\x1b[0m"); - result_message_rhs(str, rhs, expected, dec, end.duration_since(start)); - - assert_eq!(dec, expected); - - let start = Instant::now(); - let result = self.sk.eq_ignore_case(&enc_lhs, clear_rhs.as_ref()); - let end = Instant::now(); - - let dec = self.ck.decrypt_bool(&result); - - println!("\n\x1b[1mEq_ignore_case:\x1b[0m"); - result_message_clear_rhs(str, rhs, expected, dec, end.duration_since(start)); +#[allow(clippy::needless_pass_by_value)] +fn string_eq_ignore_case_test

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::eq_ignore_case); + string_eq_ignore_case_test_impl(param, executor); +} - assert_eq!(dec, expected); - } +pub(crate) fn string_eq_ignore_case_test_impl(param: P, mut eq_ignore_case_executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a FheString, GenericPatternRef<'a>), BooleanBlock>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); - pub fn check_to_lowercase_fhe_string_vs_rust_str(&self, str: &str, str_pad: Option) { - let expected = str.to_lowercase(); + eq_ignore_case_executor.setup(&cks2, sks); - let enc_str = self.encrypt_string(str, str_pad); + // trivial + for str_pad in 0..2 { + for rhs_pad in 0..2 { + for str in UP_LOW_CASE { + for rhs in UP_LOW_CASE { + let expected_result = str.eq_ignore_ascii_case(rhs); - let start = Instant::now(); - let result = self.sk.to_lowercase(&enc_str); - let end = Instant::now(); + let enc_str = FheString::new(&cks, str, Some(str_pad)); - let dec = self.ck.decrypt_ascii(&result); + let enc_rhs = + GenericPattern::Enc(FheString::new_trivial(&cks, rhs, Some(rhs_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(rhs.to_string())); - println!("\n\x1b[1mTo_lowercase:\x1b[0m"); - result_message(str, &expected, &dec, end.duration_since(start)); + for rhs in [enc_rhs, clear_rhs] { + let result = eq_ignore_case_executor.execute((&enc_str, rhs.as_ref())); - assert_eq!(dec, expected); + assert_eq!(expected_result, cks.decrypt_bool(&result)); + } + } + } + } } + // encrypted + { + let str = "aB"; + let str_pad = 1; + let rhs_pad = 1; - pub fn check_to_uppercase_fhe_string_vs_rust_str(&self, str: &str, str_pad: Option) { - let expected = str.to_uppercase(); + for rhs in ["Ab", "Ac"] { + let expected_result = str.eq_ignore_ascii_case(rhs); - let enc_str = self.encrypt_string(str, str_pad); + let enc_str = FheString::new(&cks, str, Some(str_pad)); + let enc_rhs = GenericPattern::Enc(FheString::new_trivial(&cks, rhs, Some(rhs_pad))); + let clear_rhs = GenericPattern::Clear(ClearString::new(rhs.to_string())); - let start = Instant::now(); - let result = self.sk.to_uppercase(&enc_str); - let end = Instant::now(); + for rhs in [enc_rhs, clear_rhs] { + let result = eq_ignore_case_executor.execute((&enc_str, rhs.as_ref())); - let dec = self.ck.decrypt_ascii(&result); - - println!("\n\x1b[1mTo_upperrcase:\x1b[0m"); - result_message(str, &expected, &dec, end.duration_since(start)); - - assert_eq!(dec, expected); + assert_eq!(expected_result, cks.decrypt_bool(&result)); + } + } } } diff --git a/tfhe/src/strings/test_functions/test_whitespace.rs b/tfhe/src/strings/test_functions/test_whitespace.rs index 370e9e2079..8ea470b7e4 100644 --- a/tfhe/src/strings/test_functions/test_whitespace.rs +++ b/tfhe/src/strings/test_functions/test_whitespace.rs @@ -1,18 +1,59 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; +use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey}; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; +use crate::shortint::PBSParameters; +use crate::strings::ciphertext::FheString; use crate::strings::server_key::{split_ascii_whitespace, FheStringIterator}; -use crate::strings::test::TestKind; -use crate::strings::test_functions::result_message; -use crate::strings::TestKeys; -use std::time::Instant; +use std::iter::once; +use std::sync::Arc; + const WHITESPACES: [&str; 5] = [" ", "\n", "\t", "\r", "\u{000C}"]; #[test] -fn test_trim_trivial() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Trivial, - ); +fn string_trim_test_parameterized() { + string_trim_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} + +#[allow(clippy::needless_pass_by_value)] +fn string_trim_test

(param: P) +where + P: Into, +{ + #[allow(clippy::type_complexity)] + let ops: [( + for<'a> fn(&'a str) -> &'a str, + fn(&ServerKey, &FheString) -> FheString, + ); 3] = [ + (|lhs| lhs.trim(), ServerKey::trim), + (|lhs| lhs.trim_start(), ServerKey::trim_start), + (|lhs| lhs.trim_end(), ServerKey::trim_end), + ]; + + let param = param.into(); + + for (clear_op, encrypted_op) in ops { + let executor = CpuFunctionExecutor::new(&encrypted_op); + string_trim_test_impl(param, executor, clear_op); + } +} +pub(crate) fn string_trim_test_impl( + param: P, + mut trim_executor: T, + clear_function: for<'a> fn(&'a str) -> &'a str, +) where + P: Into, + T: for<'a> FunctionExecutor<&'a FheString, FheString>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); + + trim_executor.setup(&cks2, sks); + + // trivial for str_pad in 0..2 { for ws in WHITESPACES { for core in ["", "a", "a a"] { @@ -23,39 +64,63 @@ fn test_trim_trivial() { format!("{core}{ws}"), format!("{ws}{core}{ws}"), ] { - keys.check_trim_fhe_string_vs_rust_str(&str, Some(str_pad)); - keys.check_trim_start_fhe_string_vs_rust_str(&str, Some(str_pad)); - keys.check_trim_end_fhe_string_vs_rust_str(&str, Some(str_pad)); + let expected_result = clear_function(&str); + + let enc_str = FheString::new(&cks, &str, Some(str_pad)); + + let result = trim_executor.execute(&enc_str); + + assert_eq!(expected_result, &cks.decrypt_ascii(&result)); } } } } + // encrypted + { + let str_pad = 1; + + for str in [" a ", "abc"] { + let expected_result = clear_function(str); + + let enc_str = FheString::new(&cks, str, Some(str_pad)); + + let result = trim_executor.execute(&enc_str); + + assert_eq!(expected_result, &cks.decrypt_ascii(&result)); + } + } } #[test] -fn test_trim() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Encrypted, - ); +fn string_split_whitespace_test_parameterized() { + string_split_whitespace_test(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); +} - keys.check_trim_fhe_string_vs_rust_str(" a ", Some(1)); - keys.check_trim_fhe_string_vs_rust_str("abc", Some(1)); +#[allow(clippy::needless_pass_by_value)] +fn string_split_whitespace_test

(param: P) +where + P: Into, +{ + let fhe_func: fn(&ServerKey, &FheString) -> Box = + |_sk, str| Box::new(split_ascii_whitespace(str)); - keys.check_trim_start_fhe_string_vs_rust_str(" a ", Some(1)); - keys.check_trim_start_fhe_string_vs_rust_str("abc", Some(1)); + let executor = CpuFunctionExecutor::new(&fhe_func); - keys.check_trim_end_fhe_string_vs_rust_str(" a ", Some(1)); - keys.check_trim_end_fhe_string_vs_rust_str("abc", Some(1)); + string_split_whitespace_test_impl(param, executor); } -#[test] -fn test_split_ascii_whitespace_trivial() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Trivial, - ); +pub(crate) fn string_split_whitespace_test_impl(param: P, mut split_whitespace_executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a FheString, Box>, +{ + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks2 = RadixClientKey::from((cks.clone(), 0)); + + split_whitespace_executor.setup(&cks2, sks.clone()); + // trivial for str_pad in 0..2 { for ws in WHITESPACES { #[allow(clippy::useless_format)] @@ -74,118 +139,55 @@ fn test_split_ascii_whitespace_trivial() { format!("{ws}a{ws}a"), format!("a{ws}a{ws}a"), ] { - keys.check_split_ascii_whitespace_fhe_string_vs_rust_str(&str, Some(str_pad)); - } - } - } -} + let expected: Vec<_> = str + .split_ascii_whitespace() + .map(Some) + .chain(once(None)) + .collect(); -#[test] -fn test_split_ascii_whitespace() { - let keys = TestKeys::new( - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, - TestKind::Encrypted, - ); - - keys.check_split_ascii_whitespace_fhe_string_vs_rust_str("a b", Some(1)); - keys.check_split_ascii_whitespace_fhe_string_vs_rust_str("abc", Some(1)); -} - -impl TestKeys { - pub fn check_trim_end_fhe_string_vs_rust_str(&self, str: &str, str_pad: Option) { - let expected = str.trim_end(); - - let enc_str = self.encrypt_string(str, str_pad); + let enc_str = FheString::new(&cks, &str, Some(str_pad)); - let start = Instant::now(); - let result = self.sk.trim_end(&enc_str); - let end = Instant::now(); + let mut iterator = split_whitespace_executor.execute(&enc_str); - let dec = self.ck.decrypt_ascii(&result); - - println!("\n\x1b[1mTrim_end:\x1b[0m"); - result_message(str, expected, &dec, end.duration_since(start)); - - assert_eq!(dec, expected); - } + for expected in &expected { + let (split, is_some) = iterator.next(&sks); - pub fn check_trim_start_fhe_string_vs_rust_str(&self, str: &str, str_pad: Option) { - let expected = str.trim_start(); + let dec_split = cks.decrypt_ascii(&split); + let dec_is_some = cks.decrypt_bool(&is_some); - let enc_str = self.encrypt_string(str, str_pad); + let dec = dec_is_some.then_some(dec_split); - let start = Instant::now(); - let result = self.sk.trim_start(&enc_str); - let end = Instant::now(); - - let dec = self.ck.decrypt_ascii(&result); - - println!("\n\x1b[1mTrim_start:\x1b[0m"); - result_message(str, expected, &dec, end.duration_since(start)); - - assert_eq!(dec, expected); + assert_eq!(expected, &dec.as_deref()) + } + } + } } - pub fn check_trim_fhe_string_vs_rust_str(&self, str: &str, str_pad: Option) { - let expected = str.trim(); - - let enc_str = self.encrypt_string(str, str_pad); + // encrypted + { + let str_pad = 1; - let start = Instant::now(); - let result = self.sk.trim(&enc_str); - let end = Instant::now(); + for str in ["a b", "abc"] { + let expected: Vec<_> = str + .split_ascii_whitespace() + .map(Some) + .chain(once(None)) + .collect(); - let dec = self.ck.decrypt_ascii(&result); + let enc_str = FheString::new(&cks, str, Some(str_pad)); - println!("\n\x1b[1mTrim:\x1b[0m"); - result_message(str, expected, &dec, end.duration_since(start)); + let mut iterator = split_whitespace_executor.execute(&enc_str); - assert_eq!(dec, expected); - } + for expected in &expected { + let (split, is_some) = iterator.next(&sks); - pub fn check_split_ascii_whitespace_fhe_string_vs_rust_str( - &self, - str: &str, - str_pad: Option, - ) { - let mut expected: Vec<_> = str.split_ascii_whitespace().map(Some).collect(); - expected.push(None); + let dec_split = cks.decrypt_ascii(&split); + let dec_is_some = cks.decrypt_bool(&is_some); - let enc_str = self.encrypt_string(str, str_pad); + let dec = dec_is_some.then_some(dec_split); - let mut results = Vec::with_capacity(expected.len()); - - // Call next enough times - let start = Instant::now(); - let mut split_iter = split_ascii_whitespace(&enc_str); - for _ in 0..expected.len() { - results.push(split_iter.next(&self.sk)) + assert_eq!(expected, &dec.as_deref()) + } } - let end = Instant::now(); - - // Collect the decrypted results properly - let dec: Vec<_> = results - .iter() - .map(|(result, is_some)| { - let dec_is_some = self.ck.decrypt_bool(is_some); - let dec_result = self.ck.decrypt_ascii(result); - if !dec_is_some { - // When it's None, the FheString returned is always empty - assert_eq!(dec_result, ""); - } - - dec_is_some.then_some(dec_result) - }) - .collect(); - - let dec_as_str: Vec<_> = dec - .iter() - .map(|option| option.as_ref().map(|s| s.as_str())) - .collect(); - - println!("\n\x1b[1mSplit_ascii_whitespace:\x1b[0m"); - result_message(str, &expected, &dec_as_str, end.duration_since(start)); - - assert_eq!(dec_as_str, expected); } }