Skip to content

Commit

Permalink
feat: Fiat-Shamir transcript using a short hash (#900)
Browse files Browse the repository at this point in the history
* chore: take challenge slice instead of vararg

* feat: implement some options for FS

* feat: modify FS behaviour accoring to options

* docs: prettify method documentation

* test: random test for FS

* chore: use domain separation in sumcheck FS

* feat: add shorthand method to init correct transcript

* test: implement FS tests with marshal

* test: fix test to use domain separation

* test: fix test to use domain separation

* fix: correct marshal for zero point

* fix: zero point marshal for emulated curves
  • Loading branch information
ivokub authored Nov 8, 2023
1 parent 6477e51 commit 3f98e9b
Show file tree
Hide file tree
Showing 13 changed files with 297 additions and 71 deletions.
3 changes: 3 additions & 0 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ func (c *Curve[B, S]) MarshalG1(p AffinePoint[B]) []frontend.Variable {
res := make([]frontend.Variable, 2*nbBits)
copy(res, bx)
copy(res[len(bx):], by)
xZ := c.baseApi.IsZero(x)
yZ := c.baseApi.IsZero(y)
res[1] = c.api.Mul(xZ, yZ)
return res
}

Expand Down
53 changes: 33 additions & 20 deletions std/algebra/emulated/sw_emulated/point_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,27 +83,40 @@ func (c *MarshalG1Test[T, S]) Define(api frontend.API) error {

func TestMarshalG1(t *testing.T) {
assert := test.NewAssert(t)
_, _, g, _ := bw6761.Generators()
gBytes := g.Marshal()
nbBytes := 2 * fp_bw6761.Bytes
nbBits := nbBytes * 8
circuit := &MarshalG1Test[emulated.BW6761Fp, emulated.BW6761Fr]{
R: make([]frontend.Variable, nbBits),
}
witness := &MarshalG1Test[emulated.BW6761Fp, emulated.BW6761Fr]{
G: AffinePoint[emulated.BW6761Fp]{
X: emulated.ValueOf[emulated.BW6761Fp](g.X),
Y: emulated.ValueOf[emulated.BW6761Fp](g.Y),
},
R: make([]frontend.Variable, nbBits),
}
for i := 0; i < nbBytes; i++ {
for j := 0; j < 8; j++ {
witness.R[i*8+j] = (gBytes[i] >> (7 - j)) & 1
testFn := func(r fr_bw6761.Element) {
var P bw6761.G1Affine
P.ScalarMultiplicationBase(r.BigInt(new(big.Int)))
gBytes := P.Marshal()
nbBytes := 2 * fp_bw6761.Bytes
nbBits := nbBytes * 8
circuit := &MarshalG1Test[emulated.BW6761Fp, emulated.BW6761Fr]{
R: make([]frontend.Variable, nbBits),
}
}
err := test.IsSolved(circuit, witness, testCurve.ScalarField())
assert.NoError(err)
witness := &MarshalG1Test[emulated.BW6761Fp, emulated.BW6761Fr]{
G: AffinePoint[emulated.BW6761Fp]{
X: emulated.ValueOf[emulated.BW6761Fp](P.X),
Y: emulated.ValueOf[emulated.BW6761Fp](P.Y),
},
R: make([]frontend.Variable, nbBits),
}
for i := 0; i < nbBytes; i++ {
for j := 0; j < 8; j++ {
witness.R[i*8+j] = (gBytes[i] >> (7 - j)) & 1
}
}
err := test.IsSolved(circuit, witness, testCurve.ScalarField())
assert.NoError(err)
}
assert.Run(func(assert *test.Assert) {
var r fr_bw6761.Element
r.SetRandom()
testFn(r)
})
assert.Run(func(assert *test.Assert) {
var r fr_bw6761.Element
r.SetZero()
testFn(r)
})
}

type NegTest[T, S emulated.FieldParams] struct {
Expand Down
40 changes: 25 additions & 15 deletions std/algebra/native/sw_bls12377/g1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,33 @@ func (c *MarshalG1Test) Define(api frontend.API) error {
func TestMarshalG1(t *testing.T) {
assert := test.NewAssert(t)

// sample a random point
var r fr.Element
r.SetRandom()
var br big.Int
r.BigInt(&br)
_, _, g, _ := bls12377.Generators()
g.ScalarMultiplication(&g, &br)
gBytes := g.Marshal()
var witness MarshalG1Test
witness.P.Assign(&g)
for i := 0; i < 96; i++ {
for j := 0; j < 8; j++ {
witness.R[i*8+j] = (gBytes[i] >> (7 - j)) & 1
testfn := func(r fr.Element) {
var br big.Int
r.BigInt(&br)
_, _, g, _ := bls12377.Generators()
g.ScalarMultiplication(&g, &br)
gBytes := g.Marshal()
var witness MarshalG1Test
witness.P.Assign(&g)
for i := 0; i < 96; i++ {
for j := 0; j < 8; j++ {
witness.R[i*8+j] = (gBytes[i] >> (7 - j)) & 1
}
}
var circuit MarshalG1Test
assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_761))
}
var circuit MarshalG1Test
assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_761))
assert.Run(func(assert *test.Assert) {
// sample a random point
var r fr.Element
r.SetRandom()
testfn(r)
})
assert.Run(func(assert *test.Assert) {
var r fr.Element
r.SetZero()
testfn(r)
})
}

// -------------------------------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions std/algebra/native/sw_bls12377/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ func (c *Curve) MarshalG1(P G1Affine) []frontend.Variable {
res[i] = x[nbBits-1-i]
res[i+nbBits] = y[nbBits-1-i]
}
xZ := c.api.IsZero(P.X)
yZ := c.api.IsZero(P.Y)
res[1] = c.api.Mul(xZ, yZ)
return res
}

Expand Down
40 changes: 25 additions & 15 deletions std/algebra/native/sw_bls24315/g1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,33 @@ func (c *MarshalG1Test) Define(api frontend.API) error {
func TestMarshalG1(t *testing.T) {
assert := test.NewAssert(t)

// sample a random point
var r fr.Element
r.SetRandom()
var br big.Int
r.BigInt(&br)
_, _, g, _ := bls24315.Generators()
g.ScalarMultiplication(&g, &br)
gBytes := g.Marshal()
var witness MarshalG1Test
witness.P.Assign(&g)
for i := 0; i < 80; i++ {
for j := 0; j < 8; j++ {
witness.R[i*8+j] = (gBytes[i] >> (7 - j)) & 1
testfn := func(r fr.Element) {
var br big.Int
r.BigInt(&br)
_, _, g, _ := bls24315.Generators()
g.ScalarMultiplication(&g, &br)
gBytes := g.Marshal()
var witness MarshalG1Test
witness.P.Assign(&g)
for i := 0; i < 80; i++ {
for j := 0; j < 8; j++ {
witness.R[i*8+j] = (gBytes[i] >> (7 - j)) & 1
}
}
var circuit MarshalG1Test
assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_633))
}
var circuit MarshalG1Test
assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_633))
assert.Run(func(assert *test.Assert) {
// sample a random point
var r fr.Element
r.SetRandom()
testfn(r)
})
assert.Run(func(assert *test.Assert) {
var r fr.Element
r.SetZero()
testfn(r)
})
}

// -------------------------------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions std/algebra/native/sw_bls24315/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ func (c *Curve) MarshalG1(P G1Affine) []frontend.Variable {
res[i] = x[nbBits-1-i]
res[i+nbBits] = y[nbBits-1-i]
}
xZ := c.api.IsZero(P.X)
yZ := c.api.IsZero(P.Y)
res[1] = c.api.Mul(xZ, yZ)
return res
}

Expand Down
2 changes: 1 addition & 1 deletion std/commitments/fri/fri.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (s RadixTwoFri) verifyProofOfProximitySingleRound(api frontend.API, salt fr
xis[i] = paddNaming(fmt.Sprintf("x%d", i), frSize)
}
xis[s.nbSteps] = paddNaming("s0", frSize)
fs := fiatshamir.NewTranscript(api, s.h, xis...)
fs := fiatshamir.NewTranscript(api, s.h, xis, fiatshamir.WithDomainSeparation())
xi := make([]frontend.Variable, s.nbSteps)

// the salt is binded to the first challenge, to ensure the challenges
Expand Down
72 changes: 62 additions & 10 deletions std/fiat-shamir/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ package fiatshamir

import (
"errors"
"slices"

"github.com/consensys/gnark/constant"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/math/bits"
)

// errChallengeNotFound is returned when a wrong challenge name is provided.
Expand All @@ -41,6 +43,8 @@ type Transcript struct {

// gnark API
api frontend.API

config transcriptConfig
}

type challenge struct {
Expand All @@ -53,19 +57,24 @@ type challenge struct {
// NewTranscript returns a new transcript.
// h is the hash function that is used to compute the challenges.
// challenges are the name of the challenges. The order is important.
func NewTranscript(api frontend.API, h hash.FieldHasher, challengesID ...string) Transcript {
func NewTranscript(api frontend.API, h hash.FieldHasher, challengesID []string, opts ...TranscriptOption) *Transcript {
cfg := transcriptConfig{}
for _, opt := range opts {
opt(&cfg)
}
n := len(challengesID)
t := Transcript{
challenges: make(map[string]challenge, n),
api: api,
h: h,
config: cfg,
}

for i := 0; i < n; i++ {
t.challenges[challengesID[i]] = challenge{position: i}
}

return t
return &t
}

// Bind binds the challenge to value. A challenge can be binded to an
Expand All @@ -92,10 +101,10 @@ func (t *Transcript) Bind(challengeID string, values []frontend.Variable) error

// ComputeChallenge computes the challenge corresponding to the given name.
// The resulting variable is:
// * H(name ∥ previous_challenge ∥ binded_values...) if the challenge is not the first one
// * H(name ∥ binded_values... ) if it's is the first challenge
// - H(name ∥ previous_challenge ∥ binded_values...) if the challenge is not the first one
// - H(name ∥ binded_values... ) if it's is the first challenge
func (t *Transcript) ComputeChallenge(challengeID string) (frontend.Variable, error) {

var err error
challenge, ok := t.challenges[challengeID]

if !ok {
Expand All @@ -110,19 +119,34 @@ func (t *Transcript) ComputeChallenge(challengeID string) (frontend.Variable, er
t.h.Reset()

// write the challenge name, the purpose is to have a domain separator
cChallenge := []byte(challengeID) // if we send a string, it is assumed to be a base10 number
if challengeName, err := constant.HashedBytes(t.api, cChallenge); err == nil {
t.h.Write(challengeName)
challengeInput := []byte(challengeID)
var challengeHashInput frontend.Variable = challengeInput
if t.config.withDomainSeparation {
challengeHashInput, err = constant.HashedBytes(t.api, []byte(challengeID))
if err != nil {
return nil, err
}
}
if t.config.tryBitmode > 0 {
challengeBits := bits.ToBinary(t.api, challengeInput, bits.WithNbDigits(8*len(challengeInput)))
slices.Reverse(challengeBits)
t.h.Write(challengeBits...)
} else {
return nil, err
t.h.Write(challengeHashInput)
}

// write the previous challenge if it's not the first challenge
if challenge.position != 0 {
if t.previous == nil || (t.previous.position != challenge.position-1) {
return nil, errPreviousChallengeNotComputed
}
t.h.Write(t.previous.value)
if t.config.tryBitmode > 0 {
prevBits := bits.ToBinary(t.api, t.previous.value, bits.WithNbDigits(t.config.tryBitmode))
slices.Reverse(prevBits)
t.h.Write(prevBits...)
} else {
t.h.Write(t.previous.value)
}
}

// write the binded values in the order they were added
Expand All @@ -140,3 +164,31 @@ func (t *Transcript) ComputeChallenge(challengeID string) (frontend.Variable, er
return challenge.value, nil

}

type transcriptConfig struct {
tryBitmode int
withDomainSeparation bool
}

// TranscriptOption allows modifying the [Transcript] operation.
type TranscriptOption func(tc *transcriptConfig)

// WithTryBitmode changes the [Transcript] to work on bits instead of field
// elements when writing input to the hasher. Requires that the hasher is also
// set to work in bitmode. This mode of operation is useful in cases where we
// work in mismatching fields and want to avoid overflows.
func WithTryBitmode(nbBits int) TranscriptOption {
return func(tc *transcriptConfig) {
tc.tryBitmode = nbBits
}
}

// WithDomainSeparation adds domain separation string `string:` as defined in
// RCF 9380. This mode of operation is beneficial when seeking for compatibility
// with native Transcript when initialized using gnark-crypto's MiMC
// implementation.
func WithDomainSeparation() TranscriptOption {
return func(tc *transcriptConfig) {
tc.withDomainSeparation = true
}
}
9 changes: 6 additions & 3 deletions std/fiat-shamir/transcript_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package fiatshamir

import (
"crypto/rand"
"math/big"
"testing"

Expand Down Expand Up @@ -45,7 +46,7 @@ func (circuit *FiatShamirCircuit) Define(api frontend.API) error {
}

// New transcript with 3 challenges to be derived
tsSnark := NewTranscript(api, &hSnark, "alpha", "beta", "gamma")
tsSnark := NewTranscript(api, &hSnark, []string{"alpha", "beta", "gamma"}, WithDomainSeparation())

// Bind challenges
if err := tsSnark.Bind("alpha", circuit.Bindings[0][:]); err != nil {
Expand Down Expand Up @@ -83,6 +84,7 @@ func (circuit *FiatShamirCircuit) Define(api frontend.API) error {
}

func TestFiatShamir(t *testing.T) {
var err error
assert := test.NewAssert(t)

testData := map[ecc.ID]hash.Hash{
Expand All @@ -101,10 +103,11 @@ func TestFiatShamir(t *testing.T) {
// instantiate the hash and the transcript in plain go
ts := fiatshamir.NewTranscript(h.New(), "alpha", "beta", "gamma")

var bindings [3][4]big.Int
var bindings [3][4]*big.Int
for i := 0; i < 3; i++ {
for j := 0; j < 4; j++ {
bindings[i][j].SetUint64(uint64(i * j))
bindings[i][j], err = rand.Int(rand.Reader, curveID.ScalarField())
assert.NoError(err)
}
}
frSize := utils.ByteLen(curveID.ScalarField())
Expand Down
7 changes: 3 additions & 4 deletions std/gkr/gkr.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package gkr

import (
"fmt"
"strconv"

"github.com/consensys/gnark/frontend"
fiatshamir "github.com/consensys/gnark/std/fiat-shamir"
"github.com/consensys/gnark/std/polynomial"
"github.com/consensys/gnark/std/sumcheck"
"strconv"
)

// @tabaie TODO: Contains many things copy-pasted from gnark-crypto. Generify somehow?
Expand Down Expand Up @@ -198,9 +199,7 @@ func setup(api frontend.API, c Circuit, assignment WireAssignment, transcriptSet

if transcriptSettings.Transcript == nil {
challengeNames := ChallengeNames(o.sorted, o.nbVars, transcriptSettings.Prefix)
transcript := fiatshamir.NewTranscript(
api, transcriptSettings.Hash, challengeNames...)
o.transcript = &transcript
o.transcript = fiatshamir.NewTranscript(api, transcriptSettings.Hash, challengeNames)
if err = o.transcript.Bind(challengeNames[0], transcriptSettings.BaseChallenges); err != nil {
return o, err
}
Expand Down
Loading

0 comments on commit 3f98e9b

Please sign in to comment.