Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ns/chore/zk use stable isqrt #1954

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tfhe-zk-pok/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ documentation = "https://docs.zama.ai/tfhe-rs"
repository = "https://github.com/zama-ai/tfhe-rs"
license = "BSD-3-Clause-Clear"
description = "tfhe-zk-pok: An implementation of zero-knowledge proofs of encryption for TFHE."
rust-version = "1.84"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand Down
6 changes: 3 additions & 3 deletions tfhe-zk-pok/src/backward_compatibility/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::fmt::Display;
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};

use crate::curve_api::Curve;
use crate::four_squares::{isqrt, sqr};
use crate::four_squares::sqr;
use crate::proofs::pke_v2::Bound;
use crate::proofs::GroupElements;
use crate::serialization::{
Expand Down Expand Up @@ -102,15 +102,15 @@ impl Upgrade<SerializablePKEv2PublicParams> for SerializablePKEv2PublicParamsV0
type Error = Infallible;

fn upgrade(self) -> Result<SerializablePKEv2PublicParams, Self::Error> {
let slack_factor = isqrt((self.d + self.k) as u128) as u64;
let slack_factor = (self.d + self.k).isqrt() as u64;
let B_inf = self.B / slack_factor;
Ok(SerializablePKEv2PublicParams {
g_lists: self.g_lists,
D: self.D,
n: self.n,
d: self.d,
k: self.k,
B_bound_squared: sqr(self.B_bound as u128),
B_bound_squared: sqr(self.B_bound),
B_inf,
q: self.q,
t: self.t,
Expand Down
43 changes: 8 additions & 35 deletions tfhe-zk-pok/src/four_squares.rs
Original file line number Diff line number Diff line change
@@ -1,45 +1,18 @@
use ark_ff::biginteger::arithmetic::widening_mul;
use rand::prelude::*;

pub fn sqr(x: u128) -> u128 {
/// Avoid overflows for squares of u64
pub fn sqr(x: u64) -> u128 {
let x = x as u128;
x * x
}

pub fn checked_sqr(x: u128) -> Option<u128> {
x.checked_mul(x)
}

// copied from the standard library
// since isqrt is unstable at the moment
pub fn isqrt(this: u128) -> u128 {
if this < 2 {
return this;
}

// The algorithm is based on the one presented in
// <https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)>
// which cites as source the following C code:
// <https://web.archive.org/web/20120306040058/http://medialab.freaknet.org/martin/src/sqrt/sqrt.c>.

let mut op = this;
let mut res = 0;
let mut one = 1 << (this.ilog2() & !1);

while one != 0 {
if op >= res + one {
op -= res + one;
res = (res >> 1) + one;
} else {
res >>= 1;
}
one >>= 2;
}

res
}

fn half_gcd(p: u128, s: u128) -> u128 {
let sq_p = isqrt(p as _);
let sq_p = p.isqrt();
let mut a = p;
let mut b = s;
while b > sq_p {
Expand Down Expand Up @@ -225,13 +198,13 @@ pub fn four_squares(v: u128) -> [u64; 4] {

let f = v % 4;
if f == 2 {
let b = isqrt(v as _) as u64;
let b = v.isqrt() as u64;

'main_loop: loop {
let x = 2 + rng.gen::<u64>() % (b - 2);
let y = 2 + rng.gen::<u64>() % (b - 2);

let (sum, o) = u128::overflowing_add(sqr(x as u128), sqr(y as u128));
let (sum, o) = u128::overflowing_add(sqr(x), sqr(y));
if o || sum > v {
continue 'main_loop;
}
Expand Down Expand Up @@ -288,9 +261,9 @@ pub fn four_squares(v: u128) -> [u64; 4] {
let i = mont.natural_from_mont(sqrt);
let i = if i <= p / 2 { p - i } else { i };
let z = half_gcd(p, i) as u64;
let w = isqrt(p - sqr(z as u128)) as u64;
let w = (p - sqr(z)).isqrt() as u64;

if p != sqr(z as u128) + sqr(w as u128) {
if p != sqr(z) + sqr(w) {
continue 'main_loop;
}

Expand Down
21 changes: 11 additions & 10 deletions tfhe-zk-pok/src/proofs/pke_v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ than the lwe dimension d. Please pick a smaller k: k = {k}, d = {d}"
Bound::GHL => 950625,
Bound::CS => 2 * (d as u128 + k as u128) + 4,
})
.checked_mul(B_squared + (sqr((d + 2) as u128) * (d + k) as u128) / 4)
.checked_mul(B_squared + (sqr((d + 2) as u64) * (d + k) as u128) / 4)
.unwrap_or_else(|| {
panic!(
"Invalid parameters for zk_pok, B_squared: {B_squared}, d: {d}, k: {k}. \
Expand Down Expand Up @@ -552,8 +552,9 @@ The computed m parameter is {m_bound} > 64. Please select a smaller B, d and/or
/// Use the relationship: `||x||_2 <= sqrt(dim)*||x||_inf`. Since we are only interested in the
/// squared bound, we avoid the sqrt by returning dim*(||x||_inf)^2.
fn inf_norm_bound_to_euclidean_squared(B_inf: u64, dim: usize) -> u128 {
checked_sqr(B_inf as u128)
.and_then(|norm_squared| norm_squared.checked_mul(dim as u128))
let norm_squared = sqr(B_inf);
norm_squared
.checked_mul(dim as u128)
.unwrap_or_else(|| panic!("Invalid parameters for zk_pok, B_inf: {B_inf}, d+k: {dim}"))
}

Expand Down Expand Up @@ -765,7 +766,7 @@ fn prove_impl<G: Curve>(
let e_sqr_norm = e1
.iter()
.chain(e2)
.map(|x| sqr(x.unsigned_abs() as u128))
.map(|x| sqr(x.unsigned_abs()))
.sum::<u128>();

if sanity_check_mode == ProofSanityCheckMode::Panic {
Expand Down Expand Up @@ -940,7 +941,7 @@ fn prove_impl<G: Curve>(
assert!(
checked_sqr(acc.unsigned_abs()).unwrap() <= B_bound_squared,
"sqr(acc) ({}) > B_bound_squared ({B_bound_squared})",
sqr(acc as u128)
checked_sqr(acc.unsigned_abs()).unwrap()
);
}
acc as i64
Expand Down Expand Up @@ -2786,7 +2787,7 @@ mod tests {
};

let B_with_slack_squared = inf_norm_bound_to_euclidean_squared(B, d + k);
let B_with_slack = isqrt(B_with_slack_squared) as u64;
let B_with_slack = B_with_slack_squared.isqrt() as u64;

let bound = match slack_mode {
// The slack is maximal, any term above B+slack should be refused
Expand All @@ -2797,7 +2798,7 @@ mod tests {
.e1
.iter()
.chain(&testcase.e2)
.map(|x| sqr(x.unsigned_abs() as u128))
.map(|x| sqr(x.unsigned_abs()))
.sum::<u128>();

let orig_value = match coeff_type {
Expand All @@ -2806,8 +2807,8 @@ mod tests {
};

let bound_squared =
B_with_slack_squared - (e_sqr_norm - sqr(orig_value as u128));
isqrt(bound_squared) as i64
B_with_slack_squared - (e_sqr_norm - sqr(orig_value as u64));
bound_squared.isqrt() as i64
}
// There is no slack effect, any term above B should be refused
BoundTestSlackMode::Min => B as i64,
Expand Down Expand Up @@ -2849,7 +2850,7 @@ mod tests {
let crs_max_k = crs_gen::<Curve>(d, d, B, q, t, msbs_zero_padding_bit_count, rng);

let B_with_slack_squared = inf_norm_bound_to_euclidean_squared(B, d + k);
let B_with_slack_upper = isqrt(B_with_slack_squared) as u64 + 1;
let B_with_slack_upper = B_with_slack_squared.isqrt() as u64 + 1;

// Generate test noise vectors with random coeffs and one completely out of bounds

Expand Down
2 changes: 1 addition & 1 deletion tfhe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ exclude = [
"/js_on_wasm_tests/",
"/web_wasm_parallel_tests/",
]
rust-version = "1.83"
rust-version = "1.84"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand Down
Loading