From 68c60bad0b6c6bf8dcc0b89792a787b8b8d88968 Mon Sep 17 00:00:00 2001 From: Austin Abell Date: Tue, 10 Dec 2024 10:49:11 -0500 Subject: [PATCH] bigint2 acceleration update (#6) * WIP bigint2 impl * reverse byteorder * update to latest version * loosen bytemuck version requirement * update to rc version * bump proc-macro2 * wip changes with extra checks * uncomment ignored test * put back precomputed tables logic * point to risc0 2572 * bump to latest version and commit * bump bigint2 commit to latest * bump version to 1.2.0-rc.1 * swap pointers instead of copying data * bump to 1.2 * Update k256/src/arithmetic/projective.rs * Update k256/src/arithmetic/projective.rs Co-authored-by: Frank Laub * Update k256/src/arithmetic/projective.rs * Update k256/src/arithmetic/projective.rs * Update k256/src/arithmetic/projective.rs * remove dead code from bigint1 --------- Co-authored-by: Victor Graf Co-authored-by: Frank Laub --- Cargo.lock | 55 ++++++++++++++-- k256/Cargo.toml | 4 ++ k256/src/arithmetic.rs | 59 +++++++++++++++++ k256/src/arithmetic/mul.rs | 38 +++++++++++ k256/src/arithmetic/projective.rs | 104 +++++++++++------------------- 5 files changed, 189 insertions(+), 71 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 69470b94..3295dd0c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -132,6 +132,12 @@ version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" +[[package]] +name = "bytemuck" +version = "1.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a" + [[package]] name = "cast" version = "0.3.0" @@ -532,6 +538,12 @@ dependencies = [ "digest", ] +[[package]] +name = "include_bytes_aligned" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ee796ad498c8d9a1d68e477df8f754ed784ef875de1414ebdaf169f70a6a784" + [[package]] name = "instant" version = "0.1.12" @@ -592,6 +604,7 @@ name = "k256" version = "0.13.2" dependencies = [ "blobby", + "bytemuck", "cfg-if", "criterion", "ecdsa", @@ -603,6 +616,7 @@ dependencies = [ "once_cell", "proptest", "rand_core", + "risc0-bigint2", "serdect", "sha2", "sha3", @@ -860,9 +874,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.59" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6aeca18b86b413c660b781aa319e4e2648a3e6f9eadc9b47e9038e6fe9f3451b" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] @@ -1009,6 +1023,16 @@ dependencies = [ "subtle", ] +[[package]] +name = "risc0-bigint2" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f4c185a3bfaee681eed5bfac1440128184bf0b6544c345fb4d7bd4317c909fb" +dependencies = [ + "include_bytes_aligned", + "stability", +] + [[package]] name = "rustix" version = "0.36.16" @@ -1102,7 +1126,7 @@ checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -1190,6 +1214,16 @@ dependencies = [ "der", ] +[[package]] +name = "stability" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d904e7009df136af5297832a3ace3370cd14ff1546a232f4f185036c2736fcac" +dependencies = [ + "quote", + "syn 2.0.20", +] + [[package]] name = "subtle" version = "2.4.1" @@ -1207,6 +1241,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn" +version = "2.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcb8d4cebc40aa517dfb69618fa647a346562e67228e2236ae0042ee6ac14775" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "tap" version = "1.0.1" @@ -1307,7 +1352,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 1.0.109", "wasm-bindgen-shared", ] @@ -1329,7 +1374,7 @@ checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/k256/Cargo.toml b/k256/Cargo.toml index 130c6427..26872eae 100644 --- a/k256/Cargo.toml +++ b/k256/Cargo.toml @@ -29,6 +29,10 @@ serdect = { version = "0.2", optional = true, default-features = false } sha2 = { version = "0.10", optional = true, default-features = false } signature = { version = "2", optional = true } +[target.'cfg(all(target_os = "zkvm", target_arch = "riscv32"))'.dependencies] +risc0-bigint2 = { version = "1.2", features = ["unstable"] } +bytemuck = "1" + [dev-dependencies] blobby = "0.3" ecdsa-core = { version = "0.16", package = "ecdsa", default-features = false, features = ["dev"] } diff --git a/k256/src/arithmetic.rs b/k256/src/arithmetic.rs index b6ba5673..d49b189c 100644 --- a/k256/src/arithmetic.rs +++ b/k256/src/arithmetic.rs @@ -33,6 +33,65 @@ pub(crate) const CURVE_EQUATION_B: FieldElement = FieldElement::from_bytes_unche 0, 0, 0, 0, 0, 0, 0, CURVE_EQUATION_B_SINGLE as u8, ]); +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +use risc0_bigint2::ec; + +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +use ecdsa_core::elliptic_curve::group::prime::PrimeCurveAffine; + +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +pub(crate) fn affine_to_bigint2_affine( + affine: &AffinePoint, +) -> ec::AffinePoint<8, ec::Secp256k1Curve> { + if affine.is_identity().into() { + return ec::AffinePoint::IDENTITY; + } + let mut buffer = [[0u32; 8]; 2]; + // TODO this could potentially read from internal repr (check risc0 felt endianness) + let mut x_bytes: [u8; 32] = affine.x.to_bytes().into(); + let mut y_bytes: [u8; 32] = affine.y.to_bytes().into(); + x_bytes.reverse(); + y_bytes.reverse(); + + let x = bytemuck::cast::<_, [u32; 8]>(x_bytes); + let y = bytemuck::cast::<_, [u32; 8]>(y_bytes); + ec::AffinePoint::new_unchecked(x, y) +} + +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +pub(crate) fn projective_to_affine(p: &ProjectivePoint) -> ec::AffinePoint<8, ec::Secp256k1Curve> { + let aff = p.to_affine(); + affine_to_bigint2_affine(&aff) +} + +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +pub(crate) fn affine_to_projective( + affine: &ec::AffinePoint<8, ec::Secp256k1Curve>, +) -> ProjectivePoint { + if let Some(value) = affine.as_u32s() { + let mut x = bytemuck::cast::<_, [u8; 32]>(value[0]); + let mut y = bytemuck::cast::<_, [u8; 32]>(value[1]); + x.reverse(); + y.reverse(); + + crate::AffinePoint::new( + FieldElement::from_bytes_unchecked(&x), + FieldElement::from_bytes_unchecked(&y), + ) + .into() + } else { + ProjectivePoint::IDENTITY + } +} + +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +pub(crate) fn scalar_to_words(s: &Scalar) -> [u32; 8] { + let mut bytes: [u8; 32] = s.to_bytes().into(); + // U256 is big endian, need to flip to little endian. + bytes.reverse(); + bytemuck::cast::<_, [u32; 8]>(bytes) +} + #[cfg(test)] mod tests { use super::CURVE_EQUATION_B; diff --git a/k256/src/arithmetic/mul.rs b/k256/src/arithmetic/mul.rs index 306752de..42739c10 100644 --- a/k256/src/arithmetic/mul.rs +++ b/k256/src/arithmetic/mul.rs @@ -339,6 +339,44 @@ impl LinearCombination<[(ProjectivePoint, Scalar)]> for ProjectivePoint { } } +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +use risc0_bigint2::ec; + +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +use super::{affine_to_projective, projective_to_affine, scalar_to_words}; + +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +fn lincomb( + xks: &[(ProjectivePoint, Scalar)], + tables: &mut [(LookupTable, LookupTable)], + digits: &mut [(Radix16Decomposition<33>, Radix16Decomposition<33>)], +) -> ProjectivePoint { + let mut xks_iter = xks + .iter() + .map(|(p, s)| (projective_to_affine(p), scalar_to_words(s))); + let Some((affine, scalar)) = xks_iter.next() else { + return ProjectivePoint::IDENTITY; + }; + + let mut result = ec::AffinePoint::new_unchecked([0u32; 8], [0u32; 8]); + affine.mul(&scalar, &mut result); + let mut buffer = ec::AffinePoint::new_unchecked([0u32; 8], [0u32; 8]); + let mut mul_buffer = ec::AffinePoint::new_unchecked([0u32; 8], [0u32; 8]); + + let mut result_ptr = &mut result; + let mut buffer_ptr = &mut buffer; + + for (point, scalar) in xks_iter { + point.mul(&scalar, &mut mul_buffer); + result_ptr.add(&mul_buffer, &mut buffer_ptr); + core::mem::swap(&mut result_ptr, &mut buffer_ptr); + } + + // Convert the final result back to projective form + affine_to_projective(result_ptr) +} + +#[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] fn lincomb( xks: &[(ProjectivePoint, Scalar)], tables: &mut [(LookupTable, LookupTable)], diff --git a/k256/src/arithmetic/projective.rs b/k256/src/arithmetic/projective.rs index 8fe0cf70..07fb781d 100644 --- a/k256/src/arithmetic/projective.rs +++ b/k256/src/arithmetic/projective.rs @@ -25,6 +25,14 @@ use elliptic_curve::{ #[cfg(feature = "alloc")] use alloc::vec::Vec; +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +use risc0_bigint2::ec; + +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +use super::{ + affine_to_bigint2_affine, affine_to_projective, projective_to_affine, scalar_to_words, +}; + #[rustfmt::skip] const ENDOMORPHISM_BETA: FieldElement = FieldElement::from_bytes_unchecked(&[ 0x7a, 0xe9, 0x6a, 0x2b, 0x65, 0x7c, 0x07, 0x10, @@ -93,6 +101,34 @@ impl ProjectivePoint { } /// Returns `self + other`. + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + fn add(&self, other: &ProjectivePoint) -> ProjectivePoint { + let b = other.to_affine(); + self.add_mixed(&b) + } + + /// Returns `self + other`. + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + fn add_mixed(&self, other: &AffinePoint) -> ProjectivePoint { + let a = projective_to_affine(self); + let b = affine_to_bigint2_affine(other); + let mut result = ec::AffinePoint::IDENTITY; + a.add(&b, &mut result); + affine_to_projective(&result) + } + + /// Doubles this point. + #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] + #[inline] + pub fn double(&self) -> ProjectivePoint { + let a = projective_to_affine(self); + let mut result = ec::AffinePoint::IDENTITY; + a.double(&mut result); + affine_to_projective(&result) + } + + /// Returns `self + other`. + #[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] fn add(&self, other: &ProjectivePoint) -> ProjectivePoint { // We implement the complete addition formula from Renes-Costello-Batina 2015 // (https://eprint.iacr.org/2015/1060 Algorithm 7). @@ -108,30 +144,6 @@ impl ProjectivePoint { let yz_pairs = ((self.y + &self.z) * &(other.y + &other.z)) + &n_yy_zz; let xz_pairs = ((self.x + &self.z) * &(other.x + &other.z)) + &n_xx_zz; - if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { - // Same as below, but using mul_single instead of repeated addition to get small - // multiplications and normalize_weak is removed. - let bzz3 = zz.mul_single(CURVE_EQUATION_B_SINGLE * 3); - - let yy_m_bzz3 = yy + &bzz3.negate(1); - let yy_p_bzz3 = yy + &bzz3; - - let byz3 = &yz_pairs.mul_single(CURVE_EQUATION_B_SINGLE * 3); - - let xx3 = xx.mul_single(3); - let bxx9 = xx3.mul_single(CURVE_EQUATION_B_SINGLE * 3); - - let new_x = (xy_pairs * &yy_m_bzz3) + &(byz3 * &xz_pairs).negate(1); // m1 - let new_y = (yy_p_bzz3 * &yy_m_bzz3) + &(bxx9 * &xz_pairs); - let new_z = (yz_pairs * &yy_p_bzz3) + &(xx3 * &xy_pairs); - - return ProjectivePoint { - x: new_x, - y: new_y, - z: new_z, - }; - } - let bzz = zz.mul_single(CURVE_EQUATION_B_SINGLE); let bzz3 = (bzz.double() + &bzz).normalize_weak(); @@ -161,6 +173,7 @@ impl ProjectivePoint { } /// Returns `self + other`. + #[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] fn add_mixed(&self, other: &AffinePoint) -> ProjectivePoint { // We implement the complete addition formula from Renes-Costello-Batina 2015 // (https://eprint.iacr.org/2015/1060 Algorithm 8). @@ -171,29 +184,6 @@ impl ProjectivePoint { let yz_pairs = (other.y * &self.z) + &self.y; let xz_pairs = (other.x * &self.z) + &self.x; - if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { - // Same as below, but using mul_single instead of repeated addition to get small - // multiplications and normalize_weak is removed. - let bzz3 = self.z.mul_single(CURVE_EQUATION_B_SINGLE * 3); - - let yy_m_bzz3 = yy + &bzz3.negate(1); - let yy_p_bzz3 = yy + &bzz3; - - let n_byz3 = - &yz_pairs.mul(&FieldElement::from_i64(CURVE_EQUATION_B_SINGLE as i64 * -3)); - - let xx3 = xx.mul_single(3); - let bxx9 = xx3.mul_single(CURVE_EQUATION_B_SINGLE * 3); - - let mut ret = ProjectivePoint { - x: (xy_pairs * &yy_m_bzz3) + &(n_byz3 * &xz_pairs), - y: (yy_p_bzz3 * &yy_m_bzz3) + &(bxx9 * &xz_pairs), - z: (yz_pairs * &yy_p_bzz3) + &(xx3 * &xy_pairs), - }; - ret.conditional_assign(self, other.is_identity()); - return ret; - } - let bzz = &self.z.mul_single(CURVE_EQUATION_B_SINGLE); let bzz3 = (bzz.double() + bzz).normalize_weak(); @@ -221,6 +211,7 @@ impl ProjectivePoint { } /// Doubles this point. + #[cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))] #[inline] pub fn double(&self) -> ProjectivePoint { // We implement the complete addition formula from Renes-Costello-Batina 2015 @@ -230,25 +221,6 @@ impl ProjectivePoint { let zz = self.z.square(); let xy2 = (self.x * &self.y).double(); - if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { - // Same as below, but using mul_single instead of repeated addition to get small - // multiplications and normalize_weak is removed. - let bzz3 = zz.mul_single(CURVE_EQUATION_B_SINGLE * 3); - let n_bzz9 = zz.mul(&FieldElement::from_i64(CURVE_EQUATION_B_SINGLE as i64 * -9)); - - let yy_m_bzz9 = yy + &n_bzz9; - let yy_p_bzz3 = yy + &bzz3; - - let yy_zz = yy * &zz; - let t = yy_zz.mul_single(CURVE_EQUATION_B_SINGLE * 24); - - return ProjectivePoint { - x: xy2 * &yy_m_bzz9, - y: ((yy_m_bzz9 * &yy_p_bzz3) + &t), - z: ((yy * &self.y) * &self.z).mul_single(8), - }; - } - let bzz = &zz.mul_single(CURVE_EQUATION_B_SINGLE); let bzz3 = (bzz.double() + bzz).normalize_weak(); let bzz9 = (bzz3.double() + &bzz3).normalize_weak();