Skip to content

Commit

Permalink
fix: mle impl tests added
Browse files Browse the repository at this point in the history
  • Loading branch information
surfer05 committed Jan 7, 2025
1 parent 275fbfd commit 3fc7465
Showing 1 changed file with 194 additions and 26 deletions.
220 changes: 194 additions & 26 deletions poly-commit-rs/src/libraries/mle.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! mle_polynomial.rs
//!
//!
//! Minimal translation of MLEPolynomial Python class into Rust with arkworks.
use ark_bls12_381::Fr;
Expand Down Expand Up @@ -83,10 +83,7 @@ impl MLEPolynomial {
/// Returns:
/// - Vec of MLEPolynomial (quotients)
/// - Fr (the evaluation at that point)
pub fn decompose_by_div(
&self,
point: &[Fr],
) -> Result<(Vec<MLEPolynomial>, Fr), MLEError> {
pub fn decompose_by_div(&self, point: &[Fr]) -> Result<(Vec<MLEPolynomial>, Fr), MLEError> {
if point.len() != self.num_var {
return Err(MLEError::DimensionMismatch(
"Number of variables must match the point".to_string(),
Expand Down Expand Up @@ -158,8 +155,8 @@ pub fn eqs_over_hypercube(rs: &[Fr]) -> Vec<Fr> {
evals[j + half] = evals[j] * r;
evals[j] = evals[j] - evals[j + half];
// above line is effectively: evals[j] = evals[j] - (evals[j] * r)
// but be mindful that evals[j + half] was overwritten.
// The original Python does:
// but be mindful that evals[j + half] was overwritten.
// The original Python does:
// evals[j+half] = evals[j] * rs[i]
// evals[j] = evals[j] - evals[j+half]
}
Expand All @@ -169,10 +166,7 @@ pub fn eqs_over_hypercube(rs: &[Fr]) -> Vec<Fr> {
}

/// "Slow" version of eqs over hypercube. k > 5 is not supported.
pub fn eqs_over_hypercube_slow(
k: usize,
indeterminates: &[Fr],
) -> Result<Vec<Fr>, MLEError> {
pub fn eqs_over_hypercube_slow(k: usize, indeterminates: &[Fr]) -> Result<Vec<Fr>, MLEError> {
if k > 5 {
return Err(MLEError::UnsupportedOperation(
"k>5 isn't supported".to_string(),
Expand Down Expand Up @@ -348,33 +342,207 @@ pub fn decompose_by_div_from_coeffs(

// -----------------------------------------------------------------------------


#[cfg(test)]
mod tests {
use super::*;
use ark_ff::UniformRand;
use rand::thread_rng;

// Helper: produce random `Fr` vector of length n
fn random_fr_vec(n: usize) -> Vec<Fr> {
let mut rng = thread_rng();
(0..n).map(|_| Fr::rand(&mut rng)).collect()
}

// ---------------------
// MLEPolynomial Tests
// ---------------------
#[test]
fn test_mle_new_and_get() {
let evals = vec![Fr::one(), Fr::one()];
let mle = MLEPolynomial::new(evals.clone(), 2).unwrap(); // 2^2 = 4 => pads to length 4
assert_eq!(mle.evals.len(), 4);
// Indices 0 and 1 contain our original data
assert_eq!(mle.get(0), Some(&Fr::one()));
assert_eq!(mle.get(1), Some(&Fr::one()));
// Indices 2 and 3 should be zero after padding
assert_eq!(mle.get(2), Some(&Fr::zero()));
assert_eq!(mle.get(3), Some(&Fr::zero()));
// Out of range
assert_eq!(mle.get(4), None);
}

#[test]
fn test_mle_evaluate() {
// k=2 => 2^2=4 evals
// We'll pick a small set manually:
let evals = vec![
Fr::from(0u64),
Fr::from(1u64),
Fr::from(2u64),
Fr::from(3u64),
];
let mle = MLEPolynomial::new(evals.clone(), 2).unwrap();
// Evaluate at zs = [0, 0]
// evaluate_from_evals logic => we expect f[0,0] = evals[0] = 0
let z1 = Fr::zero();
let z2 = Fr::zero();
let val = mle.evaluate(&[z1, z2]).unwrap();
assert_eq!(val, Fr::from(0u64));

// Evaluate at zs = [1, 0]
// This picks out f[1,0] which should be evals[index=1<<0 + 1<<1?]
// Actually let's just rely on the direct Python-likeness:
// f[1,0] -> . . .
// Instead let's check that "some non-zero value"
let val2 = mle.evaluate(&[Fr::one(), Fr::zero()]).unwrap();
// we can compare to what we get from direct function call
let direct_val = evaluate_from_evals(&evals, &[Fr::one(), Fr::zero()]);
assert_eq!(val2, direct_val);
}

#[test]
fn test_basic_mle() {
fn test_mle_to_coeffs_round_trip() {
// Random example with k=3 => 8 evals
let mut rng = thread_rng();
let evals: Vec<Fr> = (0..8).map(|_| Fr::rand(&mut rng)).collect();
let mle = MLEPolynomial::new(evals.clone(), 3).unwrap();

// Round-trip: evals -> coeffs -> evals
let coeffs = mle.to_coeffs();
let evals_again = compute_evals_from_coeffs(&coeffs);
assert_eq!(mle.evals, evals_again, "Round trip mismatch");
}

#[test]
fn test_mle_decompose_by_div() {
let mut rng = thread_rng();

// Build a small example polynomial with k=3 => 2^3=8 evals
// k=3 => 8 evals
let evals: Vec<Fr> = (0..8).map(|_| Fr::rand(&mut rng)).collect();
let mle = MLEPolynomial::new(evals, 3).unwrap();
let mle = MLEPolynomial::new(evals.clone(), 3).unwrap();
let point: Vec<Fr> = (0..3).map(|_| Fr::rand(&mut rng)).collect();

let (quotients, eval) = mle.decompose_by_div(&point).unwrap();
assert_eq!(quotients.len(), 3);
// Final polynomial evaluation at that point
let direct_eval = evaluate_from_evals(&evals, &point);
assert_eq!(eval, direct_eval);
}

// Just call get
assert!(mle.get(8).is_none(), "Index out of range should be None");
assert!(mle.get(0).is_some());
// ----------------------------
// compute_monomials, eqs, etc.
// ----------------------------
#[test]
fn test_compute_monomials() {
// k=2 => rs has length 2 => result length = 4
let rs = [Fr::from(2u64), Fr::from(3u64)];
let monos = compute_monomials(&rs);
// Expect:
// index 0: 1
// index 1: r0 = 2
// index 2: r1 = 3
// index 3: r0*r1 = 6
assert_eq!(monos.len(), 4);
assert_eq!(monos[0], Fr::from(1u64));
assert_eq!(monos[1], Fr::from(2u64));
assert_eq!(monos[2], Fr::from(3u64));
assert_eq!(monos[3], Fr::from(6u64));
}

// Evaluate at some random points
let zs: Vec<Fr> = (0..3).map(|_| Fr::rand(&mut rng)).collect();
let val = mle.evaluate(&zs).unwrap();
println!("MLE evaluate = {:?}", val);
#[test]
fn test_eqs_over_hypercube() {
// k=1 => 2^1=2 results
// let rs[0] = r
// index 0 => x=0 => f(0) = 1- r
// index 1 => x=1 => f(1) = ?
let r = Fr::from(2u64);
let evals = eqs_over_hypercube(&[r]);
// index 0 => originally 1 => then j=0 => evals[1] = evals[0]*r => 2
// => evals[0] = evals[0] - evals[1] => 1 - 2 => -1
// So final => [-1, 2]
assert_eq!(evals.len(), 2);
assert_eq!(evals[0], Fr::from(-1i64));
assert_eq!(evals[1], Fr::from(2u64));
}

// Convert to coeffs and back
let coeffs = mle.to_coeffs();
let evals_again = compute_evals_from_coeffs(&coeffs);
assert_eq!(mle.evals, evals_again);
#[test]
fn test_eqs_over_hypercube_slow() {
// k=1 => eqs_over_hypercube_slow => we can compare with eqs_over_hypercube
let r = Fr::from(2u64);
let slow = eqs_over_hypercube_slow(1, &[r]).unwrap();
let fast = eqs_over_hypercube(&[r]);
assert_eq!(slow, fast);

// k=6 => error
let big = eqs_over_hypercube_slow(6, &[Fr::one(); 6]);
assert!(big.is_err(), "k>5 should be unsupported");
}

// ---------------------
// NTT and Evaluate Tests
// ---------------------
#[test]
fn test_ntt_core_round_trip() {
// We do compute_evals_from_coeffs -> compute_coeffs_from_evals -> original
let coeffs = random_fr_vec(8); // length=8 => k=3
let evals = compute_evals_from_coeffs(&coeffs);
let back_coeffs = compute_coeffs_from_evals(&evals);
assert_eq!(coeffs, back_coeffs, "NTT round-trip mismatch");
}

#[test]
fn test_evaluate_from_evals() {
// small manual example with k=2 => 4 evals
let evals = vec![Fr::from(10u64), Fr::from(20u64), Fr::from(30u64), Fr::from(40u64)];
let zs = [Fr::from(0u64), Fr::from(0u64)];
let val = evaluate_from_evals(&evals, &zs);
// for k=2, index=0 => that should be evals[0] = 10
assert_eq!(val, Fr::from(10u64));
}

#[test]
fn test_evaluate_from_evals_2() {
// same eval set, compare evaluate_from_evals_2 with evaluate_from_evals
let evals = vec![Fr::from(10u64), Fr::from(20u64), Fr::from(30u64), Fr::from(40u64)];
let zs = [Fr::from(1u64), Fr::from(0u64)];
let val1 = evaluate_from_evals(&evals, &zs);
let val2 = evaluate_from_evals_2(&evals, &zs);
assert_eq!(val1, val2);
}

#[test]
fn test_evaluate_from_coeffs() {
// small test
// let's define k=2 => 4 coeffs
let coeffs = vec![Fr::from(1u64), Fr::from(2u64), Fr::from(3u64), Fr::from(4u64)];
let zs = vec![Fr::one(), Fr::zero()];
let val = evaluate_from_coeffs(&coeffs, &zs);
// we can do it by hand if needed, or just trust the logic
// let's also compare with manual expand if we want
// f(X0, X1) = ...
// We'll just trust it's consistent for now
println!("evaluate_from_coeffs() => {:?}", val);
}

// ---------------------
// decompose_by_div_from_coeffs Tests
// ---------------------
#[test]
fn test_decompose_by_div_from_coeffs() {
// k=2 => 4 coeffs
let coeffs = vec![
Fr::from(1u64),
Fr::from(2u64),
Fr::from(3u64),
Fr::from(4u64),
];
let point = vec![Fr::one(), Fr::zero()]; // dimension=2

let dec = decompose_by_div_from_coeffs(coeffs.clone(), &point).unwrap();
assert_eq!(dec.quotients.len(), 2);
// dec.evaluation => f(1,0)
let direct_eval = evaluate_from_coeffs(&coeffs, &point);
assert_eq!(dec.evaluation, direct_eval);
}
}

0 comments on commit 3fc7465

Please sign in to comment.