diff --git a/Cargo.toml b/Cargo.toml index 930e05c..a324572 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,10 +3,8 @@ name = "fhe_string" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] -tfhe = { version = "0.4.1", features = [ "boolean", "shortint", "integer" ] } +tfhe = { version = "0.5.0", features = [ "boolean", "shortint", "integer" ] } serde = { version = "1.0", features = ["derive"] } rayon = "1.8" env_logger = "0.10.0" diff --git a/README.md b/README.md index ab2078c..b66234d 100644 --- a/README.md +++ b/README.md @@ -75,4 +75,4 @@ This project has been developed for the [Zama Bounty Program](https://github.com ## License -See [LICENSE] file. \ No newline at end of file +See [LICENSE](LICENSE) file. \ No newline at end of file diff --git a/examples/cmd/main.rs b/examples/cmd/main.rs index 458aeea..18f123d 100644 --- a/examples/cmd/main.rs +++ b/examples/cmd/main.rs @@ -2,7 +2,10 @@ use std::{any::Any, fmt::Debug, ops::Add, time::Instant}; use clap::Parser; use fhe_string::{generate_keys_with_params, ClientKey, FheOption, FheString, ServerKey}; -use tfhe::{integer::RadixCiphertext, shortint::prelude::PARAM_MESSAGE_2_CARRY_2_KS_PBS}; +use tfhe::{ + integer::{BooleanBlock, RadixCiphertext}, + shortint::prelude::PARAM_MESSAGE_2_CARRY_2_KS_PBS, +}; /// Run string operations in the encrypted domain. #[derive(Parser, Debug)] @@ -575,32 +578,22 @@ struct TestCase { fhe: fn(input: &TestCaseInput) -> Box, } -fn decrypt_bool(k: &ClientKey, b: &RadixCiphertext) -> bool { - let x = k.decrypt::(b); - int_to_bool(x) +fn decrypt_bool(k: &ClientKey, b: &BooleanBlock) -> bool { + k.decrypt_bool(b) } fn decrypt_option_string_pair( k: &ClientKey, opt: &FheOption<(FheString, FheString)>, ) -> Option<(String, String)> { - let is_some = k.decrypt::(&opt.is_some); + let is_some = k.decrypt_bool(&opt.is_some); match is_some { - 0 => None, - 1 => { + false => None, + true => { let val0 = opt.val.0.decrypt(k); let val1 = opt.val.1.decrypt(k); Some((val0, val1)) } - _ => panic!("expected 0 or 1, got {}", is_some), - } -} - -fn int_to_bool(x: u64) -> bool { - match x { - 0 => false, - 1 => true, - _ => panic!("expected 0 or 1, got {}", x), } } diff --git a/src/ciphertext/compare.rs b/src/ciphertext/compare.rs index db1840f..8fc06a3 100644 --- a/src/ciphertext/compare.rs +++ b/src/ciphertext/compare.rs @@ -3,26 +3,23 @@ use std::cmp; use rayon::{join, prelude::*}; -use tfhe::integer::{IntegerCiphertext, RadixCiphertext}; +use tfhe::integer::{BooleanBlock, IntegerCiphertext, RadixCiphertext}; use crate::server_key::ServerKey; -use super::{ - logic::{binary_and, binary_and_vec, binary_not, binary_or}, - FheString, -}; +use super::{logic::all, FheString}; impl FheString { /// Returns whether `self` is empty. The result is an encryption of 1 if /// this is the case and an encryption of 0 otherwise. - pub fn is_empty(&self, k: &ServerKey) -> RadixCiphertext { + pub fn is_empty(&self, k: &ServerKey) -> BooleanBlock { let term = k.create_value(Self::TERMINATOR); k.k.eq_parallelized(&self.0[0].0, &term) } /// Returns `self == s`. The result is an encryption of 1 if this is the /// case and an encryption of 0 otherwise. - pub fn eq(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext { + pub fn eq(&self, k: &ServerKey, s: &FheString) -> BooleanBlock { // Compare overlapping part. let l = cmp::min(self.max_len(), s.max_len()); let a = self.substr_clear(k, 0, l); @@ -33,41 +30,38 @@ impl FheString { // Convert strings to radix integers and rely on optimized comparison. let radix_a = a.to_long_radix(); let radix_b = b.to_long_radix(); - let eq = k.k.eq_parallelized(&radix_a, &radix_b); - - // Trim exceeding radix blocks to ensure compatibility. - k.k.trim_radix_blocks_msb(&eq, eq.blocks().len() - k.num_blocks) + k.k.eq_parallelized(&radix_a, &radix_b) }, || { // Ensure that overhang is empty. match self.max_len().cmp(&s.max_len()) { cmp::Ordering::Greater => self.substr_clear(k, l, self.max_len()).is_empty(k), cmp::Ordering::Less => s.substr_clear(k, l, s.max_len()).is_empty(k), - cmp::Ordering::Equal => k.create_one(), + cmp::Ordering::Equal => k.k.create_trivial_boolean_block(true), } }, ); - binary_and(k, &overlap_eq, &overhang_empty) + k.k.boolean_bitand(&overlap_eq, &overhang_empty) } /// Returns `self != s`. The result is an encryption of 1 if this is the /// case and an encryption of 0 otherwise. - pub fn ne(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext { + pub fn ne(&self, k: &ServerKey, s: &FheString) -> BooleanBlock { let eq = self.eq(k, s); - binary_not(k, &eq) + k.k.boolean_bitnot(&eq) } /// Returns `self <= s`. The result is an encryption of 1 if this is the /// case and an encryption of 0 otherwise. - pub fn le(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext { + pub fn le(&self, k: &ServerKey, s: &FheString) -> BooleanBlock { let s_lt_self = s.lt(k, self); - binary_not(k, &s_lt_self) + k.k.boolean_bitnot(&s_lt_self) } /// Returns `self < s`. The result is an encryption of 1 if this is the case /// and an encryption of 0 otherwise. - pub fn lt(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext { + pub fn lt(&self, k: &ServerKey, s: &FheString) -> BooleanBlock { // Pad to same length. let l = cmp::max(self.max_len(), s.max_len()); let a = self.pad(k, l); @@ -89,37 +83,37 @@ impl FheString { }, ); - let mut is_lt = k.create_zero(); - let mut is_eq = k.create_one(); + let mut is_lt = k.k.create_trivial_boolean_block(false); + let mut is_eq = k.k.create_trivial_boolean_block(true); // is_lt = is_lt || ai < bi a_lt_b.iter().zip(&a_eq_b).for_each(|(ai_lt_bi, ai_eq_bi)| { // is_lt = is_lt || ai < bi && is_eq - let ai_lt_bi_and_eq = binary_and(k, ai_lt_bi, &is_eq); - is_lt = binary_or(k, &is_lt, &ai_lt_bi_and_eq); + let ai_lt_bi_and_eq = k.k.boolean_bitand(ai_lt_bi, &is_eq); + is_lt = k.k.boolean_bitor(&is_lt, &ai_lt_bi_and_eq); // is_eq = is_eq && ai == bi - is_eq = binary_and(k, &is_eq, ai_eq_bi); + is_eq = k.k.boolean_bitand(&is_eq, ai_eq_bi); }); is_lt } /// Returns `self >= s`. The result is an encryption of 1 if this is the /// case and an encryption of 0 otherwise. - pub fn ge(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext { + pub fn ge(&self, k: &ServerKey, s: &FheString) -> BooleanBlock { s.le(k, self) } /// Returns `self > s`. The result is an encryption of 1 if this is the /// case and an encryption of 0 otherwise. - pub fn gt(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext { + pub fn gt(&self, k: &ServerKey, s: &FheString) -> BooleanBlock { s.lt(k, self) } /// Returns whether `self` and `s` are equal when ignoring case. The result /// is an encryption of 1 if this is the case and an encryption of 0 /// otherwise. - pub fn eq_ignore_ascii_case(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext { + pub fn eq_ignore_ascii_case(&self, k: &ServerKey, s: &FheString) -> BooleanBlock { // Pad to same length. let l = cmp::max(self.max_len(), s.max_len()); let a = self.pad(k, l); @@ -135,11 +129,11 @@ impl FheString { }) .collect(); - binary_and_vec(k, &v) + all(k, &v) } /// Returns whether `self[i..i+s.len]` and `s` are equal. - pub fn substr_eq(&self, k: &ServerKey, i: usize, s: &FheString) -> RadixCiphertext { + pub fn substr_eq(&self, k: &ServerKey, i: usize, s: &FheString) -> BooleanBlock { // Extract substring. let a = self.substr_clear(k, i, self.max_len()); let b = s; @@ -152,7 +146,7 @@ impl FheString { .map(|(ai, bi)| { let eq = k.k.eq_parallelized(&ai.0, &bi.0); let is_term = k.k.scalar_eq_parallelized(&bi.0, Self::TERMINATOR); - k.k.bitor_parallelized(&eq, &is_term) + k.k.boolean_bitor(&eq, &is_term) }) .collect::>() }, @@ -170,7 +164,7 @@ impl FheString { } // Check if all v[i] == 1. - binary_and_vec(k, &v) + all(k, &v) } /// Returns `self[start..end]`. If `start >= self.len`, returns the empty diff --git a/src/ciphertext/convert.rs b/src/ciphertext/convert.rs index d7bc52c..0be0996 100644 --- a/src/ciphertext/convert.rs +++ b/src/ciphertext/convert.rs @@ -1,37 +1,37 @@ //! Functionality for string conversion. use rayon::prelude::*; -use tfhe::integer::RadixCiphertext; +use tfhe::integer::BooleanBlock; use crate::server_key::ServerKey; -use super::{logic::binary_and, FheAsciiChar, FheString, Uint}; +use super::{FheAsciiChar, FheString, Uint}; impl FheAsciiChar { - const CASE_DIFF: Uint = 32; + const CASE_DIFF: u8 = 32; /// Returns whether `self` is uppercase. - pub fn is_uppercase(&self, k: &ServerKey) -> RadixCiphertext { + pub fn is_uppercase(&self, k: &ServerKey) -> BooleanBlock { // (65 <= c <= 90) let c_geq_65 = k.k.scalar_ge_parallelized(&self.0, 65 as Uint); let c_leq_90 = k.k.scalar_le_parallelized(&self.0, 90 as Uint); - binary_and(k, &c_geq_65, &c_leq_90) + k.k.boolean_bitand(&c_geq_65, &c_leq_90) } /// Returns whether `self` is lowercase. - pub fn is_lowercase(&self, k: &ServerKey) -> RadixCiphertext { + pub fn is_lowercase(&self, k: &ServerKey) -> BooleanBlock { // (97 <= c <= 122) - let c_geq_97 = k.k.scalar_ge_parallelized(&self.0, 97 as Uint); - let c_leq_122 = k.k.scalar_le_parallelized(&self.0, 122 as Uint); - binary_and(k, &c_geq_97, &c_leq_122) + let c_geq_97 = k.k.scalar_ge_parallelized(&self.0, 97 as u8); + let c_leq_122 = k.k.scalar_le_parallelized(&self.0, 122 as u8); + k.k.boolean_bitand(&c_geq_97, &c_leq_122) } /// Returns the lowercase representation of `self`. pub fn to_lowercase(&self, k: &ServerKey) -> FheAsciiChar { // c + (c.uppercase ? 32 : 0) let ucase = self.is_uppercase(k); - let ucase_mul_32 = k.k.scalar_mul_parallelized(&ucase, Self::CASE_DIFF); - let lcase = k.k.add_parallelized(&self.0, &ucase_mul_32); + let self_add_32 = k.k.scalar_add_parallelized(&self.0, Self::CASE_DIFF as u8); + let lcase = k.k.if_then_else_parallelized(&ucase, &self_add_32, &self.0); FheAsciiChar(lcase) } @@ -39,8 +39,8 @@ impl FheAsciiChar { pub fn to_uppercase(&self, k: &ServerKey) -> FheAsciiChar { // c - (c.lowercase ? 32 : 0) let lcase = self.is_lowercase(k); - let lcase_mul_32 = k.k.scalar_mul_parallelized(&lcase, Self::CASE_DIFF); - let ucase = k.k.sub_parallelized(&self.0, &lcase_mul_32); + let self_sub_32 = k.k.scalar_sub_parallelized(&self.0, Self::CASE_DIFF); + let ucase = k.k.if_then_else_parallelized(&lcase, &self_sub_32, &self.0); FheAsciiChar(ucase) } } diff --git a/src/ciphertext/insert.rs b/src/ciphertext/insert.rs index d8175e9..c20ce59 100644 --- a/src/ciphertext/insert.rs +++ b/src/ciphertext/insert.rs @@ -5,7 +5,7 @@ use tfhe::integer::RadixCiphertext; use rayon::{join, prelude::*}; use crate::{ - ciphertext::{binary_if_then_else, FheAsciiChar, Uint}, + ciphertext::{FheAsciiChar, Uint}, server_key::ServerKey, }; @@ -33,8 +33,7 @@ impl FheString { let i_lt_n_mul_self_len = k.k.lt_parallelized(&i_radix, &n_mul_self_len); let i_mod_self_len = k.k.rem_parallelized(&i_radix, &self_len); let self_i_mod_self_len = self.char_at(k, &i_mod_self_len); - let vi = binary_if_then_else( - k, + let vi = k.k.if_then_else_parallelized( &i_lt_n_mul_self_len, &self_i_mod_self_len.0, &k.create_zero(), @@ -100,8 +99,7 @@ impl FheString { let c2 = (0..l) .into_par_iter() .map(|i| { - binary_if_then_else( - k, + k.k.if_then_else_parallelized( &i_lt_index_add_blen[i], &b_at_i_sub_index[i].0, &a_at_i_sub_blen[i].0, @@ -121,7 +119,7 @@ impl FheString { let c1 = &a.0[i % a.0.len()].0; // c = c0 ? c1 : c2 - let c = binary_if_then_else(k, &c0, c1, &c2[i]); + let c = k.k.if_then_else_parallelized(&c0, c1, &c2[i]); FheAsciiChar(c) }) .collect::>(); diff --git a/src/ciphertext/logic.rs b/src/ciphertext/logic.rs index 78025ee..d82a90e 100644 --- a/src/ciphertext/logic.rs +++ b/src/ciphertext/logic.rs @@ -1,55 +1,69 @@ //! Functionality for logical operations. -use tfhe::integer::RadixCiphertext; +use tfhe::integer::{ + block_decomposition::DecomposableInto, server_key::ScalarMultiplier, BooleanBlock, + IntegerCiphertext, RadixCiphertext, +}; use crate::server_key::ServerKey; use super::Uint; -// Returns `not a`, assuming `a` is an encryption of a binary value. -pub fn binary_not(k: &ServerKey, a: &RadixCiphertext) -> RadixCiphertext { - k.k.scalar_bitxor_parallelized(a, 1) +// Returns `a ? b : 0`. +pub fn if_then_else_zero(k: &ServerKey, a: &BooleanBlock, b: &RadixCiphertext) -> RadixCiphertext { + let a_radix = a.clone().into_radix(b.blocks().len(), &k.k); + k.k.mul_parallelized(&a_radix, &b) } -// Returns `a or b`, assuming `a` and `b` are encryptions of binary values. -pub fn binary_or(k: &ServerKey, a: &RadixCiphertext, b: &RadixCiphertext) -> RadixCiphertext { - k.k.bitor_parallelized(a, b) +// Returns `a ? b : 0`, where `b` is a scalar. +pub fn scalar_if_then_else_zero( + k: &ServerKey, + a: &BooleanBlock, + b: Scalar, +) -> RadixCiphertext +where + Scalar: ScalarMultiplier + DecomposableInto, +{ + let a_radix = a.clone().into_radix(k.num_blocks, &k.k); + k.k.scalar_mul_parallelized(&a_radix, b) } -// Returns `a and b`, assuming `a` and `b` are encryptions of binary values. -pub fn binary_and(k: &ServerKey, a: &RadixCiphertext, b: &RadixCiphertext) -> RadixCiphertext { - k.k.bitand_parallelized(a, b) +// Returns `a ? b : c`. +pub fn if_then_else_bool( + k: &ServerKey, + a: &BooleanBlock, + b: &BooleanBlock, + c: &BooleanBlock, +) -> BooleanBlock { + let a_and_b = k.k.boolean_bitand(a, b); + let not_a = k.k.boolean_bitnot(a); + let not_a_and_c = k.k.boolean_bitand(¬_a, c); + k.k.boolean_bitor(&a_and_b, ¬_a_and_c) } -/// Returns 1 if all elements of `v` are equal to 1, or `v.len == 0`. Otherwise -/// returns `0`. -/// -/// Expects that all elements of `v` are binary. -pub fn binary_and_vec(k: &ServerKey, v: &[RadixCiphertext]) -> RadixCiphertext { - let sum = k.k.unchecked_sum_ciphertexts_slice_parallelized(v); +// Returns true if any of the elements of `v` is true. +pub fn any(k: &ServerKey, v: &[BooleanBlock]) -> BooleanBlock { + let v: Vec = v + .iter() + .map(|vi| vi.clone().into_radix(k.num_blocks, &k.k)) + .collect(); + let sum = k.k.unchecked_sum_ciphertexts_vec_parallelized(v); match sum { - None => k.create_one(), - Some(sum) => k.k.scalar_eq_parallelized(&sum, v.len() as Uint), + None => k.k.create_trivial_boolean_block(false), + Some(sum) => k.k.scalar_gt_parallelized(&sum, 0 as Uint), } } -/// Returns 1 if an element of `v` is equal to 1. Otherwise returns `0`. -/// -/// Expects that all elements of `v` are binary. -pub fn binary_or_vec(k: &ServerKey, v: &[RadixCiphertext]) -> RadixCiphertext { - let sum = k.k.unchecked_sum_ciphertexts_slice_parallelized(v); +// Returns true if all of the elements of `v` are true. +pub fn all(k: &ServerKey, v: &[BooleanBlock]) -> BooleanBlock { + let v: Vec = v + .iter() + .map(|vi| vi.clone().into_radix(k.num_blocks, &k.k)) + .collect(); + let l = v.len(); + let sum = k.k.unchecked_sum_ciphertexts_vec_parallelized(v); match sum { - None => k.create_zero(), - Some(sum) => k.k.scalar_gt_parallelized(&sum, 0 as Uint), + None => k.k.create_trivial_boolean_block(true), + Some(sum) => k.k.scalar_eq_parallelized(&sum, l as Uint), } } - -// Returns `a ? b : c`, assuming `a` is an encryption of a binary value. -pub fn binary_if_then_else( - k: &ServerKey, - a: &RadixCiphertext, - b: &RadixCiphertext, - c: &RadixCiphertext, -) -> RadixCiphertext { - k.k.if_then_else_parallelized(a, b, c) -} diff --git a/src/ciphertext/mod.rs b/src/ciphertext/mod.rs index 736cfb6..6273524 100644 --- a/src/ciphertext/mod.rs +++ b/src/ciphertext/mod.rs @@ -3,12 +3,14 @@ use std::error::Error; use crate::{ - ciphertext::logic::{binary_and, binary_if_then_else, binary_not, binary_or}, + ciphertext::logic::if_then_else_zero, client_key::{ClientKey, Key}, server_key::ServerKey, }; use rayon::prelude::*; -use tfhe::integer::RadixCiphertext; +use tfhe::integer::{BooleanBlock, RadixCiphertext}; + +use self::logic::{any, scalar_if_then_else_zero}; mod compare; mod convert; @@ -138,9 +140,10 @@ impl FheString { let self_i = &pair[1]; let self_isub1_neq_0 = k.k.scalar_ne_parallelized(&self_isub1.0, Self::TERMINATOR); let self_i_eq_0 = k.k.scalar_eq_parallelized(&self_i.0, Self::TERMINATOR); - let b = binary_and(k, &self_isub1_neq_0, &self_i_eq_0); + let b = k.k.boolean_bitand(&self_isub1_neq_0, &self_i_eq_0); let i = i_sub_1 + 1; - k.k.scalar_mul_parallelized(&b, i as Uint) + let i_radix = k.create_value(i as Uint); + if_then_else_zero(k, &b, &i_radix) }) .collect::>(); @@ -166,7 +169,7 @@ impl FheString { // a[i] = i < index ? a[i] : 0 let i_lt_index = k.k.scalar_gt_parallelized(index, i as Uint); - let ai = binary_if_then_else(k, &i_lt_index, &ai.0, &term); + let ai = k.k.if_then_else_parallelized(&i_lt_index, &ai.0, &term); FheAsciiChar(ai) }) .collect(); @@ -204,12 +207,7 @@ impl FheString { let i_add_index = k.k.scalar_add_parallelized(start, i as Uint); let i_add_index_lt_end = k.k.lt_parallelized(&i_add_index, end); let self_i_add_index = self.char_at(k, &i_add_index); - let ai = binary_if_then_else( - k, - &i_add_index_lt_end, - &self_i_add_index.0, - &k.create_zero(), - ); + let ai = if_then_else_zero(k, &i_add_index_lt_end, &self_i_add_index.0); FheAsciiChar(ai) }) .collect(); @@ -230,12 +228,12 @@ impl FheString { // i == j ? a[j] : 0 let i_eq_j = k.k.scalar_eq_parallelized(i, j as Uint); - k.k.mul_parallelized(&i_eq_j, &aj.0) + if_then_else_zero(k, &i_eq_j, &aj.0) }) .collect::>(); let ai = - k.k.unchecked_sum_ciphertexts_slice_parallelized(&v) + k.k.unchecked_sum_ciphertexts_vec_parallelized(v) .unwrap_or(k.create_zero()); FheAsciiChar(ai) } @@ -270,7 +268,7 @@ impl FheString { } /// Given `v` and `Enc(i)`, return `v[i]`. Returns `0` if `i` is out of bounds. -pub fn element_at(k: &ServerKey, v: &[RadixCiphertext], i: &RadixCiphertext) -> RadixCiphertext { +pub fn element_at_bool(k: &ServerKey, v: &[BooleanBlock], i: &RadixCiphertext) -> BooleanBlock { // ai = i == 0 ? a[0] : 0 + ... + i == n ? a[n] : 0 let v = v .par_iter() @@ -281,12 +279,11 @@ pub fn element_at(k: &ServerKey, v: &[RadixCiphertext], i: &RadixCiphertext) -> // i == j ? a[j] : 0 let i_eq_j = k.k.scalar_eq_parallelized(i, j as Uint); - k.k.mul_parallelized(&i_eq_j, aj) + k.k.boolean_bitand(&i_eq_j, aj) }) .collect::>(); - k.k.unchecked_sum_ciphertexts_slice_parallelized(&v) - .unwrap_or(k.create_zero()) + any(k, &v) } /// Searches `v` for the first index `i` with `p(v[i]) == 1`. @@ -295,7 +292,7 @@ pub fn element_at(k: &ServerKey, v: &[RadixCiphertext], i: &RadixCiphertext) -> pub fn index_of_unchecked( k: &ServerKey, v: &[T], - p: fn(&ServerKey, &T) -> RadixCiphertext, + p: fn(&ServerKey, &T) -> BooleanBlock, ) -> FheOption { index_of_unchecked_with_options(k, v, p, false) } @@ -306,7 +303,7 @@ pub fn index_of_unchecked( pub fn rindex_of_unchecked( k: &ServerKey, v: &[T], - p: fn(&ServerKey, &T) -> RadixCiphertext, + p: fn(&ServerKey, &T) -> BooleanBlock, ) -> FheOption { index_of_unchecked_with_options(k, v, p, true) } @@ -318,12 +315,11 @@ pub fn rindex_of_unchecked( fn index_of_unchecked_with_options( k: &ServerKey, v: &[T], - p: fn(&ServerKey, &T) -> RadixCiphertext, + p: fn(&ServerKey, &T) -> BooleanBlock, reverse: bool, ) -> FheOption { - let zero = k.create_zero(); - let mut b = zero.clone(); // Pattern contained. - let mut index = zero.clone(); // Pattern index. + let mut b = k.k.create_trivial_boolean_block(false); // Pattern contained. + let mut index = k.create_zero(); // Pattern index. let items: Vec<_> = if reverse { v.iter().enumerate().rev().collect() @@ -336,7 +332,7 @@ fn index_of_unchecked_with_options( .par_iter() .map(|(i, x)| { let pi = p(k, x); - let pi_mul_i = k.k.scalar_mul_parallelized(&pi, *i as Uint); + let pi_mul_i = scalar_if_then_else_zero(k, &pi, *i as Uint); (i, pi, pi_mul_i) }) .collect(); @@ -346,10 +342,10 @@ fn index_of_unchecked_with_options( log::trace!("index_of_opt_unchecked: at index {i}"); // index = b ? index : (pi ? i : 0) - index = binary_if_then_else(k, &b, &index, &pi_mul_i); + index = k.k.if_then_else_parallelized(&b, &index, &pi_mul_i); // b = b || pi - b = binary_or(k, &b, &pi); + b = k.k.boolean_bitor(&b, &pi); }); FheOption { @@ -361,35 +357,33 @@ fn index_of_unchecked_with_options( /// FheOption represents an encrypted option type. pub struct FheOption { /// Whether this option decrypts to `Some` or `None`. - pub is_some: RadixCiphertext, + pub is_some: BooleanBlock, /// The optional value. pub val: T, } impl FheOption { pub fn decrypt(&self, k: &ClientKey) -> Option { - let is_some = k.decrypt::(&self.is_some); + let is_some = k.0.decrypt_bool(&self.is_some); match is_some { - 0 => None, - 1 => { + true => { let val = k.decrypt::(&self.val); Some(val) } - _ => panic!("expected 0 or 1, got {}", is_some), + false => None, } } } impl FheOption { pub fn decrypt(&self, k: &ClientKey) -> Option { - let is_some = k.decrypt::(&self.is_some); + let is_some = k.0.decrypt_bool(&self.is_some); match is_some { - 0 => None, - 1 => { + true => { let val = self.val.decrypt(k); Some(val) } - _ => panic!("expected 0 or 1, got {}", is_some), + false => None, } } } diff --git a/src/ciphertext/replace.rs b/src/ciphertext/replace.rs index d62e1e5..9cc7dbd 100644 --- a/src/ciphertext/replace.rs +++ b/src/ciphertext/replace.rs @@ -1,9 +1,13 @@ //! Functionality for string replacement. -use tfhe::integer::RadixCiphertext; +use tfhe::integer::{IntegerCiphertext, RadixCiphertext}; use crate::{ - ciphertext::{binary_and, binary_if_then_else, element_at, Uint}, + ciphertext::{ + element_at_bool, + logic::{if_then_else_bool, if_then_else_zero}, + Uint, + }, server_key::ServerKey, }; @@ -55,11 +59,10 @@ impl FheString { v[i] = in_match ? s[j] : self[c] j += 1 */ - let mut in_match = k.create_zero(); + let mut in_match = k.k.create_trivial_boolean_block(false); let mut j = k.create_zero(); let mut n = k.create_zero(); let mut v = Vec::::new(); - let zero = k.create_zero(); (0..l).for_each(|i| { log::trace!("replace_nopt: at index {i}"); @@ -68,25 +71,29 @@ impl FheString { let c = k.k.scalar_add_parallelized(&n_mul_lendiff, i as Uint); let j_lt_slen = k.k.lt_parallelized(&j, &s_len); - let match_and_jltslen = binary_and(k, &in_match, &j_lt_slen); + let match_and_jltslen = k.k.boolean_bitand(&in_match, &j_lt_slen); - let found_c = element_at(k, &found, &c); + let found_c = element_at_bool(k, &found, &c); let foundc_and_n_lt_nmax = match n_max { Some(n_max) => { let n_lt_nmax = k.k.lt_parallelized(&n, n_max); - binary_and(k, &found_c, &n_lt_nmax) + k.k.boolean_bitand(&found_c, &n_lt_nmax) } None => found_c, }; - let n_add_found_c = k.k.add_parallelized(&n, &foundc_and_n_lt_nmax); + let foundc_and_n_lt_nmax_radix = foundc_and_n_lt_nmax + .clone() + .into_radix(n.blocks().len(), &k.k); + let n_add_found_c = k.k.add_parallelized(&n, &foundc_and_n_lt_nmax_radix); - in_match = binary_if_then_else(k, &match_and_jltslen, &in_match, &foundc_and_n_lt_nmax); - j = binary_if_then_else(k, &match_and_jltslen, &j, &zero); - n = binary_if_then_else(k, &match_and_jltslen, &n, &n_add_found_c); + in_match = if_then_else_bool(k, &match_and_jltslen, &in_match, &foundc_and_n_lt_nmax); + j = if_then_else_zero(k, &match_and_jltslen, &j); + n = + k.k.if_then_else_parallelized(&match_and_jltslen, &n, &n_add_found_c); let sj = s.char_at(k, &j).0; let self_c = self.char_at(k, &c).0; - let vi = binary_if_then_else(k, &in_match, &sj, &self_c); + let vi = k.k.if_then_else_parallelized(&in_match, &sj, &self_c); v.push(FheAsciiChar(vi)); j = k.k.scalar_add_parallelized(&j, 1 as Uint); diff --git a/src/ciphertext/search.rs b/src/ciphertext/search.rs index 0d4f9ed..b6efbe9 100644 --- a/src/ciphertext/search.rs +++ b/src/ciphertext/search.rs @@ -1,24 +1,22 @@ //! Functionality for string search. use rayon::{join, prelude::*}; -use tfhe::integer::RadixCiphertext; +use tfhe::integer::{BooleanBlock, RadixCiphertext}; -use crate::{ - ciphertext::{binary_if_then_else, Uint}, - server_key::ServerKey, -}; +use crate::{ciphertext::Uint, server_key::ServerKey}; use super::{ - binary_and, index_of_unchecked, logic::binary_or_vec, rindex_of_unchecked, FheAsciiChar, - FheOption, FheString, + index_of_unchecked, + logic::{any, if_then_else_bool, if_then_else_zero}, + rindex_of_unchecked, FheAsciiChar, FheOption, FheString, }; impl FheString { /// Returns whether `self` contains the string `s`. The result is an /// encryption of 1 if this is the case and an encryption of 0 otherwise. - pub fn contains(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext { + pub fn contains(&self, k: &ServerKey, s: &FheString) -> BooleanBlock { let found = self.find_all(k, s); - binary_or_vec(k, &found) + any(k, &found) } /// Returns the index of the first occurrence of `s`, if existent. @@ -32,7 +30,7 @@ impl FheString { /// Returns a vector v of length self.max_len where the i-th entry is an /// encryption of 1 if the substring of self starting from i matches s, and /// an encryption of 0 otherwise. - pub(super) fn find_all(&self, k: &ServerKey, s: &FheString) -> Vec { + pub(super) fn find_all(&self, k: &ServerKey, s: &FheString) -> Vec { (0..self.0.len() - 1) .into_par_iter() .map(|i| { @@ -43,13 +41,11 @@ impl FheString { } /// Returns a vector v of length self.max_len where `v[i] = p(self[i])`. - /// - /// `p` is expected to return an encryption of either 0 or 1. pub(super) fn find_all_pred_unchecked( &self, k: &ServerKey, - p: fn(&ServerKey, &FheAsciiChar) -> RadixCiphertext, - ) -> Vec { + p: fn(&ServerKey, &FheAsciiChar) -> BooleanBlock, + ) -> Vec { self.0.par_iter().map(|c| p(k, c)).collect::>() } @@ -58,7 +54,7 @@ impl FheString { pub(super) fn find_all_next_pred_unchecked( &self, k: &ServerKey, - p: fn(&ServerKey, &FheAsciiChar) -> RadixCiphertext, + p: fn(&ServerKey, &FheAsciiChar) -> BooleanBlock, ) -> Vec> { self.0 .par_iter() @@ -73,7 +69,7 @@ impl FheString { &self, k: &ServerKey, s: &FheString, - ) -> Vec { + ) -> Vec { let matches = self.find_all(k, s); let s_len = s.len(k); @@ -90,22 +86,23 @@ impl FheString { j += 1 in_match = in_match && j < s.len */ - let mut in_match = k.create_zero(); + let mut in_match = k.k.create_trivial_boolean_block(false); let mut j = k.create_zero(); matches .iter() .map(|mi| { // (matches[i], in_match, j) = in_match ? (0, in_match, j) : (matches[i], matches[i], 0) - let mi_out = binary_if_then_else(k, &in_match, &k.create_zero(), mi); - j = binary_if_then_else(k, &in_match, &j, &k.create_zero()); - in_match = binary_if_then_else(k, &in_match, &in_match, mi); + let mi_out = + if_then_else_bool(k, &in_match, &k.k.create_trivial_boolean_block(false), mi); + j = if_then_else_zero(k, &in_match, &j); + in_match = if_then_else_bool(k, &in_match, &in_match, mi); // j += 1 k.k.scalar_add_assign_parallelized(&mut j, 1 as Uint); // in_match = in_match && j < s.len let j_lt_slen = k.k.lt_parallelized(&j, &s_len); - in_match = binary_and(k, &in_match, &j_lt_slen); + in_match = k.k.boolean_bitand(&in_match, &j_lt_slen); mi_out }) @@ -118,7 +115,7 @@ impl FheString { &self, k: &ServerKey, s: &FheString, - ) -> Vec { + ) -> Vec { let matches = self.find_all(k, s); let s_len = s.len(k); @@ -139,11 +136,12 @@ impl FheString { .map(|mi| { // m[i] = j < s.len ? 0 : m[i] let j_lt_slen = k.k.lt_parallelized(&j, &s_len); - let mi = binary_if_then_else(k, &j_lt_slen, &zero, mi); + let mi = + if_then_else_bool(k, &j_lt_slen, &k.k.create_trivial_boolean_block(false), mi); // j = j >= s.len && m[i] ? 0 : j - let j_lt_slen_and_mi = binary_and(k, &j_lt_slen, &mi); - j = binary_if_then_else(k, &j_lt_slen_and_mi, &zero, &j); + let j_lt_slen_and_mi = k.k.boolean_bitand(&j_lt_slen, &mi); + j = k.k.if_then_else_parallelized(&j_lt_slen_and_mi, &zero, &j); // j += 1 k.k.scalar_add_assign_parallelized(&mut j, 1 as Uint); @@ -163,8 +161,15 @@ impl FheString { // If empty pattern, return length. Otherwise return last index. let empty = s.is_empty(k); FheOption { - is_some: binary_if_then_else(k, &empty, &k.create_one(), &last.is_some), - val: binary_if_then_else(k, &empty, &self.len(k), &last.val), + is_some: if_then_else_bool( + k, + &empty, + &k.k.create_trivial_boolean_block(true), + &last.is_some, + ), + val: k + .k + .if_then_else_parallelized(&empty, &self.len(k), &last.val), } } @@ -175,7 +180,7 @@ impl FheString { &self, k: &ServerKey, i: usize, - p: fn(&ServerKey, &FheAsciiChar) -> RadixCiphertext, + p: fn(&ServerKey, &FheAsciiChar) -> BooleanBlock, ) -> FheOption { // Search substring. let subvec = &self.0.get(i..).unwrap_or_default(); @@ -189,37 +194,32 @@ impl FheString { } /// Searches `self` for the first index `i` with `p(self[i]) == 1`. - /// - /// Expects that `p` returns an encryption of either 0 or 1. pub fn find_pred_unchecked( &self, k: &ServerKey, - p: fn(&ServerKey, &FheAsciiChar) -> RadixCiphertext, + p: fn(&ServerKey, &FheAsciiChar) -> BooleanBlock, ) -> FheOption { index_of_unchecked(k, &self.0, p) } /// Searches `self` for the first index `i` with `p(self[i]) == 1` in /// reverse direction. - /// - /// Expects that `p` returns an encryption of either 0 or 1. pub fn rfind_pred_unchecked( &self, k: &ServerKey, - p: fn(&ServerKey, &FheAsciiChar) -> RadixCiphertext, + p: fn(&ServerKey, &FheAsciiChar) -> BooleanBlock, ) -> FheOption { rindex_of_unchecked(k, &self.0, p) } - /// Returns whether `self` starts with the string `s`. The result is an - /// encryption of 1 if this is the case and an encryption of 0 otherwise. - pub fn starts_with(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext { + /// Returns whether `self` starts with the string `s`. + pub fn starts_with(&self, k: &ServerKey, s: &FheString) -> BooleanBlock { self.substr_eq(k, 0, s) } /// Returns whether `self` ends with the string `s`. The result is an /// encryption of 1 if this is the case and an encryption of 0 otherwise. - pub fn ends_with(&self, k: &ServerKey, s: &FheString) -> RadixCiphertext { + pub fn ends_with(&self, k: &ServerKey, s: &FheString) -> BooleanBlock { let opti = self.rfind(k, s); // is_end = self.len == i + s.len @@ -228,6 +228,6 @@ impl FheString { let is_end = k.k.eq_parallelized(&self_len, &i_add_s_len); // ends_with = contained && is_end - binary_and(k, &opti.is_some, &is_end) + k.k.boolean_bitand(&opti.is_some, &is_end) } } diff --git a/src/ciphertext/split.rs b/src/ciphertext/split.rs index f8e1ac0..681fbf8 100644 --- a/src/ciphertext/split.rs +++ b/src/ciphertext/split.rs @@ -1,21 +1,11 @@ //! Functionality for string splitting. use rayon::prelude::*; -use tfhe::integer::RadixCiphertext; - -use crate::{ - ciphertext::{ - element_at, - logic::{binary_not, binary_or}, - }, - client_key::ClientKey, - server_key::ServerKey, -}; - -use super::{ - logic::{binary_and, binary_if_then_else}, - FheAsciiChar, FheOption, FheString, Uint, -}; +use tfhe::integer::{IntegerCiphertext, RadixCiphertext}; + +use crate::{ciphertext::element_at_bool, client_key::ClientKey, server_key::ServerKey}; + +use super::{FheAsciiChar, FheOption, FheString, Uint}; /// An element of an `FheStringSliceVector`. #[derive(Clone)] @@ -43,7 +33,7 @@ impl FheStringSliceVector { let v = self .v .par_iter() - .map(|vi| vi.is_some.clone()) + .map(|vi| vi.is_some.clone().into_radix(k.num_blocks, &k.k)) .collect::>(); k.k.unchecked_sum_ciphertexts_vec_parallelized(v) .unwrap_or(k.create_zero()) @@ -54,7 +44,7 @@ impl FheStringSliceVector { let mut n = k.create_zero(); let init = FheOption { - is_some: k.create_zero(), + is_some: k.k.create_trivial_boolean_block(false), val: FheStringSlice { start: k.create_zero(), end: k.create_zero(), @@ -64,17 +54,20 @@ impl FheStringSliceVector { let slice = self.v.iter().enumerate().fold(init, |acc, (j, vi)| { // acc = i == n && vi.is_some ? (j, vi.end) : acc let i_eq_n = k.k.eq_parallelized(i, &n); - let is_some = binary_and(k, &i_eq_n, &vi.is_some); + let is_some = k.k.boolean_bitand(&i_eq_n, &vi.is_some); let j_radix = k.create_value(j as Uint); - let start = binary_if_then_else(k, &is_some, &j_radix, &acc.val.start); - let end = binary_if_then_else(k, &is_some, &vi.val.end, &acc.val.end); + let start = + k.k.if_then_else_parallelized(&is_some, &j_radix, &acc.val.start); + let end = + k.k.if_then_else_parallelized(&is_some, &vi.val.end, &acc.val.end); let acc = FheOption { is_some, val: FheStringSlice { start, end }, }; // n += vi.is_some - k.k.add_assign_parallelized(&mut n, &vi.is_some); + let is_some_radix = vi.is_some.clone().into_radix(n.blocks().len(), &k.k); + k.k.add_assign_parallelized(&mut n, &is_some_radix); acc }); @@ -95,7 +88,6 @@ impl FheStringSliceVector { n += v[i].is_start */ let mut n = k.create_zero(); - let zero = k.create_zero(); self.v = self .v @@ -103,10 +95,11 @@ impl FheStringSliceVector { .map(|vi| { // is_some = n < i ? vi.is_some : 0 let n_lt_i = k.k.lt_parallelized(&n, i); - let is_some = binary_if_then_else(k, &n_lt_i, &vi.is_some, &zero); + let is_some = k.k.boolean_bitand(&n_lt_i, &vi.is_some); // n += v[i].is_some - k.k.add_assign_parallelized(&mut n, &vi.is_some); + let is_some_radix = vi.is_some.clone().into_radix(n.blocks().len(), &k.k); + k.k.add_assign_parallelized(&mut n, &is_some_radix); FheOption { is_some, @@ -118,7 +111,7 @@ impl FheStringSliceVector { /// Truncate the last element if it is empty. fn truncate_last_if_empty(&mut self, k: &ServerKey) { - let mut b = k.create_one(); + let mut b = k.k.create_trivial_boolean_block(true); let mut v = self .v .iter() @@ -128,14 +121,14 @@ impl FheStringSliceVector { let is_empty = k.k.ge_parallelized(&vi.val.start, &vi.val.end); // is_some = b && vi.is_some && is_empty ? 0 : vi.is_some - let b_and_start = binary_and(k, &b, &vi.is_some); - let b_and_start_and_empty = binary_and(k, &b_and_start, &is_empty); - let is_some = - binary_if_then_else(k, &b_and_start_and_empty, &k.create_zero(), &vi.is_some); + let b_and_start = k.k.boolean_bitand(&b, &vi.is_some); + let b_and_start_and_empty = k.k.boolean_bitand(&b_and_start, &is_empty); + let not_b_and_start_and_empty = k.k.boolean_bitnot(&b_and_start_and_empty); + let is_some = k.k.boolean_bitand(¬_b_and_start_and_empty, &vi.is_some); // b = b && !vi.is_some - let not_start = binary_not(k, &vi.is_some); - b = binary_and(k, &b, ¬_start); + let not_start = k.k.boolean_bitnot(&vi.is_some); + b = k.k.boolean_bitand(&b, ¬_start); FheOption { is_some, @@ -150,19 +143,20 @@ impl FheStringSliceVector { /// Expand the first slice to the beginning of the string. fn expand_first(&mut self, k: &ServerKey) { // Find the first item and set its start point to 0. - let mut not_found = k.create_one(); + let mut not_found = k.k.create_trivial_boolean_block(true); let zero = k.create_zero(); self.v = self .v .iter() .map(|vi| { // start = not_found && vi.is_some ? 0 : vi.start - let not_found_and_some = binary_and(k, ¬_found, &vi.is_some); - let start = binary_if_then_else(k, ¬_found_and_some, &zero, &vi.val.start); + let not_found_and_some = k.k.boolean_bitand(¬_found, &vi.is_some); + let start = + k.k.if_then_else_parallelized(¬_found_and_some, &zero, &vi.val.start); // not_found = not_found && !vi.is_some - let not_some = binary_not(k, &vi.is_some); - not_found = binary_and(k, ¬_found, ¬_some); + let not_some = k.k.boolean_bitnot(&vi.is_some); + not_found = k.k.boolean_bitand(¬_found, ¬_some); FheOption { is_some: vi.is_some.clone(), @@ -178,7 +172,7 @@ impl FheStringSliceVector { /// Expand the last slice to the end of the string. fn expand_last(&mut self, k: &ServerKey) { // Find the last item and set its end point to s.len. - let mut not_found = k.create_one(); + let mut not_found = k.k.create_trivial_boolean_block(true); let self_len = self.s.len(k); let mut v = self .v @@ -186,12 +180,13 @@ impl FheStringSliceVector { .rev() .map(|vi| { // end = not_found && vi.is_some ? self.s.len : vi.end - let not_found_and_some = binary_and(k, ¬_found, &vi.is_some); - let end = binary_if_then_else(k, ¬_found_and_some, &self_len, &vi.val.end); + let not_found_and_some = k.k.boolean_bitand(¬_found, &vi.is_some); + let end = + k.k.if_then_else_parallelized(¬_found_and_some, &self_len, &vi.val.end); // not_found = not_found && !vi.is_some - let not_some = binary_not(k, &vi.is_some); - not_found = binary_and(k, ¬_found, ¬_some); + let not_some = k.k.boolean_bitnot(&vi.is_some); + not_found = k.k.boolean_bitand(¬_found, ¬_some); FheOption { is_some: vi.is_some.clone(), @@ -212,10 +207,10 @@ impl FheStringSliceVector { self.v .iter() .filter_map(|vi| { - let is_some = k.0.decrypt::(&vi.is_some); + let is_some = k.0.decrypt_bool(&vi.is_some); match is_some { - 0 => None, - _ => { + false => None, + true => { let start = k.0.decrypt::(&vi.val.start) as usize; let end = k.0.decrypt::(&vi.val.end) as usize; let slice = s_dec.get(start..end).unwrap_or_default(); @@ -273,7 +268,6 @@ impl FheString { let n = self.max_len() + 2; // Maximum number of entries. let n_hidden = k.k.scalar_add_parallelized(&self_len, 2 as Uint); // Better bound based on hidden length. let mut next_match = self_len.clone(); - let zero = k.create_zero(); let mut elems = (0..n) .rev() .map(|i| { @@ -281,13 +275,13 @@ impl FheString { // is_some_i = i == 0 || matches[i - p.len] && i < self.len + 2 let is_some = if i == 0 { - k.create_one() + k.k.create_trivial_boolean_block(true) } else { let i_radix = k.create_value(i as Uint); let i_sub_plen = k.k.sub_parallelized(&i_radix, &p_len); - let mi = element_at(k, &matches, &i_sub_plen); + let mi = element_at_bool(k, &matches, &i_sub_plen); let i_lt_n_hidden = k.k.scalar_gt_parallelized(&n_hidden, i as Uint); - binary_if_then_else(k, &i_lt_n_hidden, &mi, &zero) + k.k.boolean_bitand(&i_lt_n_hidden, &mi) }; // next_match_target = i + (inclusive ? p.len : 0) @@ -298,12 +292,15 @@ impl FheString { }; // next_match[i] = matches[i] ? next_match_target : next_match[i+1] - let matches_i = matches.get(i).unwrap_or(&zero); - next_match = binary_if_then_else(k, matches_i, &next_match_target, &next_match); + let false_block = k.k.create_trivial_boolean_block(false); + let matches_i = matches.get(i).unwrap_or(&false_block); + next_match = + k.k.if_then_else_parallelized(matches_i, &next_match_target, &next_match); // start = max(i - p.empty, 0) let start = if i > 0 { - k.k.sub_parallelized(&k.create_value(i as Uint), &pattern_empty) + let pattern_empty_radix = pattern_empty.clone().into_radix(k.num_blocks, &k.k); + k.k.sub_parallelized(&k.create_value(i as Uint), &pattern_empty_radix) } else { k.create_zero() }; @@ -435,14 +432,14 @@ impl FheString { let w = c.is_whitespace(k); // Also check for string termination character. let z = k.k.scalar_eq_parallelized(&c.0, FheString::TERMINATOR); - binary_or(k, &w, &z) + k.k.boolean_bitor(&w, &z) }; let whitespace = self.find_all_pred_unchecked(k, is_whitespace); let next_whitespace = self.find_all_next_pred_unchecked(k, is_whitespace); let zero = k.create_zero(); let opt_default = FheOption { - is_some: zero.clone(), + is_some: k.k.create_trivial_boolean_block(false), val: zero.clone(), }; @@ -453,18 +450,22 @@ impl FheString { .enumerate() .map(|(i, _)| { // is_some = !whitespace[i] && (i == 0 || whitespace[i-1]); - let not_whitespace = binary_not(k, &whitespace[i]); + let not_whitespace = k.k.boolean_bitnot(&whitespace[i]); let i_eq_0_or_prev_whitespace = if i == 0 { - k.create_one() + k.k.create_trivial_boolean_block(true) } else { whitespace[i - 1].clone() }; - let is_some = binary_and(k, ¬_whitespace, &i_eq_0_or_prev_whitespace); + let is_some = + k.k.boolean_bitand(¬_whitespace, &i_eq_0_or_prev_whitespace); // end = s.index_of_next_white_space_or_max_len(i+1); let index_of_next = next_whitespace.get(i + 1).unwrap_or(&opt_default); - let end = - binary_if_then_else(k, &index_of_next.is_some, &index_of_next.val, &self_len); + let end = k.k.if_then_else_parallelized( + &index_of_next.is_some, + &index_of_next.val, + &self_len, + ); FheOption { is_some, diff --git a/src/ciphertext/tests/mod.rs b/src/ciphertext/tests/mod.rs index ed1127d..1debe0b 100644 --- a/src/ciphertext/tests/mod.rs +++ b/src/ciphertext/tests/mod.rs @@ -1,4 +1,7 @@ -use tfhe::{integer::RadixCiphertext, shortint::prelude::PARAM_MESSAGE_2_CARRY_2_KS_PBS}; +use tfhe::{ + integer::{BooleanBlock, RadixCiphertext}, + shortint::prelude::PARAM_MESSAGE_2_CARRY_2_KS_PBS, +}; use crate::{client_key::ClientKey, generate_keys_with_params, server_key::ServerKey}; @@ -25,40 +28,29 @@ fn encrypt_int(k: &ClientKey, n: u64) -> RadixCiphertext { k.0.encrypt(n) } -fn int_to_bool(x: u64) -> bool { - match x { - 0 => false, - 1 => true, - _ => panic!("expected 0 or 1, got {}", x), - } -} - -fn decrypt_bool(k: &ClientKey, b: &RadixCiphertext) -> bool { - let x = k.0.decrypt::(b); - int_to_bool(x) +fn decrypt_bool(k: &ClientKey, b: &BooleanBlock) -> bool { + k.0.decrypt_bool(b) } fn decrypt_option_int(k: &ClientKey, opt: &FheOption) -> Option { - let is_some = k.0.decrypt::(&opt.is_some); + let is_some = k.0.decrypt_bool(&opt.is_some); match is_some { - 0 => None, - 1 => { + false => None, + true => { let val = k.0.decrypt::(&opt.val); Some(val as usize) } - _ => panic!("expected 0 or 1, got {}", is_some), } } fn decrypt_option_string(k: &ClientKey, opt: &FheOption) -> Option { - let is_some = k.0.decrypt::(&opt.is_some); + let is_some = k.0.decrypt_bool(&opt.is_some); match is_some { - 0 => None, - 1 => { + false => None, + true => { let val = opt.val.decrypt(k); Some(val) } - _ => panic!("expected 0 or 1, got {}", is_some), } } diff --git a/src/ciphertext/tests/split.rs b/src/ciphertext/tests/split.rs index 1506c2e..e5f27a9 100644 --- a/src/ciphertext/tests/split.rs +++ b/src/ciphertext/tests/split.rs @@ -483,12 +483,12 @@ fn split_once() { .map(|v| (v.0.to_string(), v.1.to_string())); let opt_v_enc = input_enc.split_once(&server_key, &pattern_enc); - let b_dec = client_key.0.decrypt::(&opt_v_enc.is_some); + let b_dec = client_key.0.decrypt_bool(&opt_v_enc.is_some); let val0_dec = opt_v_enc.val.0.decrypt(&client_key); let val1_dec = opt_v_enc.val.1.decrypt(&client_key); let opt_v_dec = match b_dec { - 0 => None, - _ => Some((val0_dec, val1_dec)), + false => None, + true => Some((val0_dec, val1_dec)), }; println!("{:?}", t); @@ -543,12 +543,12 @@ fn rsplit_once() { .map(|v| (v.0.to_string(), v.1.to_string())); let opt_v_enc = input_enc.rsplit_once(&server_key, &pattern_enc); - let b_dec = client_key.0.decrypt::(&opt_v_enc.is_some); + let b_dec = client_key.0.decrypt_bool(&opt_v_enc.is_some); let val0_dec = opt_v_enc.val.0.decrypt(&client_key); let val1_dec = opt_v_enc.val.1.decrypt(&client_key); let opt_v_dec = match b_dec { - 0 => None, - _ => Some((val0_dec, val1_dec)), + false => None, + true => Some((val0_dec, val1_dec)), }; println!("{:?}", t); diff --git a/src/ciphertext/trim.rs b/src/ciphertext/trim.rs index 2ef0012..919a3b6 100644 --- a/src/ciphertext/trim.rs +++ b/src/ciphertext/trim.rs @@ -1,27 +1,27 @@ //! Functionality for string trimming. use rayon::{join, prelude::*}; -use tfhe::integer::RadixCiphertext; +use tfhe::integer::{BooleanBlock, RadixCiphertext}; use crate::server_key::ServerKey; use super::{ - binary_and, binary_if_then_else, binary_not, binary_or, index_of_unchecked, - rindex_of_unchecked, FheAsciiChar, FheOption, FheString, Uint, + index_of_unchecked, logic::if_then_else_zero, rindex_of_unchecked, FheAsciiChar, FheOption, + FheString, Uint, }; impl FheAsciiChar { /// Returns whether this is a whitespace character. - pub fn is_whitespace(&self, k: &ServerKey) -> RadixCiphertext { + pub fn is_whitespace(&self, k: &ServerKey) -> BooleanBlock { // Whitespace characters: 9 (Horizontal tab), 10 (Line feed), 11 // (Vertical tab), 12 (Form feed), 13 (Carriage return), 32 (Space) // (9 <= c <= 13) || c == 32 let c_geq_9 = k.k.scalar_ge_parallelized(&self.0, 9 as Uint); let c_leq_13 = k.k.scalar_le_parallelized(&self.0, 13 as Uint); - let c_geq_9_and_c_leq_13 = binary_and(k, &c_geq_9, &c_leq_13); + let c_geq_9_and_c_leq_13 = k.k.boolean_bitand(&c_geq_9, &c_leq_13); let c_eq_32 = k.k.scalar_eq_parallelized(&self.0, 32 as Uint); - binary_or(k, &c_geq_9_and_c_leq_13, &c_eq_32) + k.k.boolean_bitor(&c_geq_9_and_c_leq_13, &c_eq_32) } } @@ -31,10 +31,10 @@ impl FheString { pub fn trim_start(&self, k: &ServerKey) -> FheString { let i_opt = self.find_pred_unchecked(k, |k, c| { let is_whitespace = c.is_whitespace(k); - binary_not(k, &is_whitespace) + k.k.boolean_bitnot(&is_whitespace) }); - let i = binary_if_then_else(k, &i_opt.is_some, &i_opt.val, &k.create_zero()); + let i = if_then_else_zero(k, &i_opt.is_some, &i_opt.val); self.substr_from(k, &i) } @@ -44,15 +44,15 @@ impl FheString { let i_opt = self.rfind_pred_unchecked(k, |k, c| { // !is_terminator(c) && !is_whitespace(c) let is_term = k.k.scalar_eq_parallelized(&c.0, Self::TERMINATOR); - let not_term = binary_not(k, &is_term); + let not_term = k.k.boolean_bitnot(&is_term); let is_whitespace = c.is_whitespace(k); - let not_whitespace = binary_not(k, &is_whitespace); - binary_and(k, ¬_term, ¬_whitespace) + let not_whitespace = k.k.boolean_bitnot(&is_whitespace); + k.k.boolean_bitand(¬_term, ¬_whitespace) }); // i = i_opt.is_some ? i_opt.val + 1 : 0 let val_add_1 = k.k.scalar_add_parallelized(&i_opt.val, 1); - let i = binary_if_then_else(k, &i_opt.is_some, &val_add_1, &k.create_zero()); + let i = if_then_else_zero(k, &i_opt.is_some, &val_add_1); self.truncate(k, &i) } @@ -63,10 +63,10 @@ impl FheString { let found = self.find_all_pred_unchecked(k, |k, c| { // !is_terminator(c) && !is_whitespace(c) let is_term = k.k.scalar_eq_parallelized(&c.0, Self::TERMINATOR); - let not_term = binary_not(k, &is_term); + let not_term = k.k.boolean_bitnot(&is_term); let is_whitespace = c.is_whitespace(k); - let not_whitespace = binary_not(k, &is_whitespace); - binary_and(k, ¬_term, ¬_whitespace) + let not_whitespace = k.k.boolean_bitnot(&is_whitespace); + k.k.boolean_bitand(¬_term, ¬_whitespace) }); let (index_start, index_end) = join( @@ -76,11 +76,11 @@ impl FheString { // Truncate end. let val_add_1 = k.k.scalar_add_parallelized(&index_end.val, 1); - let i = binary_if_then_else(k, &index_end.is_some, &val_add_1, &k.create_zero()); + let i = if_then_else_zero(k, &index_end.is_some, &val_add_1); let s = self.truncate(k, &i); // Truncate start. - let i = binary_if_then_else(k, &index_start.is_some, &index_start.val, &k.create_zero()); + let i = if_then_else_zero(k, &index_start.is_some, &index_start.val); s.substr_from(k, &i) } @@ -107,7 +107,7 @@ impl FheString { let s_len = s.len(k); let i_add_slen = k.k.add_parallelized(&found.val, &s_len); let i_add_slen_eq_selflen = k.k.eq_parallelized(&i_add_slen, &self_len); - let is_some = binary_and(k, &found.is_some, &i_add_slen_eq_selflen); + let is_some = k.k.boolean_bitand(&found.is_some, &i_add_slen_eq_selflen); FheOption { is_some, @@ -124,7 +124,8 @@ impl FheString { .map(|(i, c)| { // a[i] = i < index ? a[i] : 0 let i_lt_index = k.k.scalar_gt_parallelized(index, i as Uint); - FheAsciiChar(k.k.mul_parallelized(&i_lt_index, &c.0)) + let ai = if_then_else_zero(k, &i_lt_index, &c.0); + FheAsciiChar(ai) }) .collect(); FheString(v) diff --git a/src/client_key.rs b/src/client_key.rs index fa80791..f20488b 100644 --- a/src/client_key.rs +++ b/src/client_key.rs @@ -5,7 +5,7 @@ use tfhe::{ core_crypto::prelude::UnsignedNumeric, integer::{ block_decomposition::{DecomposableInto, RecomposableFrom}, - RadixCiphertext, RadixClientKey, + BooleanBlock, RadixCiphertext, RadixClientKey, }, }; @@ -26,6 +26,11 @@ impl ClientKey { pub fn decrypt + UnsignedNumeric>(&self, ct: &RadixCiphertext) -> T { self.0.decrypt(ct) } + + /// Decrypt a boolean value. + pub fn decrypt_bool(&self, ct: &BooleanBlock) -> bool { + self.0.decrypt_bool(ct) + } } /// A trait for operations common on client key and server key.