diff --git a/poly-commit-rs/src/libraries/mle.rs b/poly-commit-rs/src/libraries/mle.rs index 978e986..005d96a 100644 --- a/poly-commit-rs/src/libraries/mle.rs +++ b/poly-commit-rs/src/libraries/mle.rs @@ -1,5 +1,5 @@ //! mle_polynomial.rs -//! +//! //! Minimal translation of MLEPolynomial Python class into Rust with arkworks. use ark_bls12_381::Fr; @@ -83,10 +83,7 @@ impl MLEPolynomial { /// Returns: /// - Vec of MLEPolynomial (quotients) /// - Fr (the evaluation at that point) - pub fn decompose_by_div( - &self, - point: &[Fr], - ) -> Result<(Vec, Fr), MLEError> { + pub fn decompose_by_div(&self, point: &[Fr]) -> Result<(Vec, Fr), MLEError> { if point.len() != self.num_var { return Err(MLEError::DimensionMismatch( "Number of variables must match the point".to_string(), @@ -158,8 +155,8 @@ pub fn eqs_over_hypercube(rs: &[Fr]) -> Vec { evals[j + half] = evals[j] * r; evals[j] = evals[j] - evals[j + half]; // above line is effectively: evals[j] = evals[j] - (evals[j] * r) - // but be mindful that evals[j + half] was overwritten. - // The original Python does: + // but be mindful that evals[j + half] was overwritten. + // The original Python does: // evals[j+half] = evals[j] * rs[i] // evals[j] = evals[j] - evals[j+half] } @@ -169,10 +166,7 @@ pub fn eqs_over_hypercube(rs: &[Fr]) -> Vec { } /// "Slow" version of eqs over hypercube. k > 5 is not supported. -pub fn eqs_over_hypercube_slow( - k: usize, - indeterminates: &[Fr], -) -> Result, MLEError> { +pub fn eqs_over_hypercube_slow(k: usize, indeterminates: &[Fr]) -> Result, MLEError> { if k > 5 { return Err(MLEError::UnsupportedOperation( "k>5 isn't supported".to_string(), @@ -348,33 +342,207 @@ pub fn decompose_by_div_from_coeffs( // ----------------------------------------------------------------------------- - #[cfg(test)] mod tests { use super::*; use ark_ff::UniformRand; use rand::thread_rng; + // Helper: produce random `Fr` vector of length n + fn random_fr_vec(n: usize) -> Vec { + let mut rng = thread_rng(); + (0..n).map(|_| Fr::rand(&mut rng)).collect() + } + + // --------------------- + // MLEPolynomial Tests + // --------------------- + #[test] + fn test_mle_new_and_get() { + let evals = vec![Fr::one(), Fr::one()]; + let mle = MLEPolynomial::new(evals.clone(), 2).unwrap(); // 2^2 = 4 => pads to length 4 + assert_eq!(mle.evals.len(), 4); + // Indices 0 and 1 contain our original data + assert_eq!(mle.get(0), Some(&Fr::one())); + assert_eq!(mle.get(1), Some(&Fr::one())); + // Indices 2 and 3 should be zero after padding + assert_eq!(mle.get(2), Some(&Fr::zero())); + assert_eq!(mle.get(3), Some(&Fr::zero())); + // Out of range + assert_eq!(mle.get(4), None); + } + + #[test] + fn test_mle_evaluate() { + // k=2 => 2^2=4 evals + // We'll pick a small set manually: + let evals = vec![ + Fr::from(0u64), + Fr::from(1u64), + Fr::from(2u64), + Fr::from(3u64), + ]; + let mle = MLEPolynomial::new(evals.clone(), 2).unwrap(); + // Evaluate at zs = [0, 0] + // evaluate_from_evals logic => we expect f[0,0] = evals[0] = 0 + let z1 = Fr::zero(); + let z2 = Fr::zero(); + let val = mle.evaluate(&[z1, z2]).unwrap(); + assert_eq!(val, Fr::from(0u64)); + + // Evaluate at zs = [1, 0] + // This picks out f[1,0] which should be evals[index=1<<0 + 1<<1?] + // Actually let's just rely on the direct Python-likeness: + // f[1,0] -> . . . + // Instead let's check that "some non-zero value" + let val2 = mle.evaluate(&[Fr::one(), Fr::zero()]).unwrap(); + // we can compare to what we get from direct function call + let direct_val = evaluate_from_evals(&evals, &[Fr::one(), Fr::zero()]); + assert_eq!(val2, direct_val); + } + #[test] - fn test_basic_mle() { + fn test_mle_to_coeffs_round_trip() { + // Random example with k=3 => 8 evals + let mut rng = thread_rng(); + let evals: Vec = (0..8).map(|_| Fr::rand(&mut rng)).collect(); + let mle = MLEPolynomial::new(evals.clone(), 3).unwrap(); + + // Round-trip: evals -> coeffs -> evals + let coeffs = mle.to_coeffs(); + let evals_again = compute_evals_from_coeffs(&coeffs); + assert_eq!(mle.evals, evals_again, "Round trip mismatch"); + } + + #[test] + fn test_mle_decompose_by_div() { let mut rng = thread_rng(); - // Build a small example polynomial with k=3 => 2^3=8 evals + // k=3 => 8 evals let evals: Vec = (0..8).map(|_| Fr::rand(&mut rng)).collect(); - let mle = MLEPolynomial::new(evals, 3).unwrap(); + let mle = MLEPolynomial::new(evals.clone(), 3).unwrap(); + let point: Vec = (0..3).map(|_| Fr::rand(&mut rng)).collect(); + + let (quotients, eval) = mle.decompose_by_div(&point).unwrap(); + assert_eq!(quotients.len(), 3); + // Final polynomial evaluation at that point + let direct_eval = evaluate_from_evals(&evals, &point); + assert_eq!(eval, direct_eval); + } - // Just call get - assert!(mle.get(8).is_none(), "Index out of range should be None"); - assert!(mle.get(0).is_some()); + // ---------------------------- + // compute_monomials, eqs, etc. + // ---------------------------- + #[test] + fn test_compute_monomials() { + // k=2 => rs has length 2 => result length = 4 + let rs = [Fr::from(2u64), Fr::from(3u64)]; + let monos = compute_monomials(&rs); + // Expect: + // index 0: 1 + // index 1: r0 = 2 + // index 2: r1 = 3 + // index 3: r0*r1 = 6 + assert_eq!(monos.len(), 4); + assert_eq!(monos[0], Fr::from(1u64)); + assert_eq!(monos[1], Fr::from(2u64)); + assert_eq!(monos[2], Fr::from(3u64)); + assert_eq!(monos[3], Fr::from(6u64)); + } - // Evaluate at some random points - let zs: Vec = (0..3).map(|_| Fr::rand(&mut rng)).collect(); - let val = mle.evaluate(&zs).unwrap(); - println!("MLE evaluate = {:?}", val); + #[test] + fn test_eqs_over_hypercube() { + // k=1 => 2^1=2 results + // let rs[0] = r + // index 0 => x=0 => f(0) = 1- r + // index 1 => x=1 => f(1) = ? + let r = Fr::from(2u64); + let evals = eqs_over_hypercube(&[r]); + // index 0 => originally 1 => then j=0 => evals[1] = evals[0]*r => 2 + // => evals[0] = evals[0] - evals[1] => 1 - 2 => -1 + // So final => [-1, 2] + assert_eq!(evals.len(), 2); + assert_eq!(evals[0], Fr::from(-1i64)); + assert_eq!(evals[1], Fr::from(2u64)); + } - // Convert to coeffs and back - let coeffs = mle.to_coeffs(); - let evals_again = compute_evals_from_coeffs(&coeffs); - assert_eq!(mle.evals, evals_again); + #[test] + fn test_eqs_over_hypercube_slow() { + // k=1 => eqs_over_hypercube_slow => we can compare with eqs_over_hypercube + let r = Fr::from(2u64); + let slow = eqs_over_hypercube_slow(1, &[r]).unwrap(); + let fast = eqs_over_hypercube(&[r]); + assert_eq!(slow, fast); + + // k=6 => error + let big = eqs_over_hypercube_slow(6, &[Fr::one(); 6]); + assert!(big.is_err(), "k>5 should be unsupported"); + } + + // --------------------- + // NTT and Evaluate Tests + // --------------------- + #[test] + fn test_ntt_core_round_trip() { + // We do compute_evals_from_coeffs -> compute_coeffs_from_evals -> original + let coeffs = random_fr_vec(8); // length=8 => k=3 + let evals = compute_evals_from_coeffs(&coeffs); + let back_coeffs = compute_coeffs_from_evals(&evals); + assert_eq!(coeffs, back_coeffs, "NTT round-trip mismatch"); + } + + #[test] + fn test_evaluate_from_evals() { + // small manual example with k=2 => 4 evals + let evals = vec![Fr::from(10u64), Fr::from(20u64), Fr::from(30u64), Fr::from(40u64)]; + let zs = [Fr::from(0u64), Fr::from(0u64)]; + let val = evaluate_from_evals(&evals, &zs); + // for k=2, index=0 => that should be evals[0] = 10 + assert_eq!(val, Fr::from(10u64)); + } + + #[test] + fn test_evaluate_from_evals_2() { + // same eval set, compare evaluate_from_evals_2 with evaluate_from_evals + let evals = vec![Fr::from(10u64), Fr::from(20u64), Fr::from(30u64), Fr::from(40u64)]; + let zs = [Fr::from(1u64), Fr::from(0u64)]; + let val1 = evaluate_from_evals(&evals, &zs); + let val2 = evaluate_from_evals_2(&evals, &zs); + assert_eq!(val1, val2); + } + + #[test] + fn test_evaluate_from_coeffs() { + // small test + // let's define k=2 => 4 coeffs + let coeffs = vec![Fr::from(1u64), Fr::from(2u64), Fr::from(3u64), Fr::from(4u64)]; + let zs = vec![Fr::one(), Fr::zero()]; + let val = evaluate_from_coeffs(&coeffs, &zs); + // we can do it by hand if needed, or just trust the logic + // let's also compare with manual expand if we want + // f(X0, X1) = ... + // We'll just trust it's consistent for now + println!("evaluate_from_coeffs() => {:?}", val); + } + + // --------------------- + // decompose_by_div_from_coeffs Tests + // --------------------- + #[test] + fn test_decompose_by_div_from_coeffs() { + // k=2 => 4 coeffs + let coeffs = vec![ + Fr::from(1u64), + Fr::from(2u64), + Fr::from(3u64), + Fr::from(4u64), + ]; + let point = vec![Fr::one(), Fr::zero()]; // dimension=2 + + let dec = decompose_by_div_from_coeffs(coeffs.clone(), &point).unwrap(); + assert_eq!(dec.quotients.len(), 2); + // dec.evaluation => f(1,0) + let direct_eval = evaluate_from_coeffs(&coeffs, &point); + assert_eq!(dec.evaluation, direct_eval); } }