Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: bounded scalar multiplication #934

Merged
merged 13 commits into from
Nov 28, 2023
56 changes: 56 additions & 0 deletions std/algebra/algopts/algopts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Package algopts provides shareable options for modifying algebraic operations.
//
// This package is separate to avoid cyclic imports and sharing the structures
// between interface definition, implementation getters and actual
// implementations.
package algopts

import "fmt"

type algebraCfg struct {
NbScalarBits int
FoldMulti bool
}

// AlgebraOption allows modifying algebraic operation behaviour.
type AlgebraOption func(*algebraCfg) error

// WithNbScalarBits defines the number bits when doing scalar multiplication.
// May be used when it is known that only bits least significant bits are
// non-zero. Reduces the cost for scalar multiplication. If not set then full
// width of scalars used.
func WithNbScalarBits(bits int) AlgebraOption {
return func(ac *algebraCfg) error {
if ac.NbScalarBits != 0 {
return fmt.Errorf("WithNbBits already set")
}
ac.NbScalarBits = bits
return nil
}
}

// WithFoldingScalarMul can be used when calling MultiScalarMul. By using this
// option we assume that the scalars are `1, scalar, scalar^2, ...`. We use the
// first element as the scalar to be used as a folding coefficients. By using
// this option we avoid one scalar multiplication and do not need to compute the
// powers of the folding coefficient.
func WithFoldingScalarMul() AlgebraOption {
return func(ac *algebraCfg) error {
if ac.FoldMulti {
return fmt.Errorf("withFoldingScalarMul already set")
}
ac.FoldMulti = true
return nil
}
}

// NewConfig applies all given options and returns a configuration to be used.
func NewConfig(opts ...AlgebraOption) (*algebraCfg, error) {
ret := new(algebraCfg)
for i := range opts {
if err := opts[i](ret); err != nil {
return nil, err
}
}
return ret, nil
}
65 changes: 51 additions & 14 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"math/big"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/algebra/algopts"
"github.com/consensys/gnark/std/math/emulated"
"golang.org/x/exp/slices"
)
Expand Down Expand Up @@ -465,7 +466,11 @@ func (c *Curve[B, S]) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 *AffinePo
// [ELM03]: https://arxiv.org/pdf/math/0208038.pdf
// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf
// [Joye07]: https://www.iacr.org/archive/ches2007/47270135/47270135.pdf
func (c *Curve[B, S]) ScalarMul(p *AffinePoint[B], s *emulated.Element[S]) *AffinePoint[B] {
func (c *Curve[B, S]) ScalarMul(p *AffinePoint[B], s *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}

// if p=(0,0) we assign a dummy (0,1) to p and continue
selector := c.api.And(c.baseApi.IsZero(&p.X), c.baseApi.IsZero(&p.Y))
Expand All @@ -476,6 +481,9 @@ func (c *Curve[B, S]) ScalarMul(p *AffinePoint[B], s *emulated.Element[S]) *Affi
sr := c.scalarApi.Reduce(s)
sBits := c.scalarApi.ToBits(sr)
n := st.Modulus().BitLen()
if cfg.NbScalarBits > 2 && cfg.NbScalarBits < n {
n = cfg.NbScalarBits
}

// i = 1
Rb := c.triple(p)
Expand Down Expand Up @@ -518,19 +526,27 @@ func (c *Curve[B, S]) ScalarMul(p *AffinePoint[B], s *emulated.Element[S]) *Affi
//
// [HMV04]: https://link.springer.com/book/10.1007/b97644
// [EVM]: https://ethereum.github.io/yellowpaper/paper.pdf
func (c *Curve[B, S]) ScalarMulBase(s *emulated.Element[S]) *AffinePoint[B] {
g := c.Generator()
gm := c.GeneratorMultiples()
func (c *Curve[B, S]) ScalarMulBase(s *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}

var st S
sr := c.scalarApi.Reduce(s)
sBits := c.scalarApi.ToBits(sr)
n := st.Modulus().BitLen()
if cfg.NbScalarBits > 2 && cfg.NbScalarBits < n {
n = cfg.NbScalarBits
}
g := c.Generator()
gm := c.GeneratorMultiples()

// i = 1, 2
// gm[0] = 3g, gm[1] = 5g, gm[2] = 7g
res := c.Lookup2(sBits[1], sBits[2], g, &gm[0], &gm[1], &gm[2])

for i := 3; i < st.Modulus().BitLen(); i++ {
for i := 3; i < n; i++ {
// gm[i] = [2^i]g
tmp := c.add(res, &gm[i])
res = c.Select(sBits[i], tmp, res)
Expand Down Expand Up @@ -634,20 +650,41 @@ func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Ele
//
// For the points and scalars the same considerations apply as for
// [Curve.AddUnified] and [Curve.SalarMul].
func (c *Curve[B, S]) MultiScalarMul(p []*AffinePoint[B], s []*emulated.Element[S]) (*AffinePoint[B], error) {
if len(p) != len(s) {
return nil, fmt.Errorf("mismatching points and scalars slice lengths")
}
func (c *Curve[B, S]) MultiScalarMul(p []*AffinePoint[B], s []*emulated.Element[S], opts ...algopts.AlgebraOption) (*AffinePoint[B], error) {

if len(p) == 0 {
return &AffinePoint[B]{
X: *c.baseApi.Zero(),
Y: *c.baseApi.Zero(),
}, nil
}
res := c.ScalarMul(p[0], s[0])
for i := 1; i < len(p); i++ {
q := c.ScalarMul(p[i], s[i])
res = c.AddUnified(res, q)
cfg, err := algopts.NewConfig(opts...)
if err != nil {
return nil, fmt.Errorf("new config: %w", err)
}
if !cfg.FoldMulti {
// the scalars are unique
if len(p) != len(s) {
return nil, fmt.Errorf("mismatching points and scalars slice lengths")
}
res := c.ScalarMul(p[0], s[0])
for i := 1; i < len(p); i++ {
q := c.ScalarMul(p[i], s[i], opts...)
res = c.AddUnified(res, q)
}
return res, nil
} else {
// scalars are powers
if len(s) == 0 {
return nil, fmt.Errorf("need scalar for folding")
}
gamma := s[0]
res := c.ScalarMul(p[len(p)-1], gamma, opts...)
for i := len(p) - 2; i > 0; i-- {
res = c.Add(p[i], res)
res = c.ScalarMul(res, gamma, opts...)
}
res = c.Add(p[0], res)
return res, nil
}
return res, nil
}
50 changes: 50 additions & 0 deletions std/algebra/emulated/sw_emulated/point_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
fp_secp "github.com/consensys/gnark-crypto/ecc/secp256k1/fp"
fr_secp "github.com/consensys/gnark-crypto/ecc/secp256k1/fr"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/algebra/algopts"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/math/emulated/emparams"
"github.com/consensys/gnark/test"
Expand Down Expand Up @@ -987,3 +988,52 @@ func TestMultiScalarMul(t *testing.T) {
}, &assignment, ecc.BN254.ScalarField())
assert.NoError(err)
}

type ScalarMulTestBounded[T, S emulated.FieldParams] struct {
P, Q AffinePoint[T]
S emulated.Element[S]
bits int
}

func (c *ScalarMulTestBounded[T, S]) Define(api frontend.API) error {
cr, err := New[T, S](api, GetCurveParams[T]())
if err != nil {
return err
}
res := cr.ScalarMul(&c.P, &c.S, algopts.WithNbScalarBits(c.bits))
cr.AssertIsEqual(res, &c.Q)
return nil
}

func TestScalarMulBounded(t *testing.T) {
assert := test.NewAssert(t)
_, g := secp256k1.Generators()
var r fr_secp.Element
_, _ = r.SetRandom()
s := new(big.Int)
r.BigInt(s)
nbBits := 13
mask := big.NewInt(1)
mask.Lsh(mask, uint(nbBits))
mask.Sub(mask, big.NewInt(1))
s.And(s, mask)
var S secp256k1.G1Affine
S.ScalarMultiplication(&g, s)

circuit := ScalarMulTestBounded[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{
bits: nbBits,
}
witness := ScalarMulTestBounded[emulated.Secp256k1Fp, emulated.Secp256k1Fr]{
S: emulated.ValueOf[emulated.Secp256k1Fr](s),
P: AffinePoint[emulated.Secp256k1Fp]{
X: emulated.ValueOf[emulated.Secp256k1Fp](g.X),
Y: emulated.ValueOf[emulated.Secp256k1Fp](g.Y),
},
Q: AffinePoint[emulated.Secp256k1Fp]{
X: emulated.ValueOf[emulated.Secp256k1Fp](S.X),
Y: emulated.ValueOf[emulated.Secp256k1Fp](S.Y),
},
}
err := test.IsSolved(&circuit, &witness, testCurve.ScalarField())
assert.NoError(err)
}
7 changes: 4 additions & 3 deletions std/algebra/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package algebra

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/algebra/algopts"
"github.com/consensys/gnark/std/math/emulated"
)

Expand All @@ -25,16 +26,16 @@ type Curve[FR emulated.FieldParams, G1El G1ElementT] interface {

// ScalarMul returns the scalar multiplication of the point by a scalar. It
// does not modify the inputs.
ScalarMul(*G1El, *emulated.Element[FR]) *G1El
ScalarMul(*G1El, *emulated.Element[FR], ...algopts.AlgebraOption) *G1El

// ScalarMulBase returns the scalar multiplication of the curve base point
// by a scalar. It does not modify the scalar.
ScalarMulBase(*emulated.Element[FR]) *G1El
ScalarMulBase(*emulated.Element[FR], ...algopts.AlgebraOption) *G1El

// MultiScalarMul computes the sum ∑ s_i P_i for the input
// scalars s_i and points P_i. It returns an error if the input lengths
// mismatch.
MultiScalarMul([]*G1El, []*emulated.Element[FR]) (*G1El, error)
MultiScalarMul([]*G1El, []*emulated.Element[FR], ...algopts.AlgebraOption) (*G1El, error)

// MarshalG1 returns the binary decomposition G1.X || G1.Y. It matches the
// output of gnark-crypto's Marshal method on G1 points.
Expand Down
53 changes: 38 additions & 15 deletions std/algebra/native/sw_bls12377/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/algebra/algopts"
"github.com/consensys/gnark/std/algebra/native/fields_bls12377"
"github.com/consensys/gnark/std/math/bits"
"github.com/consensys/gnark/std/math/emulated"
Expand Down Expand Up @@ -90,7 +91,7 @@ func (c *Curve) Neg(P *G1Affine) *G1Affine {

// ScalarMul computes scalar*P and returns the result. It doesn't modify the
// inputs.
func (c *Curve) ScalarMul(P *G1Affine, s *Scalar) *G1Affine {
func (c *Curve) ScalarMul(P *G1Affine, s *Scalar, opts ...algopts.AlgebraOption) *G1Affine {
res := &G1Affine{
X: P.X,
Y: P.Y,
Expand All @@ -102,7 +103,7 @@ func (c *Curve) ScalarMul(P *G1Affine, s *Scalar) *G1Affine {

// ScalarMulBase computes scalar*G where G is the standard base point of the
// curve. It doesn't modify the scalar.
func (c *Curve) ScalarMulBase(s *Scalar) *G1Affine {
func (c *Curve) ScalarMulBase(s *Scalar, opts ...algopts.AlgebraOption) *G1Affine {
res := new(G1Affine)
varScalar := c.packScalarToVar(s)
res.ScalarMulBase(c.api, varScalar)
Expand All @@ -112,27 +113,49 @@ func (c *Curve) ScalarMulBase(s *Scalar) *G1Affine {
// MultiScalarMul computes ∑scalars_i * P_i and returns it. It doesn't modify
// the inputs. It returns an error if there is a mismatch in the lengths of the
// inputs.
func (c *Curve) MultiScalarMul(P []*G1Affine, scalars []*Scalar) (*G1Affine, error) {
if len(P) != len(scalars) {
return nil, fmt.Errorf("mismatching points and scalars slice lengths")
}
func (c *Curve) MultiScalarMul(P []*G1Affine, scalars []*Scalar, opts ...algopts.AlgebraOption) (*G1Affine, error) {
if len(P) == 0 {
return &G1Affine{
X: 0,
Y: 0,
}, nil
}
res := c.ScalarMul(P[0], scalars[0])
for i := 1; i < len(P); i++ {
q := c.ScalarMul(P[i], scalars[i])
cfg, err := algopts.NewConfig(opts...)
if err != nil {
return nil, fmt.Errorf("new config: %w", err)
}
if !cfg.FoldMulti {
if len(P) != len(scalars) {
return nil, fmt.Errorf("mismatching points and scalars slice lengths")
}
res := c.ScalarMul(P[0], scalars[0])
for i := 1; i < len(P); i++ {
q := c.ScalarMul(P[i], scalars[i], opts...)

// check for infinity
isInfinity := c.api.And(c.api.IsZero(P[i].X), c.api.IsZero(P[i].Y))
tmp := c.Add(res, q)
res.X = c.api.Select(isInfinity, res.X, tmp.X)
res.Y = c.api.Select(isInfinity, res.Y, tmp.Y)
// check for infinity...
isInfinity := c.api.And(c.api.IsZero(P[i].X), c.api.IsZero(P[i].Y))
tmp := c.Add(res, q)
res.X = c.api.Select(isInfinity, res.X, tmp.X)
res.Y = c.api.Select(isInfinity, res.Y, tmp.Y)
}
return res, nil
} else {
// scalars are powers
if len(scalars) == 0 {
return nil, fmt.Errorf("need scalar for folding")
}
gamma := scalars[0]
res := c.ScalarMul(P[len(P)-1], gamma, opts...)
for i := len(P) - 2; i > 0; i-- {
isInfinity := c.api.And(c.api.IsZero(P[i].X), c.api.IsZero(P[i].Y))
tmp := c.Add(P[i], res)
res.X = c.api.Select(isInfinity, res.X, tmp.X)
res.Y = c.api.Select(isInfinity, res.Y, tmp.Y)
res = c.ScalarMul(res, gamma, opts...)
}
res = c.Add(P[0], res)
return res, nil
}
return res, nil
}

// Pairing allows computing pairing-related operations in BLS12-377.
Expand Down
Loading