Skip to content

Commit

Permalink
feat: expmod with variable modulus (#1090)
Browse files Browse the repository at this point in the history
* chore: remove unused arguments

* feat: allow non-default modulus

* feat: implement non-default modulus arithmetic

* test: tests for non-default arithmetic

* remove debug calls

* chore: clean only if set

* feat: add zero check with custom modulus

* feat: change quotient length in case of var modulus

* refactor: subtraction padding into callable function

* feat: implement variable modulus subtraction

* feat: implement variable modulus equality assertion

* test: variable modulus tests

* feat: implement custom mod exp

* refactor: rename methods

* refactor: use single impl of mulmod and checkzero

* feat: add automatic lazy reduction for var-mod addition

* refactor: make var-mod sub private

It is difficult to implement automatic reduction here due to the
recursive dependency of methods. As the goal is to provide only very
limited API for now, then make it private for now.

* refactor: move parameters

* refactor: move var-mod methods to Field

* refactor: subPaddingHint to hints file

* docs: add package documentation

* feat: implement fixed-mod exp

* test: make tests cheaper to run

* fix: handle mul hint edge case when modulus is zero

* feat: implement expmod precompile

* perf: select first bit in expmod instead of mul

* docs: indicate implementation of Expmod

---------

Co-authored-by: Youssef El Housni <[email protected]>
  • Loading branch information
ivokub and yelhousni authored Apr 18, 2024
1 parent c38cdd3 commit 3c506fd
Show file tree
Hide file tree
Showing 14 changed files with 593 additions and 25 deletions.
30 changes: 30 additions & 0 deletions std/evmprecompiles/05-expmod.go
Original file line number Diff line number Diff line change
@@ -1 +1,31 @@
package evmprecompiles

import (
"fmt"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/math/emulated/emparams"
)

// Expmod implements [MODEXP] precompile contract at address 0x05.
//
// Internally, uses 4k elements for representing the base, exponent and modulus,
// upper bounding the sizes of the inputs. The runtime is constant regardless of
// the actual length of the inputs.
//
// [MODEXP]: https://ethereum.github.io/execution-specs/autoapi/ethereum/paris/vm/precompiled_contracts/expmod/index.html
func Expmod(api frontend.API, base, exp, modulus *emulated.Element[emparams.Mod1e4096]) *emulated.Element[emparams.Mod1e4096] {
// x^0 = 1
// x mod 0 = 0
f, err := emulated.NewField[emparams.Mod1e4096](api)
if err != nil {
panic(fmt.Sprintf("new field: %v", err))
}
// in case modulus is zero, then need to compute with dummy values and return zero as a result
isZeroMod := f.IsZero(modulus)
modulus = f.Select(isZeroMod, f.One(), modulus)
res := f.ModExp(base, exp, modulus)
res = f.Select(isZeroMod, f.Zero(), res)
return res
}
86 changes: 86 additions & 0 deletions std/evmprecompiles/05-expmod_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package evmprecompiles

import (
"crypto/rand"
"fmt"
"math/big"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/math/emulated/emparams"
"github.com/consensys/gnark/test"
)

type expmodCircuit struct {
Base emulated.Element[emparams.Mod1e4096]
Exp emulated.Element[emparams.Mod1e4096]
Mod emulated.Element[emparams.Mod1e4096]
Result emulated.Element[emparams.Mod1e4096]
edgeCases bool
}

func (c *expmodCircuit) Define(api frontend.API) error {
res := Expmod(api, &c.Base, &c.Exp, &c.Mod)
f, err := emulated.NewField[emparams.Mod1e4096](api)
if err != nil {
return fmt.Errorf("new field: %w", err)
}
if c.edgeCases {
// cannot use ModAssertIsEqual for edge cases. But the output is either
// 0 or 1 so can use AssertIsEqual
f.AssertIsEqual(res, &c.Result)
} else {
// for random case need to use ModAssertIsEqual
f.ModAssertIsEqual(&c.Result, res, &c.Mod)
}
return nil
}

func testInstance(edgeCases bool, base, exp, modulus, result *big.Int) error {
circuit := &expmodCircuit{edgeCases: edgeCases}
assignment := &expmodCircuit{
Base: emulated.ValueOf[emparams.Mod1e4096](base),
Exp: emulated.ValueOf[emparams.Mod1e4096](exp),
Mod: emulated.ValueOf[emparams.Mod1e4096](modulus),
Result: emulated.ValueOf[emparams.Mod1e4096](result),
}
return test.IsSolved(circuit, assignment, ecc.BLS12_377.ScalarField())
}

func TestRandomInstance(t *testing.T) {
assert := test.NewAssert(t)
for _, bits := range []int{256, 512, 1024, 2048, 4096} {
assert.Run(func(assert *test.Assert) {
modulus := new(big.Int).Lsh(big.NewInt(1), uint(bits))
base, _ := rand.Int(rand.Reader, modulus)
exp, _ := rand.Int(rand.Reader, modulus)
res := new(big.Int).Exp(base, exp, modulus)
err := testInstance(false, base, exp, modulus, res)
assert.NoError(err)
}, fmt.Sprintf("random-%d", bits))
}
}

func TestEdgeCases(t *testing.T) {
assert := test.NewAssert(t)
testCases := []struct {
base, exp, modulus, result *big.Int
}{
{big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0)}, // 0^0 = 0 mod 0
{big.NewInt(0), big.NewInt(0), big.NewInt(1), big.NewInt(1)}, // 0^0 = 1 mod 1
{big.NewInt(0), big.NewInt(0), big.NewInt(123), big.NewInt(1)}, // 0^0 = 1 mod 123
{big.NewInt(123), big.NewInt(123), big.NewInt(0), big.NewInt(0)}, // 123^123 = 0 mod 0
{big.NewInt(123), big.NewInt(123), big.NewInt(0), big.NewInt(0)}, // 123^123 = 0 mod 1
{big.NewInt(0), big.NewInt(123), big.NewInt(123), big.NewInt(0)}, // 0^123 = 0 mod 123
{big.NewInt(123), big.NewInt(0), big.NewInt(123), big.NewInt(1)}, // 123^0 = 1 mod 123

}
for i, tc := range testCases {
assert.Run(func(assert *test.Assert) {
err := testInstance(true, tc.base, tc.exp, tc.modulus, tc.result)
assert.NoError(err)
}, fmt.Sprintf("edge-%d", i))
}
}
2 changes: 1 addition & 1 deletion std/evmprecompiles/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// 2. SHA256 ❌ -- in progress
// 3. RIPEMD160 ❌ -- postponed
// 4. ID ❌ -- trivial to implement without function
// 5. EXPMOD -- in progress
// 5. EXPMOD -- function [Expmod]
// 6. BN_ADD ✅ -- function [ECAdd]
// 7. BN_MUL ✅ -- function [ECMul]
// 8. SNARKV ✅ -- function [ECPair]
Expand Down
9 changes: 3 additions & 6 deletions std/math/emulated/composition.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,7 @@ func decompose(input *big.Int, nbBits uint, res []*big.Int) error {
//
// then no such underflow happens and s = a-b (mod p) as the padding is multiple
// of p.
func subPadding[T FieldParams](overflow uint, nbLimbs uint) []*big.Int {
var fp T
p := fp.Modulus()
bitsPerLimbs := fp.BitsPerLimb()
func subPadding(modulus *big.Int, bitsPerLimbs uint, overflow uint, nbLimbs uint) []*big.Int {

// first, we build a number nLimbs, such that nLimbs > b;
// here b is defined by its bounds, that is b is an element with nbLimbs of (bitsPerLimbs+overflow)
Expand All @@ -86,8 +83,8 @@ func subPadding[T FieldParams](overflow uint, nbLimbs uint) []*big.Int {
panic(fmt.Sprintf("recompose: %v", err))
}
// mod reduce n, and negate it
n.Mod(n, p)
n.Sub(p, n)
n.Mod(n, modulus)
n.Sub(modulus, n)

// construct pad such that:
// pad := n - neg(n mod p) == kp
Expand Down
2 changes: 1 addition & 1 deletion std/math/emulated/composition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func testSubPadding[T FieldParams](t *testing.T) {
assert := test.NewAssert(t)
for i := fp.NbLimbs(); i < 2*fp.NbLimbs(); i++ {
assert.Run(func(assert *test.Assert) {
limbs := subPadding[T](0, i)
limbs := subPadding(fp.Modulus(), fp.BitsPerLimb(), 0, i)
padValue := new(big.Int)
if err := recompose(limbs, fp.BitsPerLimb(), padValue); err != nil {
assert.FailNow("recompose", err)
Expand Down
99 changes: 99 additions & 0 deletions std/math/emulated/custommod.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package emulated

import (
"errors"

"github.com/consensys/gnark/frontend"
)

// ModMul computes a*b mod modulus. Instead of taking modulus as a constant
// parametrized by T, it is passed as an argument. This allows to use a variable
// modulus in the circuit. Type parameter T should be sufficiently big to fit a,
// b and modulus. Recommended to use [emparams.Mod1e512] or
// [emparams.Mod1e4096].
//
// NB! circuit complexity depends on T rather on the actual length of the modulus.
func (f *Field[T]) ModMul(a, b *Element[T], modulus *Element[T]) *Element[T] {
res := f.mulMod(a, b, 0, modulus)
return res
}

// ModAdd computes a+b mod modulus. Instead of taking modulus as a constant
// parametrized by T, it is passed as an argument. This allows to use a variable
// modulus in the circuit. Type parameter T should be sufficiently big to fit a,
// b and modulus. Recommended to use [emparams.Mod1e512] or
// [emparams.Mod1e4096].
//
// NB! circuit complexity depends on T rather on the actual length of the modulus.
func (f *Field[T]) ModAdd(a, b *Element[T], modulus *Element[T]) *Element[T] {
// inlined version of [Field.reduceAndOp] which uses variable-modulus reduction
var nextOverflow uint
var err error
var target overflowError
for nextOverflow, err = f.addPreCond(a, b); errors.As(err, &target); nextOverflow, err = f.addPreCond(a, b) {
if errors.As(err, &target) {
if !target.reduceRight {
a = f.mulMod(a, f.shortOne(), 0, modulus)
} else {
b = f.mulMod(b, f.shortOne(), 0, modulus)
}
}
}
res := f.add(a, b, nextOverflow)
return res
}

func (f *Field[T]) modSub(a, b *Element[T], modulus *Element[T]) *Element[T] {
// like fixed modulus subtraction, but for sub padding need to use hint
// instead of assuming T as a constant. And when doing as a hint, then need
// to assert that the padding is a multiple of the modulus (done inside callSubPaddingHint)
nextOverflow := max(b.overflow+1, a.overflow) + 1
nbLimbs := max(len(a.Limbs), len(b.Limbs))
limbs := make([]frontend.Variable, nbLimbs)
padding := f.computeSubPaddingHint(b.overflow, uint(nbLimbs), modulus)
for i := range limbs {
limbs[i] = padding.Limbs[i]
if i < len(a.Limbs) {
limbs[i] = f.api.Add(limbs[i], a.Limbs[i])
}
if i < len(b.Limbs) {
limbs[i] = f.api.Sub(limbs[i], b.Limbs[i])
}
}
res := f.newInternalElement(limbs, nextOverflow)
return res
}

// ModAssertIsEqual asserts equality of a and b mod modulus. Instead of taking
// modulus as a constant parametrized by T, it is passed as an argument. This
// allows to use a variable modulus in the circuit. Type parameter T should be
// sufficiently big to fit a, b and modulus. Recommended to use
// [emparams.Mod1e512] or [emparams.Mod1e4096].
//
// NB! circuit complexity depends on T rather on the actual length of the modulus.
func (f *Field[T]) ModAssertIsEqual(a, b *Element[T], modulus *Element[T]) {
// like fixed modulus AssertIsEqual, but uses current Sub implementation for
// computing the diff
diff := f.modSub(b, a, modulus)
f.checkZero(diff, modulus)
}

// ModExp computes base^exp mod modulus. Instead of taking modulus as a constant
// parametrized by T, it is passed as an argument. This allows to use a variable
// modulus in the circuit. Type parameter T should be sufficiently big to fit
// base, exp and modulus. Recommended to use [emparams.Mod1e512] or
// [emparams.Mod1e4096].
//
// NB! circuit complexity depends on T rather on the actual length of the modulus.
func (f *Field[T]) ModExp(base, exp, modulus *Element[T]) *Element[T] {
expBts := f.ToBits(exp)
n := len(expBts)
res := f.Select(expBts[0], base, f.One())
base = f.ModMul(base, base, modulus)
for i := 1; i < n-1; i++ {
res = f.Select(expBts[i], f.ModMul(base, res, modulus), res)
base = f.ModMul(base, base, modulus)
}
res = f.Select(expBts[n-1], f.ModMul(base, res, modulus), res)
return res
}
Loading

0 comments on commit 3c506fd

Please sign in to comment.