From 3c506fd9526105596c77beec2c316a0997af4431 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Thu, 18 Apr 2024 16:35:16 +0200 Subject: [PATCH] feat: expmod with variable modulus (#1090) * chore: remove unused arguments * feat: allow non-default modulus * feat: implement non-default modulus arithmetic * test: tests for non-default arithmetic * remove debug calls * chore: clean only if set * feat: add zero check with custom modulus * feat: change quotient length in case of var modulus * refactor: subtraction padding into callable function * feat: implement variable modulus subtraction * feat: implement variable modulus equality assertion * test: variable modulus tests * feat: implement custom mod exp * refactor: rename methods * refactor: use single impl of mulmod and checkzero * feat: add automatic lazy reduction for var-mod addition * refactor: make var-mod sub private It is difficult to implement automatic reduction here due to the recursive dependency of methods. As the goal is to provide only very limited API for now, then make it private for now. * refactor: move parameters * refactor: move var-mod methods to Field * refactor: subPaddingHint to hints file * docs: add package documentation * feat: implement fixed-mod exp * test: make tests cheaper to run * fix: handle mul hint edge case when modulus is zero * feat: implement expmod precompile * perf: select first bit in expmod instead of mul * docs: indicate implementation of Expmod --------- Co-authored-by: Youssef El Housni --- std/evmprecompiles/05-expmod.go | 30 +++++ std/evmprecompiles/05-expmod_test.go | 86 +++++++++++++ std/evmprecompiles/doc.go | 2 +- std/math/emulated/composition.go | 9 +- std/math/emulated/composition_test.go | 2 +- std/math/emulated/custommod.go | 99 ++++++++++++++ std/math/emulated/custommod_test.go | 171 +++++++++++++++++++++++++ std/math/emulated/doc.go | 12 ++ std/math/emulated/element_test.go | 39 ++++++ std/math/emulated/emparams/emparams.go | 38 ++++++ std/math/emulated/field_assert.go | 2 +- std/math/emulated/field_mul.go | 74 +++++++++-- std/math/emulated/field_ops.go | 5 +- std/math/emulated/hints.go | 49 +++++++ 14 files changed, 593 insertions(+), 25 deletions(-) create mode 100644 std/evmprecompiles/05-expmod_test.go create mode 100644 std/math/emulated/custommod.go create mode 100644 std/math/emulated/custommod_test.go diff --git a/std/evmprecompiles/05-expmod.go b/std/evmprecompiles/05-expmod.go index 6b1eb16123..442ba88095 100644 --- a/std/evmprecompiles/05-expmod.go +++ b/std/evmprecompiles/05-expmod.go @@ -1 +1,31 @@ package evmprecompiles + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" +) + +// Expmod implements [MODEXP] precompile contract at address 0x05. +// +// Internally, uses 4k elements for representing the base, exponent and modulus, +// upper bounding the sizes of the inputs. The runtime is constant regardless of +// the actual length of the inputs. +// +// [MODEXP]: https://ethereum.github.io/execution-specs/autoapi/ethereum/paris/vm/precompiled_contracts/expmod/index.html +func Expmod(api frontend.API, base, exp, modulus *emulated.Element[emparams.Mod1e4096]) *emulated.Element[emparams.Mod1e4096] { + // x^0 = 1 + // x mod 0 = 0 + f, err := emulated.NewField[emparams.Mod1e4096](api) + if err != nil { + panic(fmt.Sprintf("new field: %v", err)) + } + // in case modulus is zero, then need to compute with dummy values and return zero as a result + isZeroMod := f.IsZero(modulus) + modulus = f.Select(isZeroMod, f.One(), modulus) + res := f.ModExp(base, exp, modulus) + res = f.Select(isZeroMod, f.Zero(), res) + return res +} diff --git a/std/evmprecompiles/05-expmod_test.go b/std/evmprecompiles/05-expmod_test.go new file mode 100644 index 0000000000..5de7d95bcb --- /dev/null +++ b/std/evmprecompiles/05-expmod_test.go @@ -0,0 +1,86 @@ +package evmprecompiles + +import ( + "crypto/rand" + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" + "github.com/consensys/gnark/test" +) + +type expmodCircuit struct { + Base emulated.Element[emparams.Mod1e4096] + Exp emulated.Element[emparams.Mod1e4096] + Mod emulated.Element[emparams.Mod1e4096] + Result emulated.Element[emparams.Mod1e4096] + edgeCases bool +} + +func (c *expmodCircuit) Define(api frontend.API) error { + res := Expmod(api, &c.Base, &c.Exp, &c.Mod) + f, err := emulated.NewField[emparams.Mod1e4096](api) + if err != nil { + return fmt.Errorf("new field: %w", err) + } + if c.edgeCases { + // cannot use ModAssertIsEqual for edge cases. But the output is either + // 0 or 1 so can use AssertIsEqual + f.AssertIsEqual(res, &c.Result) + } else { + // for random case need to use ModAssertIsEqual + f.ModAssertIsEqual(&c.Result, res, &c.Mod) + } + return nil +} + +func testInstance(edgeCases bool, base, exp, modulus, result *big.Int) error { + circuit := &expmodCircuit{edgeCases: edgeCases} + assignment := &expmodCircuit{ + Base: emulated.ValueOf[emparams.Mod1e4096](base), + Exp: emulated.ValueOf[emparams.Mod1e4096](exp), + Mod: emulated.ValueOf[emparams.Mod1e4096](modulus), + Result: emulated.ValueOf[emparams.Mod1e4096](result), + } + return test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField()) +} + +func TestRandomInstance(t *testing.T) { + assert := test.NewAssert(t) + for _, bits := range []int{256, 512, 1024, 2048, 4096} { + assert.Run(func(assert *test.Assert) { + modulus := new(big.Int).Lsh(big.NewInt(1), uint(bits)) + base, _ := rand.Int(rand.Reader, modulus) + exp, _ := rand.Int(rand.Reader, modulus) + res := new(big.Int).Exp(base, exp, modulus) + err := testInstance(false, base, exp, modulus, res) + assert.NoError(err) + }, fmt.Sprintf("random-%d", bits)) + } +} + +func TestEdgeCases(t *testing.T) { + assert := test.NewAssert(t) + testCases := []struct { + base, exp, modulus, result *big.Int + }{ + {big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0)}, // 0^0 = 0 mod 0 + {big.NewInt(0), big.NewInt(0), big.NewInt(1), big.NewInt(1)}, // 0^0 = 1 mod 1 + {big.NewInt(0), big.NewInt(0), big.NewInt(123), big.NewInt(1)}, // 0^0 = 1 mod 123 + {big.NewInt(123), big.NewInt(123), big.NewInt(0), big.NewInt(0)}, // 123^123 = 0 mod 0 + {big.NewInt(123), big.NewInt(123), big.NewInt(0), big.NewInt(0)}, // 123^123 = 0 mod 1 + {big.NewInt(0), big.NewInt(123), big.NewInt(123), big.NewInt(0)}, // 0^123 = 0 mod 123 + {big.NewInt(123), big.NewInt(0), big.NewInt(123), big.NewInt(1)}, // 123^0 = 1 mod 123 + + } + for i, tc := range testCases { + assert.Run(func(assert *test.Assert) { + err := testInstance(true, tc.base, tc.exp, tc.modulus, tc.result) + assert.NoError(err) + }, fmt.Sprintf("edge-%d", i)) + } +} diff --git a/std/evmprecompiles/doc.go b/std/evmprecompiles/doc.go index 7c515eaa51..9b7dc431a8 100644 --- a/std/evmprecompiles/doc.go +++ b/std/evmprecompiles/doc.go @@ -7,7 +7,7 @@ // 2. SHA256 ❌ -- in progress // 3. RIPEMD160 ❌ -- postponed // 4. ID ❌ -- trivial to implement without function -// 5. EXPMOD ❌ -- in progress +// 5. EXPMOD ✅ -- function [Expmod] // 6. BN_ADD ✅ -- function [ECAdd] // 7. BN_MUL ✅ -- function [ECMul] // 8. SNARKV ✅ -- function [ECPair] diff --git a/std/math/emulated/composition.go b/std/math/emulated/composition.go index 79b0216a80..bce35e7442 100644 --- a/std/math/emulated/composition.go +++ b/std/math/emulated/composition.go @@ -66,10 +66,7 @@ func decompose(input *big.Int, nbBits uint, res []*big.Int) error { // // then no such underflow happens and s = a-b (mod p) as the padding is multiple // of p. -func subPadding[T FieldParams](overflow uint, nbLimbs uint) []*big.Int { - var fp T - p := fp.Modulus() - bitsPerLimbs := fp.BitsPerLimb() +func subPadding(modulus *big.Int, bitsPerLimbs uint, overflow uint, nbLimbs uint) []*big.Int { // first, we build a number nLimbs, such that nLimbs > b; // here b is defined by its bounds, that is b is an element with nbLimbs of (bitsPerLimbs+overflow) @@ -86,8 +83,8 @@ func subPadding[T FieldParams](overflow uint, nbLimbs uint) []*big.Int { panic(fmt.Sprintf("recompose: %v", err)) } // mod reduce n, and negate it - n.Mod(n, p) - n.Sub(p, n) + n.Mod(n, modulus) + n.Sub(modulus, n) // construct pad such that: // pad := n - neg(n mod p) == kp diff --git a/std/math/emulated/composition_test.go b/std/math/emulated/composition_test.go index d1c59289cc..25ef0430f9 100644 --- a/std/math/emulated/composition_test.go +++ b/std/math/emulated/composition_test.go @@ -54,7 +54,7 @@ func testSubPadding[T FieldParams](t *testing.T) { assert := test.NewAssert(t) for i := fp.NbLimbs(); i < 2*fp.NbLimbs(); i++ { assert.Run(func(assert *test.Assert) { - limbs := subPadding[T](0, i) + limbs := subPadding(fp.Modulus(), fp.BitsPerLimb(), 0, i) padValue := new(big.Int) if err := recompose(limbs, fp.BitsPerLimb(), padValue); err != nil { assert.FailNow("recompose", err) diff --git a/std/math/emulated/custommod.go b/std/math/emulated/custommod.go new file mode 100644 index 0000000000..2f5cbaca1b --- /dev/null +++ b/std/math/emulated/custommod.go @@ -0,0 +1,99 @@ +package emulated + +import ( + "errors" + + "github.com/consensys/gnark/frontend" +) + +// ModMul computes a*b mod modulus. Instead of taking modulus as a constant +// parametrized by T, it is passed as an argument. This allows to use a variable +// modulus in the circuit. Type parameter T should be sufficiently big to fit a, +// b and modulus. Recommended to use [emparams.Mod1e512] or +// [emparams.Mod1e4096]. +// +// NB! circuit complexity depends on T rather on the actual length of the modulus. +func (f *Field[T]) ModMul(a, b *Element[T], modulus *Element[T]) *Element[T] { + res := f.mulMod(a, b, 0, modulus) + return res +} + +// ModAdd computes a+b mod modulus. Instead of taking modulus as a constant +// parametrized by T, it is passed as an argument. This allows to use a variable +// modulus in the circuit. Type parameter T should be sufficiently big to fit a, +// b and modulus. Recommended to use [emparams.Mod1e512] or +// [emparams.Mod1e4096]. +// +// NB! circuit complexity depends on T rather on the actual length of the modulus. +func (f *Field[T]) ModAdd(a, b *Element[T], modulus *Element[T]) *Element[T] { + // inlined version of [Field.reduceAndOp] which uses variable-modulus reduction + var nextOverflow uint + var err error + var target overflowError + for nextOverflow, err = f.addPreCond(a, b); errors.As(err, &target); nextOverflow, err = f.addPreCond(a, b) { + if errors.As(err, &target) { + if !target.reduceRight { + a = f.mulMod(a, f.shortOne(), 0, modulus) + } else { + b = f.mulMod(b, f.shortOne(), 0, modulus) + } + } + } + res := f.add(a, b, nextOverflow) + return res +} + +func (f *Field[T]) modSub(a, b *Element[T], modulus *Element[T]) *Element[T] { + // like fixed modulus subtraction, but for sub padding need to use hint + // instead of assuming T as a constant. And when doing as a hint, then need + // to assert that the padding is a multiple of the modulus (done inside callSubPaddingHint) + nextOverflow := max(b.overflow+1, a.overflow) + 1 + nbLimbs := max(len(a.Limbs), len(b.Limbs)) + limbs := make([]frontend.Variable, nbLimbs) + padding := f.computeSubPaddingHint(b.overflow, uint(nbLimbs), modulus) + for i := range limbs { + limbs[i] = padding.Limbs[i] + if i < len(a.Limbs) { + limbs[i] = f.api.Add(limbs[i], a.Limbs[i]) + } + if i < len(b.Limbs) { + limbs[i] = f.api.Sub(limbs[i], b.Limbs[i]) + } + } + res := f.newInternalElement(limbs, nextOverflow) + return res +} + +// ModAssertIsEqual asserts equality of a and b mod modulus. Instead of taking +// modulus as a constant parametrized by T, it is passed as an argument. This +// allows to use a variable modulus in the circuit. Type parameter T should be +// sufficiently big to fit a, b and modulus. Recommended to use +// [emparams.Mod1e512] or [emparams.Mod1e4096]. +// +// NB! circuit complexity depends on T rather on the actual length of the modulus. +func (f *Field[T]) ModAssertIsEqual(a, b *Element[T], modulus *Element[T]) { + // like fixed modulus AssertIsEqual, but uses current Sub implementation for + // computing the diff + diff := f.modSub(b, a, modulus) + f.checkZero(diff, modulus) +} + +// ModExp computes base^exp mod modulus. Instead of taking modulus as a constant +// parametrized by T, it is passed as an argument. This allows to use a variable +// modulus in the circuit. Type parameter T should be sufficiently big to fit +// base, exp and modulus. Recommended to use [emparams.Mod1e512] or +// [emparams.Mod1e4096]. +// +// NB! circuit complexity depends on T rather on the actual length of the modulus. +func (f *Field[T]) ModExp(base, exp, modulus *Element[T]) *Element[T] { + expBts := f.ToBits(exp) + n := len(expBts) + res := f.Select(expBts[0], base, f.One()) + base = f.ModMul(base, base, modulus) + for i := 1; i < n-1; i++ { + res = f.Select(expBts[i], f.ModMul(base, res, modulus), res) + base = f.ModMul(base, base, modulus) + } + res = f.Select(expBts[n-1], f.ModMul(base, res, modulus), res) + return res +} diff --git a/std/math/emulated/custommod_test.go b/std/math/emulated/custommod_test.go new file mode 100644 index 0000000000..a399769322 --- /dev/null +++ b/std/math/emulated/custommod_test.go @@ -0,0 +1,171 @@ +package emulated + +import ( + "crypto/rand" + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated/emparams" + "github.com/consensys/gnark/test" +) + +type variableEquality[T FieldParams] struct { + Modulus Element[T] + A, B Element[T] +} + +func (c *variableEquality[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + f.ModAssertIsEqual(&c.A, &c.B, &c.Modulus) + return nil +} + +func TestVariableEquality(t *testing.T) { + assert := test.NewAssert(t) + modulus, _ := new(big.Int).SetString("4294967311", 10) + a := big.NewInt(10) + b := new(big.Int).Add(a, modulus) + circuit := &variableEquality[emparams.Mod1e512]{} + assignment := &variableEquality[emparams.Mod1e512]{ + Modulus: ValueOf[emparams.Mod1e512](modulus), + A: ValueOf[emparams.Mod1e512](a), + B: ValueOf[emparams.Mod1e512](b), + } + err := test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField()) + assert.NoError(err) +} + +type variableAddition[T FieldParams] struct { + Modulus Element[T] + A, B Element[T] + Expected Element[T] +} + +func (c *variableAddition[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + res := f.ModAdd(&c.A, &c.B, &c.Modulus) + f.ModAssertIsEqual(&c.Expected, res, &c.Modulus) + return nil +} + +func TestVariableAddition(t *testing.T) { + assert := test.NewAssert(t) + modulus, _ := new(big.Int).SetString("4294967311", 10) + circuit := &variableAddition[emparams.Mod1e512]{} + assignment := &variableAddition[emparams.Mod1e512]{ + Modulus: ValueOf[emparams.Mod1e512](modulus), + A: ValueOf[emparams.Mod1e512](10), + B: ValueOf[emparams.Mod1e512](20), + Expected: ValueOf[emparams.Mod1e512](30), + } + err := test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField()) + assert.NoError(err) +} + +type variableSubtraction[T FieldParams] struct { + Modulus Element[T] + A, B Element[T] + Expected Element[T] +} + +func (c *variableSubtraction[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + res := f.modSub(&c.A, &c.B, &c.Modulus) + f.ModAssertIsEqual(&c.Expected, res, &c.Modulus) + return nil +} + +func TestVariableSubtraction(t *testing.T) { + assert := test.NewAssert(t) + modulus, _ := new(big.Int).SetString("4294967311", 10) + circuit := &variableSubtraction[emparams.Mod1e512]{} + res := new(big.Int).Sub(modulus, big.NewInt(10)) + assignment := &variableSubtraction[emparams.Mod1e512]{ + Modulus: ValueOf[emparams.Mod1e512](modulus), + A: ValueOf[emparams.Mod1e512](10), + B: ValueOf[emparams.Mod1e512](20), + Expected: ValueOf[emparams.Mod1e512](res), + } + err := test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField()) + assert.NoError(err) +} + +type variableMultiplication[T FieldParams] struct { + Modulus Element[T] + A, B Element[T] + Expected Element[T] +} + +func (c *variableMultiplication[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + res := f.ModMul(&c.A, &c.B, &c.Modulus) + f.ModAssertIsEqual(&c.Expected, res, &c.Modulus) + return nil +} + +func TestVariableMultiplication(t *testing.T) { + assert := test.NewAssert(t) + modulus, _ := new(big.Int).SetString("4294967311", 10) + a, _ := rand.Int(rand.Reader, modulus) + b, _ := rand.Int(rand.Reader, modulus) + exp := new(big.Int).Mul(a, b) + exp.Mod(exp, modulus) + circuit := &variableMultiplication[emparams.Mod1e512]{} + assignment := &variableMultiplication[emparams.Mod1e512]{ + Modulus: ValueOf[emparams.Mod1e512](modulus), + A: ValueOf[emparams.Mod1e512](a), + B: ValueOf[emparams.Mod1e512](b), + Expected: ValueOf[emparams.Mod1e512](exp), + } + err := test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField()) + assert.NoError(err) +} + +type variableExp[T FieldParams] struct { + Modulus Element[T] + Base Element[T] + Exp Element[T] + Expected Element[T] +} + +func (c *variableExp[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + res := f.ModExp(&c.Base, &c.Exp, &c.Modulus) + f.ModAssertIsEqual(&c.Expected, res, &c.Modulus) + return nil +} + +func TestVariableExp(t *testing.T) { + assert := test.NewAssert(t) + modulus, _ := new(big.Int).SetString("4294967311", 10) + base, _ := rand.Int(rand.Reader, modulus) + exp, _ := rand.Int(rand.Reader, modulus) + expected := new(big.Int).Exp(base, exp, modulus) + circuit := &variableExp[emparams.Mod1e512]{} + assignment := &variableExp[emparams.Mod1e512]{ + Modulus: ValueOf[emparams.Mod1e512](modulus), + Base: ValueOf[emparams.Mod1e512](base), + Exp: ValueOf[emparams.Mod1e512](exp), + Expected: ValueOf[emparams.Mod1e512](expected), + } + err := test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField()) + assert.NoError(err) +} diff --git a/std/math/emulated/doc.go b/std/math/emulated/doc.go index 61a6e54288..6f4685f253 100644 --- a/std/math/emulated/doc.go +++ b/std/math/emulated/doc.go @@ -184,5 +184,17 @@ The package currently does not explicitly differentiate between constant and variable elements. The builder may track some elements as being constants. Some operations have a fast track path for cases when all inputs are constants. There is [Field.MulConst], which provides variable by constant multiplication. + +# Variable-modulus operations + +The package also exposes methods for performing operations with variable +modulus. The modulus is represented as an element and is not required to be +prime. The methods for variable-modulus operations are [Field.ModMul], +[Field.ModAdd], [Field.ModExp] and [Field.ModAssertIsEqual]. The modulus is +passed as an argument to the operation. + +The type parameter for the [Field] should be sufficiently big to allow to fit +the inputs and the modulus. Recommended to use predefined [emparams.Mod1e512] or +[emparams.Mod1e4096]. */ package emulated diff --git a/std/math/emulated/element_test.go b/std/math/emulated/element_test.go index 8954fc4d69..675f296596 100644 --- a/std/math/emulated/element_test.go +++ b/std/math/emulated/element_test.go @@ -13,6 +13,7 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/std/math/emulated/emparams" "github.com/consensys/gnark/test" ) @@ -1059,3 +1060,41 @@ func testSum[T FieldParams](t *testing.T) { assert.CheckCircuit(circuit, test.WithValidAssignment(witness)) }, testName[T]()) } + +type expCircuit[T FieldParams] struct { + Base Element[T] + Exp Element[T] + Expected Element[T] +} + +func (c *expCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + res := f.Exp(&c.Base, &c.Exp) + f.AssertIsEqual(&c.Expected, res) + return nil +} + +func testExp[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit expCircuit[T] + base, _ := rand.Int(rand.Reader, fp.Modulus()) + exp, _ := rand.Int(rand.Reader, fp.Modulus()) + expected := new(big.Int).Exp(base, exp, fp.Modulus()) + assignment := &expCircuit[T]{ + Base: ValueOf[T](base), + Exp: ValueOf[T](exp), + Expected: ValueOf[T](expected), + } + assert.CheckCircuit(&circuit, test.WithValidAssignment(assignment)) + }, testName[T]()) +} +func TestExp(t *testing.T) { + testExp[Goldilocks](t) + testExp[BN254Fr](t) + testExp[emparams.Mod1e512](t) +} diff --git a/std/math/emulated/emparams/emparams.go b/std/math/emulated/emparams/emparams.go index b07fb6e96b..bbee1afabf 100644 --- a/std/math/emulated/emparams/emparams.go +++ b/std/math/emulated/emparams/emparams.go @@ -281,3 +281,41 @@ func (fp BLS24315Fp) Modulus() *big.Int { return ecc.BLS24_315.BaseField() } type BLS24315Fr struct{ fourLimbPrimeField } func (fr BLS24315Fr) Modulus() *big.Int { return ecc.BLS24_315.ScalarField() } + +// Mod1e4096 provides type parametrization for emulated aritmetic: +// - limbs: 64 +// - limb width: 64 bits +// +// The modulus for type parametrisation is 2^4096-1. +// +// This is non-prime modulus. It is mainly targeted for using variable-modulus +// operations (ModAdd, ModMul, ModExp, ModAssertIsEqual) for variable modulus +// arithmetic. +type Mod1e4096 struct{} + +func (Mod1e4096) NbLimbs() uint { return 64 } +func (Mod1e4096) BitsPerLimb() uint { return 64 } +func (Mod1e4096) IsPrime() bool { return false } +func (Mod1e4096) Modulus() *big.Int { + val, _ := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16) + return val +} + +// Mod1e512 provides type parametrization for emulated aritmetic: +// - limbs: 8 +// - limb width: 64 bits +// +// The modulus for type parametrisation is 2^512-1. +// +// This is non-prime modulus. It is mainly targeted for using variable-modulus +// operations (ModAdd, ModMul, ModExp, ModAssertIsEqual) for variable modulus +// arithmetic. +type Mod1e512 struct{} + +func (Mod1e512) NbLimbs() uint { return 8 } +func (Mod1e512) BitsPerLimb() uint { return 64 } +func (Mod1e512) IsPrime() bool { return false } +func (Mod1e512) Modulus() *big.Int { + val, _ := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16) + return val +} diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index a2809e4eb9..5c2c700663 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -46,7 +46,7 @@ func (f *Field[T]) AssertIsEqual(a, b *Element[T]) { } diff := f.Sub(b, a) - f.checkZero(diff) + f.checkZero(diff, nil) } // AssertIsLessOrEqual ensures that e is less or equal than a. For proper diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 9a2671d08a..278b9a5024 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -58,21 +58,25 @@ type mulCheck[T FieldParams] struct { r *Element[T] // reduced value k *Element[T] // coefficient c *Element[T] // carry + p *Element[T] // modulus if non-nil } // evalRound1 evaluates first c(X), r(X) and k(X) at a given random point at[0]. // In the first round we do not assume that any of them is already evaluated as // they come directly from hint. -func (mc *mulCheck[T]) evalRound1(api frontend.API, at []frontend.Variable) { +func (mc *mulCheck[T]) evalRound1(at []frontend.Variable) { mc.c = mc.f.evalWithChallenge(mc.c, at) mc.r = mc.f.evalWithChallenge(mc.r, at) mc.k = mc.f.evalWithChallenge(mc.k, at) + if mc.p != nil { + mc.p = mc.f.evalWithChallenge(mc.p, at) + } } // evalRound2 now evaluates a and b at a given random point at[0]. However, it // may happen that a or b is equal to r from a previous mulcheck. In that case // we can reuse the evaluation to save constraints. -func (mc *mulCheck[T]) evalRound2(api frontend.API, at []frontend.Variable) { +func (mc *mulCheck[T]) evalRound2(at []frontend.Variable) { mc.a = mc.f.evalWithChallenge(mc.a, at) mc.b = mc.f.evalWithChallenge(mc.b, at) } @@ -81,6 +85,9 @@ func (mc *mulCheck[T]) evalRound2(api frontend.API, at []frontend.Variable) { // computation of p(ch) and (2^t-ch) can be shared over all mulCheck instances, // then we get them already evaluated as peval and coef. func (mc *mulCheck[T]) check(api frontend.API, peval, coef frontend.Variable) { + if mc.p != nil { + peval = mc.p.evaluation + } ls := api.Mul(mc.a.evaluation, mc.b.evaluation) rs := api.Add(mc.r.evaluation, api.Mul(peval, mc.k.evaluation), api.Mul(mc.c.evaluation, coef)) api.AssertIsEqual(ls, rs) @@ -99,14 +106,19 @@ func (mc *mulCheck[T]) cleanEvaluations() { mc.k.isEvaluated = false mc.c.evaluation = 0 mc.c.isEvaluated = false + if mc.p != nil { + mc.p.evaluation = 0 + mc.p.isEvaluated = false + } } // mulMod returns a*b mod r. In practice it computes the result using a hint and // defers the actual multiplication check. -func (f *Field[T]) mulMod(a, b *Element[T], _ uint) *Element[T] { +func (f *Field[T]) mulMod(a, b *Element[T], _ uint, p *Element[T]) *Element[T] { f.enforceWidthConditional(a) f.enforceWidthConditional(b) - k, r, c, err := f.callMulHint(a, b, true) + f.enforceWidthConditional(p) + k, r, c, err := f.callMulHint(a, b, true, p) if err != nil { panic(err) } @@ -117,18 +129,20 @@ func (f *Field[T]) mulMod(a, b *Element[T], _ uint) *Element[T] { c: c, k: k, r: r, + p: p, } f.mulChecks = append(f.mulChecks, mc) return r } // checkZero creates multiplication check a * 1 = 0 + k*p. -func (f *Field[T]) checkZero(a *Element[T]) { +func (f *Field[T]) checkZero(a *Element[T], p *Element[T]) { // the method works similarly to mulMod, but we know that we are multiplying // by one and expected result should be zero. f.enforceWidthConditional(a) + f.enforceWidthConditional(p) b := f.shortOne() - k, r, c, err := f.callMulHint(a, b, false) + k, r, c, err := f.callMulHint(a, b, false, p) if err != nil { panic(err) } @@ -139,6 +153,7 @@ func (f *Field[T]) checkZero(a *Element[T]) { c: c, k: k, r: r, // expected to be zero on zero limbs. + p: p, } f.mulChecks = append(f.mulChecks, mc) } @@ -191,6 +206,9 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { toCommit = append(toCommit, f.mulChecks[i].r.Limbs...) toCommit = append(toCommit, f.mulChecks[i].k.Limbs...) toCommit = append(toCommit, f.mulChecks[i].c.Limbs...) + if f.mulChecks[i].p != nil { + toCommit = append(toCommit, f.mulChecks[i].p.Limbs...) + } } // we give all the inputs as inputs to obtain random verifier challenge. multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { @@ -207,11 +225,11 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { } // evaluate all r, k, c for i := range f.mulChecks { - f.mulChecks[i].evalRound1(api, at) + f.mulChecks[i].evalRound1(at) } // assuming r is input to some other multiplication, then is already evaluated for i := range f.mulChecks { - f.mulChecks[i].evalRound2(api, at) + f.mulChecks[i].evalRound2(at) } // evaluate p(X) at challenge pval := f.evalWithChallenge(f.Modulus(), at) @@ -234,7 +252,7 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { } // callMulHint uses hint to compute r, k and c. -func (f *Field[T]) callMulHint(a, b *Element[T], isMulMod bool) (quo, rem, carries *Element[T], err error) { +func (f *Field[T]) callMulHint(a, b *Element[T], isMulMod bool, customMod *Element[T]) (quo, rem, carries *Element[T], err error) { // compute the expected overflow after the multiplication of a*b to be able // to estimate the number of bits required to represent the result. nextOverflow, _ := f.mulPreCond(a, b) @@ -249,8 +267,15 @@ func (f *Field[T]) callMulHint(a, b *Element[T], isMulMod bool) (quo, rem, carri // we compute the width of the product of a*b, then we divide it by the // width of the modulus. We add 1 to the result to ensure that we have // enough space for the quotient. + modbits := uint(f.fParams.Modulus().BitLen()) + if customMod != nil { + // when we're using custom modulus, then we do not really know its + // length ahead of time. We assume worst case scenario and assume that + // the quotient can be the total length of the multiplication result. + modbits = 0 + } nbQuoLimbs := (uint(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)))*nbBits + nextOverflow + 1 - // - uint(f.fParams.Modulus().BitLen()) + // + modbits + // nbBits - 1) / nbBits // the remainder is always less than modulus so can represent on the same @@ -267,7 +292,11 @@ func (f *Field[T]) callMulHint(a, b *Element[T], isMulMod bool) (quo, rem, carri len(a.Limbs), nbQuoLimbs, } - hintInputs = append(hintInputs, f.Modulus().Limbs...) + modulusLimbs := f.Modulus().Limbs + if customMod != nil { + modulusLimbs = customMod.Limbs + } + hintInputs = append(hintInputs, modulusLimbs...) hintInputs = append(hintInputs, a.Limbs...) hintInputs = append(hintInputs, b.Limbs...) ret, err := f.api.NewHint(mulHint, int(nbQuoLimbs)+int(nbRemLimbs)+int(nbCarryLimbs), hintInputs...) @@ -328,7 +357,9 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { quo := new(big.Int) rem := new(big.Int) ab := new(big.Int).Mul(a, b) - quo.QuoRem(ab, p, rem) + if p.Cmp(new(big.Int)) != 0 { + quo.QuoRem(ab, p, rem) + } if err := decompose(quo, uint(nbBits), quoLimbs); err != nil { return fmt.Errorf("decompose quo: %w", err) } @@ -380,7 +411,7 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { // For multiplying by a constant, use [Field[T].MulConst] method which is more // efficient. func (f *Field[T]) Mul(a, b *Element[T]) *Element[T] { - return f.reduceAndOp(f.mulMod, f.mulPreCond, a, b) + return f.reduceAndOp(func(a, b *Element[T], u uint) *Element[T] { return f.mulMod(a, b, u, nil) }, f.mulPreCond, a, b) } // MulMod computes a*b and reduces it modulo the field order. The returned Element @@ -388,7 +419,7 @@ func (f *Field[T]) Mul(a, b *Element[T]) *Element[T] { // // Equivalent to [Field[T].Mul], kept for backwards compatibility. func (f *Field[T]) MulMod(a, b *Element[T]) *Element[T] { - return f.reduceAndOp(f.mulMod, f.mulPreCond, a, b) + return f.reduceAndOp(func(a, b *Element[T], u uint) *Element[T] { return f.mulMod(a, b, u, nil) }, f.mulPreCond, a, b) } // MulConst multiplies a by a constant c and returns it. We assume that the @@ -463,3 +494,18 @@ func (f *Field[T]) mulNoReduce(a, b *Element[T], nextoverflow uint) *Element[T] } return f.newInternalElement(resLimbs, nextoverflow) } + +// Exp computes base^exp modulo the field order. The returned Element has default +// number of limbs and zero overflow. +func (f *Field[T]) Exp(base, exp *Element[T]) *Element[T] { + expBts := f.ToBits(exp) + n := len(expBts) + res := f.Select(expBts[0], base, f.One()) + base = f.Mul(base, base) + for i := 1; i < n-1; i++ { + res = f.Select(expBts[i], f.Mul(base, res), res) + base = f.Mul(base, base) + } + res = f.Select(expBts[n-1], f.Mul(base, res), res) + return res +} diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index aeaf2c3059..a9f0d9cda3 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -176,7 +176,7 @@ func (f *Field[T]) Reduce(a *Element[T]) *Element[T] { panic("trying to reduce a constant, which happen to have an overflow flag set") } // slow path - use hint to reduce value - return f.mulMod(a, f.One(), 0) + return f.mulMod(a, f.One(), 0, nil) } // Sub subtracts b from a and returns it. Reduces locally if wouldn't fit into @@ -204,9 +204,10 @@ func (f *Field[T]) sub(a, b *Element[T], nextOverflow uint) *Element[T] { // first we have to compute padding to ensure that the subtraction does not // underflow. + var fp T nbLimbs := max(len(a.Limbs), len(b.Limbs)) limbs := make([]frontend.Variable, nbLimbs) - padLimbs := subPadding[T](b.overflow, uint(nbLimbs)) + padLimbs := subPadding(fp.Modulus(), fp.BitsPerLimb(), b.overflow, uint(nbLimbs)) for i := range limbs { limbs[i] = padLimbs[i] if i < len(a.Limbs) { diff --git a/std/math/emulated/hints.go b/std/math/emulated/hints.go index 6c1644c407..eab14b47e9 100644 --- a/std/math/emulated/hints.go +++ b/std/math/emulated/hints.go @@ -22,6 +22,7 @@ func GetHints() []solver.Hint { InverseHint, SqrtHint, mulHint, + subPaddingHint, } } @@ -153,3 +154,51 @@ func SqrtHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { return nil }) } + +// subPaddingHint computes the padding for the subtraction of two numbers. It +// ensures that the padding is a multiple of the modulus. Can be used to avoid +// underflow. +// +// In case of fixed modulus use subPadding instead. +func subPaddingHint(mod *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) < 4 { + return fmt.Errorf("input must be at least four elements") + } + nbLimbs := int(inputs[0].Int64()) + bitsPerLimbs := uint(inputs[1].Uint64()) + overflow := uint(inputs[2].Uint64()) + retLimbs := int(inputs[3].Int64()) + if len(inputs[4:]) != nbLimbs { + return fmt.Errorf("input length mismatch") + } + if len(outputs) != retLimbs { + return fmt.Errorf("result does not fit into output") + } + pLimbs := inputs[4 : 4+nbLimbs] + p := new(big.Int) + if err := recompose(pLimbs, bitsPerLimbs, p); err != nil { + return fmt.Errorf("recompose modulus: %w", err) + } + padLimbs := subPadding(p, bitsPerLimbs, overflow, uint(nbLimbs)) + for i := range padLimbs { + outputs[i].Set(padLimbs[i]) + } + + return nil +} + +func (f *Field[T]) computeSubPaddingHint(overflow uint, nbLimbs uint, modulus *Element[T]) *Element[T] { + var fp T + inputs := []frontend.Variable{fp.NbLimbs(), fp.BitsPerLimb(), overflow, nbLimbs} + inputs = append(inputs, modulus.Limbs...) + res, err := f.api.NewHint(subPaddingHint, int(nbLimbs), inputs...) + if err != nil { + panic(fmt.Sprintf("sub padding hint: %v", err)) + } + for i := range res { + f.checker.Check(res[i], int(fp.BitsPerLimb()+overflow+1)) + } + padding := f.newInternalElement(res, fp.BitsPerLimb()+overflow+1) + f.checkZero(padding, modulus) + return padding +}