Skip to content

Commit

Permalink
implement getProof and verifyProof on trie (#1852)
Browse files Browse the repository at this point in the history
* implement getProof and verifyProof trie logic 
Co-authored-by: rian <[email protected]>
  • Loading branch information
rianhughes authored May 13, 2024
1 parent 576d342 commit 264eb96
Show file tree
Hide file tree
Showing 4 changed files with 565 additions and 0 deletions.
38 changes: 38 additions & 0 deletions core/trie/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@ func NewKey(length uint8, keyBytes []byte) Key {
return k
}

func (k *Key) SubKey(n uint8) *Key {
if n > k.len {
panic("n is greater than the length of the key")
}

newKey := &Key{len: n}
copy(newKey.bitset[:], k.bitset[len(k.bitset)-int((k.len+7)/8):]) //nolint:gomnd

// Shift right by the number of bits that are not needed
shift := k.len - n
for i := len(newKey.bitset) - 1; i >= 0; i-- {
newKey.bitset[i] >>= shift
if i > 0 {
newKey.bitset[i] |= newKey.bitset[i-1] << (8 - shift)
}
}

return newKey
}

func (k *Key) bytesNeeded() uint {
const byteBits = 8
return (uint(k.len) + (byteBits - 1)) / byteBits
Expand Down Expand Up @@ -114,3 +134,21 @@ func (k *Key) Truncate(length uint8) {
inUseBytes[0] = (inUseBytes[0] << unusedBitsCount) >> unusedBitsCount
}
}

func (k *Key) RemoveLastBit() {
if k.len == 0 {
return
}

k.len--

unusedBytes := k.unusedBytes()
clear(unusedBytes)

// clear upper bits on the last used byte
inUseBytes := k.inUseBytes()
unusedBitsCount := 8 - (k.len % 8)
if unusedBitsCount != 8 && len(inUseBytes) > 0 {
inUseBytes[0] = (inUseBytes[0] << unusedBitsCount) >> unusedBitsCount
}
}
179 changes: 179 additions & 0 deletions core/trie/proof.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package trie

import (
"fmt"

"github.com/NethermindEth/juno/core/felt"
)

// https://github.com/starknet-io/starknet-p2p-specs/blob/main/p2p/proto/snapshot.proto#L6
type ProofNode struct {
Binary *Binary
Edge *Edge
}

// Note: does not work for leaves
func (pn *ProofNode) Hash(hash hashFunc) *felt.Felt {
switch {
case pn.Binary != nil:
return hash(pn.Binary.LeftHash, pn.Binary.RightHash)
case pn.Edge != nil:
length := make([]byte, len(pn.Edge.Path.bitset))
length[len(pn.Edge.Path.bitset)-1] = pn.Edge.Path.len
pathFelt := pn.Edge.Path.Felt()
lengthFelt := new(felt.Felt).SetBytes(length)
return new(felt.Felt).Add(hash(pn.Edge.Child, &pathFelt), lengthFelt)
default:
return nil
}
}

func (pn *ProofNode) PrettyPrint() {
if pn.Binary != nil {
fmt.Printf(" Binary:\n")
fmt.Printf(" LeftHash: %v\n", pn.Binary.LeftHash)
fmt.Printf(" RightHash: %v\n", pn.Binary.RightHash)
}
if pn.Edge != nil {
fmt.Printf(" Edge:\n")
fmt.Printf(" Child: %v\n", pn.Edge.Child)
fmt.Printf(" Path: %v\n", pn.Edge.Path)
fmt.Printf(" Value: %v\n", pn.Edge.Value)
}
}

type Binary struct {
LeftHash *felt.Felt
RightHash *felt.Felt
}

type Edge struct {
Child *felt.Felt
Path *Key
Value *felt.Felt
}

func isEdge(parentKey *Key, sNode storageNode) bool {
sNodeLen := sNode.key.len
if parentKey == nil { // Root
return sNodeLen != 0
}
return sNodeLen-parentKey.len > 1
}

// Note: we need to account for the fact that Junos Trie has nodes that are Binary AND Edge,
// whereas the protocol requires nodes that are Binary XOR Edge
func transformNode(tri *Trie, parentKey *Key, sNode storageNode) (*Edge, *Binary, error) {
isEdgeBool := isEdge(parentKey, sNode)

var edge *Edge
if isEdgeBool {
edgePath := path(sNode.key, parentKey)
edge = &Edge{
Path: &edgePath,
Child: sNode.node.Value,
}
}
if sNode.key.len == tri.height { // Leaf
return edge, nil, nil
}
lNode, err := tri.GetNodeFromKey(sNode.node.Left)
if err != nil {
return nil, nil, err
}
rNode, err := tri.GetNodeFromKey(sNode.node.Right)
if err != nil {
return nil, nil, err
}

rightHash := rNode.Value
if isEdge(sNode.key, storageNode{node: rNode, key: sNode.node.Right}) {
edgePath := path(sNode.node.Right, sNode.key)
rEdge := ProofNode{Edge: &Edge{
Path: &edgePath,
Child: rNode.Value,
}}
rightHash = rEdge.Hash(tri.hash)
}
leftHash := lNode.Value
if isEdge(sNode.key, storageNode{node: lNode, key: sNode.node.Left}) {
edgePath := path(sNode.node.Left, sNode.key)
lEdge := ProofNode{Edge: &Edge{
Path: &edgePath,
Child: lNode.Value,
}}
leftHash = lEdge.Hash(tri.hash)
}
binary := &Binary{
LeftHash: leftHash,
RightHash: rightHash,
}

return edge, binary, nil
}

// https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L514
func GetProof(leaf *felt.Felt, tri *Trie) ([]ProofNode, error) {
leafKey := tri.feltToKey(leaf)
nodesToLeaf, err := tri.nodesFromRoot(&leafKey)
if err != nil {
return nil, err
}
proofNodes := []ProofNode{}

var parentKey *Key

for i := 0; i < len(nodesToLeaf); i++ {
sNode := nodesToLeaf[i]
sNodeEdge, sNodeBinary, err := transformNode(tri, parentKey, sNode)
if err != nil {
return nil, err
}
isLeaf := sNode.key.len == tri.height

if sNodeEdge != nil && !isLeaf { // Internal Edge
proofNodes = append(proofNodes, []ProofNode{{Edge: sNodeEdge}, {Binary: sNodeBinary}}...)
} else if sNodeEdge == nil && !isLeaf { // Internal Binary
proofNodes = append(proofNodes, []ProofNode{{Binary: sNodeBinary}}...)
} else if sNodeEdge != nil && isLeaf { // Leaf Edge
proofNodes = append(proofNodes, []ProofNode{{Edge: sNodeEdge}}...)
} else if sNodeEdge == nil && sNodeBinary == nil { // sNode is a binary leaf
break
}
parentKey = nodesToLeaf[i].key
}
return proofNodes, nil
}

// verifyProof checks if `leafPath` leads from `root` to `leafHash` along the `proofNodes`
// https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2006
func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode, hash hashFunc) bool {
if key.Len() != 251 { //nolint:gomnd
return false
}

expectedHash := root
remainingPath := key

for _, proofNode := range proofs {
if !proofNode.Hash(hash).Equal(expectedHash) {
return false
}
switch {
case proofNode.Binary != nil:
if remainingPath.Test(remainingPath.Len() - 1) {
expectedHash = proofNode.Binary.RightHash
} else {
expectedHash = proofNode.Binary.LeftHash
}
remainingPath.RemoveLastBit()
case proofNode.Edge != nil:
if !proofNode.Edge.Path.Equal(remainingPath.SubKey(proofNode.Edge.Path.Len())) {
return false
}
expectedHash = proofNode.Edge.Child
remainingPath.Truncate(proofNode.Edge.Path.Len())
}
}
return expectedHash.Equal(value)
}
Loading

0 comments on commit 264eb96

Please sign in to comment.