diff --git a/alu_u32/src/mul/columns.rs b/alu_u32/src/mul/columns.rs index 53c2f9e7..575d1b73 100644 --- a/alu_u32/src/mul/columns.rs +++ b/alu_u32/src/mul/columns.rs @@ -19,6 +19,7 @@ pub struct Mul32Cols { pub is_real: T, pub counter: T, + pub counter_mult: T, } pub const NUM_MUL_COLS: usize = size_of::>(); diff --git a/alu_u32/src/mul/mod.rs b/alu_u32/src/mul/mod.rs index 6d24b1c9..abee5705 100644 --- a/alu_u32/src/mul/mod.rs +++ b/alu_u32/src/mul/mod.rs @@ -3,15 +3,18 @@ extern crate alloc; use alloc::vec; use alloc::vec::Vec; use columns::{Mul32Cols, MUL_COL_MAP, NUM_MUL_COLS}; +use core::iter::Sum; +use core::ops::Mul; +use itertools::iproduct; use valida_bus::MachineWithGeneralBus; use valida_cpu::MachineWithCpuChip; -use valida_machine::{instructions, Chip, Instruction, Interaction, Operands, Word}; +use valida_machine::{instructions, BusArgument, Chip, Instruction, Interaction, Operands, Word}; use valida_opcodes::MUL32; use valida_range::MachineWithRangeChip; use core::borrow::BorrowMut; use p3_air::VirtualPairCol; -use p3_field::PrimeField; +use p3_field::{PrimeField, PrimeField64}; use p3_matrix::dense::RowMajorMatrix; pub mod columns; @@ -29,7 +32,7 @@ pub struct Mul32Chip { impl Chip for Mul32Chip where - F: PrimeField, + F: PrimeField64, M: MachineWithGeneralBus, { fn generate_trace(&self, _machine: &M) -> RowMajorMatrix { @@ -44,6 +47,7 @@ where let row = &mut values[i * NUM_MUL_COLS..(i + 1) * NUM_MUL_COLS]; let cols: &mut Mul32Cols = row.borrow_mut(); cols.counter = F::from_canonical_usize(i + 1); + cols.is_real = F::ONE; self.op_to_row(op, cols); } @@ -56,6 +60,16 @@ where self.op_to_row(&dummy_op, cols); } + // Set counter multiplicity + let num_rows = values.len() / NUM_MUL_COLS; + let mut mult = vec![F::ZERO; num_rows]; + for i in 0..num_rows { + let r = values[MUL_COL_MAP.r + i * NUM_MUL_COLS].as_canonical_u64(); + let s = values[MUL_COL_MAP.s + i * NUM_MUL_COLS].as_canonical_u64(); + mult[r as usize] += F::ONE; + mult[s as usize] += F::ONE; + } + RowMajorMatrix { values, width: NUM_MUL_COLS, @@ -82,8 +96,26 @@ where } fn local_sends(&self) -> Vec> { - // TODO - vec![] + let send_r = Interaction { + fields: vec![VirtualPairCol::single_main(MUL_COL_MAP.r)], + count: VirtualPairCol::one(), + argument_index: BusArgument::Local(0), + }; + let send_s = Interaction { + fields: vec![VirtualPairCol::single_main(MUL_COL_MAP.s)], + count: VirtualPairCol::one(), + argument_index: BusArgument::Local(0), + }; + vec![send_r, send_s] + } + + fn local_receives(&self) -> Vec> { + let receives = Interaction { + fields: vec![VirtualPairCol::single_main(MUL_COL_MAP.counter)], + count: VirtualPairCol::single_main(MUL_COL_MAP.counter_mult), + argument_index: BusArgument::Local(0), + }; + vec![receives] } } @@ -97,11 +129,60 @@ impl Mul32Chip { cols.input_1 = b.transform(F::from_canonical_u8); cols.input_2 = c.transform(F::from_canonical_u8); cols.output = a.transform(F::from_canonical_u8); + + // Compute $r$ to satisfy $pi - z = 2^32 r$. + let base_m32: [u64; 4] = [1, 1 << 8, 1 << 16, 1 << 24]; + let pi = pi_m::<4, u64, u64>( + &base_m32, + b.transform(|x| x as u64), + c.transform(|x| x as u64), + ); + let z: u32 = (*a).into(); + let z: u64 = z as u64; + let r = (pi - z) / (1u64 << 32); + let r = r as u32; + cols.r = F::from_canonical_u32(r); + + // Compute $s$ to satisfy $pi' - z' = 2^16 s$. + let base_m16: [u32; 2] = [1, 1 << 8]; + let pi_prime = pi_m::<2, u32, u32>( + &base_m16, + b.transform(|x| x as u32), + c.transform(|x| x as u32), + ); + let z_prime = a[3] as u32 + (1u32 << 8) * a[2] as u32; + let z_prime: u32 = z_prime.into(); + let s = (pi_prime - z_prime) / (1u32 << 16); + cols.s = F::from_canonical_u32(s); } } } } +fn pi_m + Clone + Sum>( + base: &[O; N], + input_1: Word, + input_2: Word, +) -> O { + iproduct!(0..N, 0..N) + .filter(|(i, j)| i + j < N) + .map(|(i, j)| base[i + j].clone() * input_1[3 - i] * input_2[3 - j]) + .sum() +} + +fn sigma_m + Clone + Sum>( + base: &[O], + input: Word, +) -> O { + input + .into_iter() + .rev() + .take(N) + .enumerate() + .map(|(i, x)| base[i].clone() * x) + .sum() +} + pub trait MachineWithMul32Chip: MachineWithCpuChip { fn mul_u32(&self) -> &Mul32Chip; fn mul_u32_mut(&mut self) -> &mut Mul32Chip; diff --git a/alu_u32/src/mul/stark.rs b/alu_u32/src/mul/stark.rs index 8dc90b7a..c25a0ac9 100644 --- a/alu_u32/src/mul/stark.rs +++ b/alu_u32/src/mul/stark.rs @@ -1,8 +1,6 @@ use super::columns::Mul32Cols; -use super::Mul32Chip; +use super::{pi_m, sigma_m, Mul32Chip}; use core::borrow::Borrow; -use itertools::iproduct; -use valida_machine::Word; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{AbstractField, PrimeField}; @@ -21,23 +19,26 @@ where let next: &Mul32Cols = main.row_slice(1).borrow(); // Limb weights modulo 2^32 - let base_m = [1, 1 << 8, 1 << 16, 1 << 24].map(AB::Expr::from_canonical_u32); + let base_m32 = [1, 1 << 8, 1 << 16, 1 << 24].map(AB::Expr::from_canonical_u32); + + // Limb weights modulo 2^16 + let base_m16 = [1, 1 << 8].map(AB::Expr::from_canonical_u32); // Partially reduced summation of input product limbs (mod 2^32) - let pi = pi_m::<4, AB>(&base_m, local.input_1, local.input_2); + let pi = pi_m::<4, AB::Var, AB::Expr>(&base_m32, local.input_1, local.input_2); // Partially reduced summation of output limbs (mod 2^32) - let sigma = sigma_m::<4, AB>(&base_m, local.output); + let sigma = sigma_m::<4, AB::Var, AB::Expr>(&base_m32, local.output); // Partially reduced summation of input product limbs (mod 2^16) - let pi_prime = pi_m::<2, AB>(&base_m[..2], local.input_1, local.input_2); + let pi_prime = pi_m::<2, AB::Var, AB::Expr>(&base_m16, local.input_1, local.input_2); // Partially reduced summation of output limbs (mod 2^16) - let sigma_prime = sigma_m::<2, AB>(&base_m[..2], local.output); + let sigma_prime = sigma_m::<2, AB::Var, AB::Expr>(&base_m16, local.output); // Congruence checks - builder.assert_eq(pi - sigma, local.r * AB::Expr::TWO); - builder.assert_eq(pi_prime - sigma_prime, local.s * base_m[2].clone()); + builder.assert_eq(pi - sigma, local.r * AB::Expr::from_wrapped_u64(1 << 32)); + builder.assert_eq(pi_prime - sigma_prime, local.s * base_m32[2].clone()); // Range check counter builder @@ -52,24 +53,3 @@ where .assert_eq(local.counter, AB::Expr::from_canonical_u32(1 << 10)); } } - -fn pi_m( - base: &[AB::Expr], - input_1: Word, - input_2: Word, -) -> AB::Expr { - iproduct!(0..N, 0..N) - .filter(|(i, j)| i + j < N) - .map(|(i, j)| base[i + j].clone() * input_1[3 - i] * input_2[3 - j]) - .sum() -} - -fn sigma_m(base: &[AB::Expr], input: Word) -> AB::Expr { - input - .into_iter() - .rev() - .take(N) - .enumerate() - .map(|(i, x)| base[i].clone() * x) - .sum() -} diff --git a/basic/Cargo.toml b/basic/Cargo.toml index 095bf00b..2d8742a8 100644 --- a/basic/Cargo.toml +++ b/basic/Cargo.toml @@ -18,6 +18,7 @@ valida-memory = { path = "../memory" } valida-output = { path = "../output" } valida-program = { path = "../program" } valida-range = { path = "../range" } +valida-opcodes = { path = "../opcodes" } p3-maybe-rayon = { path = "../../Plonky3/maybe-rayon" } p3-baby-bear = { path = "../../Plonky3/baby-bear" } byteorder = "1.4.3" diff --git a/basic/tests/test_interpreter.rs b/basic/tests/test_interpreter.rs index dd4f77d1..3efbd5b1 100644 --- a/basic/tests/test_interpreter.rs +++ b/basic/tests/test_interpreter.rs @@ -6,18 +6,8 @@ use std::io::Read; use std::process::{Command, Stdio}; use byteorder::{LittleEndian, WriteBytesExt}; -use valida_alu_u32::{add::Add32Instruction, div::Div32Instruction}; -use valida_basic::BasicMachine; -use valida_cpu::{ - BeqInstruction, BneInstruction, Imm32Instruction, JalInstruction, JalvInstruction, - ReadAdviceInstruction, StopInstruction, Store32Instruction, -}; -use valida_machine::Instruction; use valida_machine::{InstructionWord, Operands, ProgramROM}; -use valida_output::WriteInstruction; - -type Val = BabyBear; -type Challenge = BabyBear; +use valida_opcodes::{ADD32, BEQ, BNE, DIV32, IMM32, JAL, JALV, READ_ADVICE, STOP, STORE32, WRITE}; #[test] fn run_fibonacci() { @@ -82,63 +72,63 @@ fn build_fibonacci_program_rom() -> ProgramROM { //... program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: IMM32, operands: Operands([-4, 0, 0, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: READ_ADVICE, operands: Operands([0, 1, -8, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: STORE32, operands: Operands([0, -16, -8, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: IMM32, operands: Operands([-20, 0, 0, 0, 28]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: JAL, operands: Operands([-28, fib_bb0, -28, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: STORE32, operands: Operands([0, -12, -24, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: STORE32, operands: Operands([0, 4, -12, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: WRITE, operands: Operands([0, 4, 0, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: DIV32, operands: Operands([4, 4, 256, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: WRITE, operands: Operands([0, 4, 0, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: DIV32, operands: Operands([4, 4, 256, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: WRITE, operands: Operands([0, 4, 0, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: DIV32, operands: Operands([4, 4, 256, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: WRITE, operands: Operands([0, 4, 0, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: STOP, operands: Operands::default(), }, ]); @@ -152,23 +142,23 @@ fn build_fibonacci_program_rom() -> ProgramROM { // beq .LBB0_1, 0(fp), 0(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: STORE32, operands: Operands([0, -4, 12, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: IMM32, operands: Operands([-8, 0, 0, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: IMM32, operands: Operands([-12, 0, 0, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: IMM32, operands: Operands([-16, 0, 0, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: BEQ, operands: Operands([fib_bb0_1, 0, 0, 0, 0]), }, ]); @@ -178,11 +168,11 @@ fn build_fibonacci_program_rom() -> ProgramROM { // beq .LBB0_4, 0(fp), 0(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: BNE, operands: Operands([fib_bb0_2, -16, -4, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: BEQ, operands: Operands([fib_bb0_4, 0, 0, 0, 0]), }, ]); @@ -194,19 +184,19 @@ fn build_fibonacci_program_rom() -> ProgramROM { // beq .LBB0_3, 0(fp), 0(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: ADD32, operands: Operands([-20, -8, -12, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: STORE32, operands: Operands([0, -8, -12, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: STORE32, operands: Operands([0, -12, -20, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: BEQ, operands: Operands([fib_bb0_3, 0, 0, 0, 0]), }, ]); @@ -216,11 +206,11 @@ fn build_fibonacci_program_rom() -> ProgramROM { // beq .LBB0_1, 0(fp), 0(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: ADD32, operands: Operands([-16, -16, 1, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: BEQ, operands: Operands([fib_bb0_1, 0, 0, 0, 0]), }, ]); @@ -230,11 +220,11 @@ fn build_fibonacci_program_rom() -> ProgramROM { // jalv -4(fp), 0(fp), 8(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: STORE32, operands: Operands([0, 4, -8, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: JALV, operands: Operands([-4, 0, 8, 0, 0]), }, ]); diff --git a/basic/tests/test_prover.rs b/basic/tests/test_prover.rs index f82978f0..0a5a6ff7 100644 --- a/basic/tests/test_prover.rs +++ b/basic/tests/test_prover.rs @@ -1,13 +1,11 @@ use p3_baby_bear::BabyBear; -use valida_alu_u32::add::{Add32Instruction, MachineWithAdd32Chip}; +use valida_alu_u32::add::MachineWithAdd32Chip; use valida_basic::BasicMachine; -use valida_cpu::{ - BeqInstruction, BneInstruction, Imm32Instruction, JalInstruction, JalvInstruction, - MachineWithCpuChip, StopInstruction, Store32Instruction, -}; +use valida_cpu::MachineWithCpuChip; use valida_machine::config::StarkConfigImpl; -use valida_machine::{Instruction, InstructionWord, Machine, Operands, ProgramROM, Word}; +use valida_machine::{InstructionWord, Machine, Operands, ProgramROM, Word}; use valida_memory::MachineWithMemoryChip; +use valida_opcodes::{ADD32, BEQ, BNE, IMM32, JAL, JALV, STOP, STORE32}; use valida_program::MachineWithProgramChip; use p3_challenger::DuplexChallenger; @@ -46,35 +44,35 @@ fn prove_fibonacci() { //... program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: IMM32, operands: Operands([-4, 0, 0, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: IMM32, operands: Operands([-8, 0, 0, 0, 25]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: STORE32, operands: Operands([0, -16, -8, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: IMM32, operands: Operands([-20, 0, 0, 0, 28]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: JAL, operands: Operands([-28, fib_bb0, -28, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: STORE32, operands: Operands([0, -12, -24, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: STORE32, operands: Operands([0, 4, -12, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: STOP, operands: Operands::default(), }, ]); @@ -88,23 +86,23 @@ fn prove_fibonacci() { // beq .LBB0_1, 0(fp), 0(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: STORE32, operands: Operands([0, -4, 12, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: IMM32, operands: Operands([-8, 0, 0, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: IMM32, operands: Operands([-12, 0, 0, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: IMM32, operands: Operands([-16, 0, 0, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: BEQ, operands: Operands([fib_bb0_1, 0, 0, 0, 0]), }, ]); @@ -114,11 +112,11 @@ fn prove_fibonacci() { // beq .LBB0_4, 0(fp), 0(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: BNE, operands: Operands([fib_bb0_2, -16, -4, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: BEQ, operands: Operands([fib_bb0_4, 0, 0, 0, 0]), }, ]); @@ -130,19 +128,19 @@ fn prove_fibonacci() { // beq .LBB0_3, 0(fp), 0(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: ADD32, operands: Operands([-20, -8, -12, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: STORE32, operands: Operands([0, -8, -12, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: STORE32, operands: Operands([0, -12, -20, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: BEQ, operands: Operands([fib_bb0_3, 0, 0, 0, 0]), }, ]); @@ -152,11 +150,11 @@ fn prove_fibonacci() { // beq .LBB0_1, 0(fp), 0(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: ADD32, operands: Operands([-16, -16, 1, 0, 1]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: BEQ, operands: Operands([fib_bb0_1, 0, 0, 0, 0]), }, ]); @@ -166,11 +164,11 @@ fn prove_fibonacci() { // jalv -4(fp), 0(fp), 8(fp) program.extend([ InstructionWord { - opcode: >>::OPCODE, + opcode: STORE32, operands: Operands([0, 4, -8, 0, 0]), }, InstructionWord { - opcode: >>::OPCODE, + opcode: JALV, operands: Operands([-4, 0, 8, 0, 0]), }, ]); diff --git a/machine/src/core.rs b/machine/src/core.rs index 2015f495..d3fff351 100644 --- a/machine/src/core.rs +++ b/machine/src/core.rs @@ -74,7 +74,7 @@ impl Mul for Word { fn mul(self, other: Self) -> Self { let b: u32 = self.into(); let c: u32 = other.into(); - let res = b * c; + let res = b.overflowing_mul(c).0; res.into() } } diff --git a/machine/src/lib.rs b/machine/src/lib.rs index c383ec48..76202cbb 100644 --- a/machine/src/lib.rs +++ b/machine/src/lib.rs @@ -32,7 +32,7 @@ pub trait Instruction { fn execute(state: &mut M, ops: Operands); } -#[derive(Copy, Clone, Default)] +#[derive(Copy, Clone, Default, Debug)] pub struct InstructionWord { pub opcode: u32, pub operands: Operands, @@ -47,7 +47,7 @@ impl InstructionWord { } } -#[derive(Copy, Clone, Default)] +#[derive(Copy, Clone, Default, Debug)] pub struct Operands(pub [F; OPERAND_ELEMENTS]); impl Operands {