From ba8239dd5efda2a6babb571f2d41cf6229c6c506 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Tue, 14 Jan 2025 18:18:23 +0100 Subject: [PATCH] feat: improve Falcon signature verification --- assembly/src/ast/instruction/advice.rs | 3 + assembly/src/parser/grammar.lalrpop | 2 + assembly/src/parser/token.rs | 4 + core/src/sys_events.rs | 20 ++ .../operations/sys_ops/sys_event_handlers.rs | 43 ++++ stdlib/asm/crypto/dsa/rpo_falcon512.masm | 214 ++++++++++++------ stdlib/tests/crypto/falcon.rs | 50 +++- 7 files changed, 252 insertions(+), 84 deletions(-) diff --git a/assembly/src/ast/instruction/advice.rs b/assembly/src/ast/instruction/advice.rs index f6bfe807d0..8ddcef7584 100644 --- a/assembly/src/ast/instruction/advice.rs +++ b/assembly/src/ast/instruction/advice.rs @@ -13,6 +13,7 @@ use vm_core::sys_events::SystemEvent; #[derive(Clone, PartialEq, Eq, Debug)] pub enum SystemEventNode { PushU64Div, + PushFalconDiv, PushExt2intt, PushSmtPeek, PushMapVal, @@ -30,6 +31,7 @@ impl From<&SystemEventNode> for SystemEvent { use SystemEventNode::*; match value { PushU64Div => Self::U64Div, + PushFalconDiv => Self::FalconDiv, PushExt2intt => Self::Ext2Intt, PushSmtPeek => Self::SmtPeek, PushMapVal => Self::MapValueToStack, @@ -56,6 +58,7 @@ impl fmt::Display for SystemEventNode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::PushU64Div => write!(f, "push_u64div"), + Self::PushFalconDiv => write!(f, "push_falcon_div"), Self::PushExt2intt => write!(f, "push_ext2intt"), Self::PushSmtPeek => write!(f, "push_smtpeek"), Self::PushMapVal => write!(f, "push_mapval"), diff --git a/assembly/src/parser/grammar.lalrpop b/assembly/src/parser/grammar.lalrpop index 378ae1ec50..06af15b57e 100644 --- a/assembly/src/parser/grammar.lalrpop +++ b/assembly/src/parser/grammar.lalrpop @@ -56,6 +56,7 @@ extern { "push_sig" => Token::PushSig, "push_smtpeek" => Token::PushSmtpeek, "push_u64div" => Token::PushU64Div, + "push_falcon_div" => Token::PushFalconDiv, "and" => Token::And, "assert" => Token::Assert, "assertz" => Token::Assertz, @@ -677,6 +678,7 @@ SystemEvent: Instruction = { "adv" "." "push_sig" "." => Instruction::SysEvent(SystemEventNode::PushSignature { kind }), "adv" "." "push_smtpeek" => Instruction::SysEvent(SystemEventNode::PushSmtPeek), "adv" "." "push_u64div" => Instruction::SysEvent(SystemEventNode::PushU64Div), + "adv" "." "push_falcon_div" => Instruction::SysEvent(SystemEventNode::PushFalconDiv), } #[inline] diff --git a/assembly/src/parser/token.rs b/assembly/src/parser/token.rs index 94c4f73fd4..688a19f78f 100644 --- a/assembly/src/parser/token.rs +++ b/assembly/src/parser/token.rs @@ -153,6 +153,7 @@ pub enum Token<'input> { PushSmtset, PushSmtget, PushU64Div, + PushFalconDiv, And, Assert, Assertz, @@ -338,6 +339,7 @@ impl fmt::Display for Token<'_> { Token::PushSmtset => write!(f, "push_smtset"), Token::PushSmtget => write!(f, "push_smtget"), Token::PushU64Div => write!(f, "push_u64div"), + Token::PushFalconDiv => write!(f, "push_falcon_div"), Token::And => write!(f, "and"), Token::Assert => write!(f, "assert"), Token::Assertz => write!(f, "assertz"), @@ -531,6 +533,7 @@ impl<'input> Token<'input> { | Token::PushSmtset | Token::PushSmtget | Token::PushU64Div + | Token::PushFalconDiv | Token::And | Token::Assert | Token::Assertz @@ -676,6 +679,7 @@ impl<'input> Token<'input> { ("push_smtset", Token::PushSmtset), ("push_smtget", Token::PushSmtget), ("push_u64div", Token::PushU64Div), + ("push_falcon_div", Token::PushFalconDiv), ("and", Token::And), ("assert", Token::Assert), ("assertz", Token::Assertz), diff --git a/core/src/sys_events.rs b/core/src/sys_events.rs index 0a4c8df2a7..d554163582 100644 --- a/core/src/sys_events.rs +++ b/core/src/sys_events.rs @@ -28,6 +28,7 @@ mod constants { pub const EVENT_HDWORD_TO_MAP_WITH_DOMAIN: u32 = 2822590340; pub const EVENT_HPERM_TO_MAP: u32 = 3297060969; pub const EVENT_FALCON_SIG_TO_STACK: u32 = 3419226139; + pub const EVENT_FALCON_DIV: u32 = 3419226155; } /// Defines a set of actions which can be initiated from the VM to inject new data into the advice @@ -119,6 +120,22 @@ pub enum SystemEvent { /// the remainder respectively. U64Div, + /// Pushes the result of divison (both the quotient and the remainder) of a [u64] by the Falcon + /// prime (M = 12289) onto the advice stack. + /// + /// Inputs: + /// Operand stack: [a1, a0, ...] + /// Advice stack: [...] + /// + /// Outputs: + /// Operand stack: [a1, a0, ...] + /// Advice stack: [q0, q1, r, ...] + /// + /// Where (a0, a1) are the 32-bit limbs of the dividend (with a0 representing the 32 least + /// significant bits and a1 representing the 32 most significant bits). + /// Similarly, (q0, q1) represent the quotient and r the remainder. + FalconDiv, + /// Given an element in a quadratic extension field on the top of the stack (i.e., a0, b1), /// computes its multiplicative inverse and push the result onto the advice stack. /// @@ -310,6 +327,7 @@ impl SystemEvent { SystemEvent::MapValueToStack => EVENT_MAP_VALUE_TO_STACK, SystemEvent::MapValueToStackN => EVENT_MAP_VALUE_TO_STACK_N, SystemEvent::U64Div => EVENT_U64_DIV, + SystemEvent::FalconDiv => EVENT_FALCON_DIV, SystemEvent::Ext2Inv => EVENT_EXT2_INV, SystemEvent::Ext2Intt => EVENT_EXT2_INTT, SystemEvent::SmtPeek => EVENT_SMT_PEEK, @@ -335,6 +353,7 @@ impl SystemEvent { EVENT_MAP_VALUE_TO_STACK => Some(SystemEvent::MapValueToStack), EVENT_MAP_VALUE_TO_STACK_N => Some(SystemEvent::MapValueToStackN), EVENT_U64_DIV => Some(SystemEvent::U64Div), + EVENT_FALCON_DIV => Some(SystemEvent::FalconDiv), EVENT_EXT2_INV => Some(SystemEvent::Ext2Inv), EVENT_EXT2_INTT => Some(SystemEvent::Ext2Intt), EVENT_SMT_PEEK => Some(SystemEvent::SmtPeek), @@ -367,6 +386,7 @@ impl fmt::Display for SystemEvent { Self::MapValueToStack => write!(f, "map_value_to_stack"), Self::MapValueToStackN => write!(f, "map_value_to_stack_with_len"), Self::U64Div => write!(f, "div_u64"), + Self::FalconDiv => write!(f, "falcon_div"), Self::Ext2Inv => write!(f, "ext2_inv"), Self::Ext2Intt => write!(f, "ext2_intt"), Self::SmtPeek => write!(f, "smt_peek"), diff --git a/processor/src/operations/sys_ops/sys_event_handlers.rs b/processor/src/operations/sys_ops/sys_event_handlers.rs index 13e4a796be..9544a85212 100644 --- a/processor/src/operations/sys_ops/sys_event_handlers.rs +++ b/processor/src/operations/sys_ops/sys_event_handlers.rs @@ -19,6 +19,9 @@ use crate::{ /// The offset of the domain value on the stack in the `hdword_to_map_with_domain` system event. const HDWORD_TO_MAP_WITH_DOMAIN_DOMAIN_OFFSET: usize = 8; +/// Falcon signature prime. +const M: u64 = 12289; + impl Process { pub(super) fn handle_sytem_event( &self, @@ -39,6 +42,7 @@ impl Process { copy_map_value_to_adv_stack(advice_provider, process_state, true) }, SystemEvent::U64Div => push_u64_div_result(advice_provider, process_state), + SystemEvent::FalconDiv => push_falcon_mod_result(advice_provider, process_state), SystemEvent::Ext2Inv => push_ext2_inv_result(advice_provider, process_state), SystemEvent::Ext2Intt => push_ext2_intt_result(advice_provider, process_state), SystemEvent::SmtPeek => push_smtpeek_result(advice_provider, process_state), @@ -342,6 +346,45 @@ pub fn push_u64_div_result( Ok(()) } +/// Pushes the result of divison (both the quotient and the remainder) of a [u64] by the Falcon +/// prime (M = 12289) onto the advice stack. +/// +/// Inputs: +/// Operand stack: [a1, a0, ...] +/// Advice stack: [...] +/// +/// Outputs: +/// Operand stack: [a1, a0, ...] +/// Advice stack: [q0, q1, r, ...] +/// +/// Where (a0, a1) are the 32-bit limbs of the dividend (with a0 representing the 32 least +/// significant bits and a1 representing the 32 most significant bits). +/// Similarly, (q0, q1) represent the quotient and r the remainder. +/// +/// # Errors +/// Returns an error if the divisor is ZERO. +pub fn push_falcon_mod_result( + advice_provider: &mut impl AdviceProvider, + process: ProcessState, +) -> Result<(), ExecutionError> { + let dividend_hi = process.get_stack_item(0).as_int(); + let dividend_lo = process.get_stack_item(1).as_int(); + let dividend = (dividend_hi << 32) + dividend_lo; + + let quotient = dividend / M; + let remainder = dividend - quotient * M; + + let (q_hi, q_lo) = u64_to_u32_elements(quotient); + let (r_hi, r_lo) = u64_to_u32_elements(remainder); + assert_eq!(r_hi, ZERO); + + advice_provider.push_stack(AdviceSource::Value(r_lo))?; + advice_provider.push_stack(AdviceSource::Value(q_lo))?; + advice_provider.push_stack(AdviceSource::Value(q_hi))?; + + Ok(()) +} + /// Given an element in a quadratic extension field on the top of the stack (i.e., a0, b1), /// computes its multiplicative inverse and push the result onto the advice stack. /// diff --git a/stdlib/asm/crypto/dsa/rpo_falcon512.masm b/stdlib/asm/crypto/dsa/rpo_falcon512.masm index 637f490bf9..5b1b479eca 100644 --- a/stdlib/asm/crypto/dsa/rpo_falcon512.masm +++ b/stdlib/asm/crypto/dsa/rpo_falcon512.masm @@ -1,64 +1,72 @@ use.std::crypto::hashes::rpo +use.std::math::u64 # CONSTANTS # ================================================================================================= const.J=77321994752 const.M=12289 +const.M_HALF=6144 # (M-1) / 2 const.SQUARE_NORM_BOUND=34034726 # MODULAR REDUCTION FALCON PRIME # ============================================================================================= -#! Given dividend ( i.e. field element a ) on stack top, this routine computes c = a % 12289 +#! Given dividend ( i.e. a u64 given by its lower and higher u32 decomposition ) on the stack, +#! this routine computes c = a % 12289 #! #! Expected stack state #! -#! [a, ...] +#! [a_hi, a_lo, ...] #! #! Output stack state looks like #! #! [c, ...] | c = a % 12289 +#! +#! Cycles: 31 export.mod_12289 - u32split - push.M.0 - - adv.push_u64div + adv.push_falcon_div + # the advice stack contains now [qhi, qlo, r, ...] where q = qhi * 2^32 + qlo is quotient + # and r is remainder adv_push.2 u32assert2 + # => [qlo, qhi, a_hi, a_lo, ...] - swap push.M u32overflowing_mul + # => [overflow, M * qlo % 2^32, qhi, a_hi, a_lo, ...] movup.2 push.M + # => [M, qhi, overflow, M * qlo % 2^32, a_hi, a_lo, ...] u32overflowing_madd + # => [t1, t0, M * qlo % 32, a_hi, a_lo, ...] where t = t1 * 2^32 + t0 and t = M * qhi + overflow + # Note by the bound on x - r = q * M, we are guaranteed that t1 = 0 drop + # => [M * q / 2^32, (M * q) % 2^32, a_hi, a_lo, ...] + # => [res_hi, res_lo, a_hi, a_lo, ...] - adv_push.2 - drop - u32assert + adv_push.1 + dup u32lt.M assert + # => [r, res_hi, res_lo, a_hi, a_lo, ...] dup - movup.3 u32overflowing_add + # => [flag, (res_lo + r) % 2^32, r, res_hi, a_hi, a_lo, ...] where u = uhi * 2^32 + ulo and u = (res_lo + r) / 2^32 movup.3 u32overflowing_add + # => [flag, final_res_hi, final_res_lo, r, a_hi, a_lo, ...] flag should be 0 by the bound on inputs drop - - movup.5 + # => [final_res_hi, final_res_lo, r, a_hi, a_lo, ...] + + movup.3 assert_eq - movup.4 + movup.2 assert_eq - - swap - drop - swap - drop + # => [r, ...] end # HASH-TO-POINT @@ -399,20 +407,19 @@ end #! Input: [e, ...] #! Output [norm(e)^2, ...] #! -#! Cycles: 21 +#! Cycles: 20 export.norm_sq dup dup mul - swap - #=> [e, e^2, ...] + #=> [e^2, e, ...] - dup - push.6144 - u32gt - #=> [phi, e, e^2, ...] + push.M_HALF + dup.2 + u32lt + #=> [phi, e^2, e, ...] - swap + movup.2 mul.24578 # 2*q push.151019521 # q^2 sub @@ -422,33 +429,61 @@ export.norm_sq #=> [norm(e)^2, ...] end -#! On input a tuple (u, w, v), the following computes (v - (u + (- w % q) % q) % q). +#! Given a tuple (u, w, v), we want to compute (v - (u + (- w % q) % q) % q), where: +#! +#! 1. v is a field element given by its u32 decomposition i.e., (c_lo, c_hi) such that +#! v = c_hi * 2**32 + c_lo +#! 2. w is a field element representing the (i+512)-th coefficient of the product polynomial +#! pi (i.e., h * s2). We are guaranteed that w is at most 512 * (q-1)^2. +#! 3. u is a field element representing the i-th coefficient of the product polynomial +#! pi (i.e., h * s2). We are guaranteed that u is at most 512 * (q-1)^2. +#! #! We can avoid doing three modular reductions by using the following facts: #! -#! 1. q is much smaller than the Miden prime. Precisely, q * 2^50 < Q -#! 2. The coefficients of the product polynomial, u and w, are less than J := 512 * q^2 -#! 3. The coefficients of c are less than q. +#! 1. q is much smaller than the Miden prime Q. Precisely, q * 2^50 < Q +#! 2. The coefficients of the product polynomial, u and w, are strictly less than J := 512 * q^2. +#! 3. The coefficients of c are at most q - 1. #! -#! This means that we can substitute (v - (u + (- w % q) % q) % q) with v + w + J - u without -#! risking Q-overflow since |v + w + J - u| < 1025 * q^2 +#! This means that we can substitute (v - (u + (- w % q) % q) % q) with v + w + J - u +#! (note J % q= 0) without risking Q-underflow but we can still overflow. +#! For this reason, we use the u32 decomposition of v and perform the addition of +#! v and w + J - u as u64. Note that |w + J - u| <= 1024 * (q - 1)^2 +#! and hence there is the possibility of an overflow when we add v and w + J - u as u64. +#! When there is an overflow, we add 10952, which is equal to 2^32 % q, to the upper u32 limb of +#! the result of (v + (w + J - u)). Note that since |w + J - u| <= 1024 * (q-1)^2 < 2^38, and +#! 10952 < q, we are guaranteed that this final u32 addition to the upper limb will not overflow. #! #! To get the final result we reduce (v + w + J - u) modulo q. #! -#! Input: [v, w, u, ...] +#! Input: [pi0, pi512 + J, c_hi, c_lo, ...] #! Output: [e, ...] #! -#! Cycles: 44 +#! Cycles: 45 export.diff_mod_q - # 1) v + w + J - add push.J add - #=> [v + w + J, u] + # 1) Subtract + sub + #=> [pi512 + J - pi, c_hi, c_lo, ...] - # 2) v + w + J - u - swap sub - #=> [v + w + J - u] + # 2) u32split first u64 + u32split + #=> [tmp_hi, tmp_lo, c_hi, c_lo, ...] + + # 3) Add the two u64-s + exec.u64::overflowing_add + #=> [flag, res_hi, res_lo, ..] + + # 4) Handle potential overflow in the u64 addition + push.10952 # 2^32 mod q + push.0 + #=> [0, 10952, flag, res_hi, res_lo, ..] + swap.2 + #=> [flag, 10952, 0, res_hi, res_lo, ..] + cdrop + add + #=> [res_hi, res_lo, ..] - # 3) Reduce modulo q + # 5) Reduce modulo q exec.mod_12289 #=> [e, ...] end @@ -484,65 +519,89 @@ end #! Input: [pi_ptr, ...] #! Output: [norm_sq(s1), ...] #! -#! Cycles: 58888 +#! Cycles: 40966 export.compute_s1_norm_sq repeat.128 # 1) Load the next 4 * 3 coefficients - # load c_i + # load the next four pi_i padw - dup.4 add.1281 + dup.4 mem_loadw - # load pi_{i+512} + # load the next four pi_{i+512} padw dup.8 add.128 mem_loadw - # load pi_i + # load the next four c_i padw - dup.12 + dup.12 add.1281 mem_loadw - #=> [PI, PI_{i+512}, C, pi_ptr, ...] + #=> [C, PI_{i+512}, PI, pi_ptr, ...] # 2) Compute the squared norm of (i + 0)-th coefficient of s1 - movup.8 - exec.mod_12289 - movup.5 - #=> [v, w, u, ...] where u is the i-th coefficient of `pi`, v is the i-th - # coefficient of `c` and w is the (512 + i)-th coefficient of `pi` polynomial. + u32split + #=> [c0_hi, c0_lo, c1, c2, c3, PI_{i+512}, PI, pi_ptr, ...] + movup.5 + push.J add + #=> [pi512_0, c0_hi, c0_lo, c1, c2, c3, pi512_1, pi512_2, pi512_3, PI, pi_ptr, ...] + movup.9 + #=> [pi0, pi512_0, c_hi, c_lo, c1, c2, c3, pi512_1, pi512_2, pi512_3, pi1, pi2, pi3, pi_ptr, ...] exec.diff_mod_q - #=> [e, ...] - + #=> [e, c1, c2, c3, pi512_1, pi512_2, pi512_3, pi1, pi2, pi3, pi_ptr, ...] exec.norm_sq #=> [norm(e)^2, ...] # Move the result out of the way so that we can process the remaining coefficients movdn.10 + #=> [c1, c2, c3, pi512_1, pi512_2, pi512_3, pi1, pi2, pi3, pi_ptr, e0, ...] # 3) Compute the squared norm of (i + 1)-th coefficient of s1 - movup.6 - exec.mod_12289 + + u32split + #=> [c1_hi, c1_lo, c2, c3, pi512_1, pi512_2, pi512_3, pi1, pi2, pi3, pi_ptr, e0, ...] movup.4 + #=> [pi512_1, c1_hi, c1_lo, c2, c3, pi512_2, pi512_3, pi2, pi3, pi_ptr, e0, ...] + push.J add + movup.7 + #=> [pi1, pi512_1, c1_hi, c1_lo, c2, c3, pi512_2, pi512_3, pi2, pi3, pi_ptr, e0, ...] exec.diff_mod_q exec.norm_sq + #=> [e, c2, c3, pi512_2, pi512_3, pi2, pi3, pi_ptr, e0, ...] + movdn.7 + #=> [c2, c3, pi512_2, pi512_3, pi2, pi3, pi_ptr, e0, e1, ...] # 4) Compute the squared norm of (i + 2)-th coefficient of s1 - movup.4 - exec.mod_12289 + + u32split + #=> [c2_hi, c2_lo, c3, pi512_2, pi512_3, pi2, pi3, pi_ptr, e0, e1, ...] movup.3 + push.J add + #=> [pi512_2, c2_hi, c2_lo, c3, pi512_3, pi2, pi3, pi_ptr, e0, e1, ...] + movup.5 + #=> [pi2, pi512_2, c2_hi, c2_lo, c3, pi512_3, pi3, pi_ptr, e0, e1, ...] exec.diff_mod_q exec.norm_sq + movdn.4 + #=> [c3, pi512_3, pi3, pi_ptr, e, e, e, ...] # 5) Compute the squared norm of (i + 3)-th coefficient of s1 - movup.2 - exec.mod_12289 - movup.2 + + u32split + #=> [c3_hi, c3_lo, pi512_3, pi3, pi_ptr, e0, e1, e2, ...] + movup.2 push.J add + movup.3 + #=> [pi3, pi512_3, c3_hi, c3_lo, pi_ptr, e0, e1, e2, ...] exec.diff_mod_q + #=> [e3, pi_ptr, e0, e1, e2, ...] exec.norm_sq + #=> [e3, pi_ptr, e0, e1, e2, ...] + swap + #=> [pi_ptr, e3, e0, e1, e2, ...] # 6) Increment the pointer add.1 @@ -561,21 +620,28 @@ end #! Input: [s2_ptr, ...] #! Output: [norm_sq(s2), ...] #! -#! Cycles: 13322 +#! Cycles: 11137 export.compute_s2_norm_sq repeat.128 padw - dup.4 add.1 swap.5 + dup.4 mem_loadw - repeat.4 - exec.norm_sq - movdn.4 - end - + exec.norm_sq + swap + exec.norm_sq + add + swap + exec.norm_sq + add + swap + exec.norm_sq + add + swap + add.1 end drop - repeat.511 + repeat.127 add end end @@ -592,7 +658,7 @@ end #! Input: [PK, MSG, ...] #! Output: [...] #! -#! Cycles: ~ 92029 +#! Cycles: ~ 70977 export.verify.1665 # 1) Generate a Falcon signature using the secret key associated to PK on message MSG. @@ -645,7 +711,7 @@ export.verify.1665 #=> [pi_ptr, ...] exec.compute_s1_norm_sq - #=> [norm_sq(s1), ...] (Cycles: 58888) + #=> [norm_sq(s1), ...] (Cycles: 40966) # 7) Compute the squared norm of s2 @@ -653,7 +719,7 @@ export.verify.1665 #=> [s2_ptr, norm_sq(s1), ...] exec.compute_s2_norm_sq - #=> [norm_sq(s2), norm_sq(s1), ...] (Cycles: 13322) + #=> [norm_sq(s2), norm_sq(s1), ...] (Cycles: 11137) # 8) Check that ||(s1, s2)||^2 < K diff --git a/stdlib/tests/crypto/falcon.rs b/stdlib/tests/crypto/falcon.rs index 5c41c90052..18b586eebe 100644 --- a/stdlib/tests/crypto/falcon.rs +++ b/stdlib/tests/crypto/falcon.rs @@ -14,9 +14,11 @@ use test_utils::{ MerkleStore, Rpo256, }, expect_exec_error_matches, + proptest::proptest, rand::{rand_value, rand_vector}, FieldElement, QuadFelt, Word, WORD_SIZE, }; +use vm_core::StarkField; /// Modulus used for rpo falcon 512. const M: u64 = 12289; @@ -83,21 +85,49 @@ fn test_falcon512_diff_mod_q() { exec.rpo_falcon512::diff_mod_q end "; - - let u = rand::thread_rng().gen_range(0..J); - let v = rand::thread_rng().gen_range(Q..M); + let v = rand::thread_rng().gen_range(0..Felt::MODULUS); + let (v_lo, v_hi) = (v as u32, v >> 32); let w = rand::thread_rng().gen_range(0..J); + let u = rand::thread_rng().gen_range(0..J); - let test1 = build_test!(source, &[u, v, w]); + let test1 = build_test!(source, &[v_lo as u64, v_hi as u64, w + J, u]); // Calculating (v - (u + (- w % q) % q) % q) should be the same as (v + w + J - u) % q. - let expanded_answer = (v as i64 - - (u as i64 + -(w as i64).rem_euclid(M as i64)).rem_euclid(M as i64)) - .rem_euclid(M as i64); - let simplified_answer = (v + w + J - u).rem_euclid(M); - assert_eq!(expanded_answer, i64::try_from(simplified_answer).unwrap()); + let expanded_answer = (v as i128 + - ((u as i64 + -(w as i64).rem_euclid(M as i64)).rem_euclid(M as i64) as i128)) + .rem_euclid(M as i128); + let simplified_answer = (v as i128 + w as i128 + J as i128 - u as i128).rem_euclid(M as i128); + assert_eq!(expanded_answer, i128::try_from(simplified_answer).unwrap()); + + test1.expect_stack(&[simplified_answer as u64]); +} + +proptest! { + #[test] + fn diff_mod_q_proptest(v in 0..Felt::MODULUS, w in 0..J, u in 0..J) { + + let source = " + use.std::crypto::dsa::rpo_falcon512 + + begin + exec.rpo_falcon512::diff_mod_q + end + "; + + let (v_lo, v_hi) = (v as u32, v >> 32); + + let test1 = build_test!(source, &[v_lo as u64, v_hi as u64, w + J, u]); + + // Calculating (v - (u + (- w % q) % q) % q) should be the same as (v + w + J - u) % q. + let expanded_answer = (v as i128 + - ((u as i64 + -(w as i64).rem_euclid(M as i64)).rem_euclid(M as i64) as i128)) + .rem_euclid(M as i128); + let simplified_answer = (v as i128 + w as i128 + J as i128 - u as i128).rem_euclid(M as i128); + assert_eq!(expanded_answer, i128::try_from(simplified_answer).unwrap()); + + test1.prop_expect_stack(&[simplified_answer as u64])?; + } - test1.expect_stack(&[simplified_answer]); } #[test]