diff --git a/core/trie/key.go b/core/trie/key.go index f8f4f06734..dc946a5ef6 100644 --- a/core/trie/key.go +++ b/core/trie/key.go @@ -23,6 +23,26 @@ func NewKey(length uint8, keyBytes []byte) Key { return k } +func (k *Key) SubKey(n uint8) *Key { + if n > k.len { + panic("n is greater than the length of the key") + } + + newKey := &Key{len: n} + copy(newKey.bitset[:], k.bitset[len(k.bitset)-int((k.len+7)/8):]) //nolint:gomnd + + // Shift right by the number of bits that are not needed + shift := k.len - n + for i := len(newKey.bitset) - 1; i >= 0; i-- { + newKey.bitset[i] >>= shift + if i > 0 { + newKey.bitset[i] |= newKey.bitset[i-1] << (8 - shift) + } + } + + return newKey +} + func (k *Key) bytesNeeded() uint { const byteBits = 8 return (uint(k.len) + (byteBits - 1)) / byteBits @@ -114,3 +134,21 @@ func (k *Key) Truncate(length uint8) { inUseBytes[0] = (inUseBytes[0] << unusedBitsCount) >> unusedBitsCount } } + +func (k *Key) RemoveLastBit() { + if k.len == 0 { + return + } + + k.len-- + + unusedBytes := k.unusedBytes() + clear(unusedBytes) + + // clear upper bits on the last used byte + inUseBytes := k.inUseBytes() + unusedBitsCount := 8 - (k.len % 8) + if unusedBitsCount != 8 && len(inUseBytes) > 0 { + inUseBytes[0] = (inUseBytes[0] << unusedBitsCount) >> unusedBitsCount + } +} diff --git a/core/trie/proof.go b/core/trie/proof.go new file mode 100644 index 0000000000..57e6bdeac0 --- /dev/null +++ b/core/trie/proof.go @@ -0,0 +1,179 @@ +package trie + +import ( + "fmt" + + "github.com/NethermindEth/juno/core/felt" +) + +// https://github.com/starknet-io/starknet-p2p-specs/blob/main/p2p/proto/snapshot.proto#L6 +type ProofNode struct { + Binary *Binary + Edge *Edge +} + +// Note: does not work for leaves +func (pn *ProofNode) Hash(hash hashFunc) *felt.Felt { + switch { + case pn.Binary != nil: + return hash(pn.Binary.LeftHash, pn.Binary.RightHash) + case pn.Edge != nil: + length := make([]byte, len(pn.Edge.Path.bitset)) + length[len(pn.Edge.Path.bitset)-1] = pn.Edge.Path.len + pathFelt := pn.Edge.Path.Felt() + lengthFelt := new(felt.Felt).SetBytes(length) + return new(felt.Felt).Add(hash(pn.Edge.Child, &pathFelt), lengthFelt) + default: + return nil + } +} + +func (pn *ProofNode) PrettyPrint() { + if pn.Binary != nil { + fmt.Printf(" Binary:\n") + fmt.Printf(" LeftHash: %v\n", pn.Binary.LeftHash) + fmt.Printf(" RightHash: %v\n", pn.Binary.RightHash) + } + if pn.Edge != nil { + fmt.Printf(" Edge:\n") + fmt.Printf(" Child: %v\n", pn.Edge.Child) + fmt.Printf(" Path: %v\n", pn.Edge.Path) + fmt.Printf(" Value: %v\n", pn.Edge.Value) + } +} + +type Binary struct { + LeftHash *felt.Felt + RightHash *felt.Felt +} + +type Edge struct { + Child *felt.Felt + Path *Key + Value *felt.Felt +} + +func isEdge(parentKey *Key, sNode storageNode) bool { + sNodeLen := sNode.key.len + if parentKey == nil { // Root + return sNodeLen != 0 + } + return sNodeLen-parentKey.len > 1 +} + +// Note: we need to account for the fact that Junos Trie has nodes that are Binary AND Edge, +// whereas the protocol requires nodes that are Binary XOR Edge +func transformNode(tri *Trie, parentKey *Key, sNode storageNode) (*Edge, *Binary, error) { + isEdgeBool := isEdge(parentKey, sNode) + + var edge *Edge + if isEdgeBool { + edgePath := path(sNode.key, parentKey) + edge = &Edge{ + Path: &edgePath, + Child: sNode.node.Value, + } + } + if sNode.key.len == tri.height { // Leaf + return edge, nil, nil + } + lNode, err := tri.GetNodeFromKey(sNode.node.Left) + if err != nil { + return nil, nil, err + } + rNode, err := tri.GetNodeFromKey(sNode.node.Right) + if err != nil { + return nil, nil, err + } + + rightHash := rNode.Value + if isEdge(sNode.key, storageNode{node: rNode, key: sNode.node.Right}) { + edgePath := path(sNode.node.Right, sNode.key) + rEdge := ProofNode{Edge: &Edge{ + Path: &edgePath, + Child: rNode.Value, + }} + rightHash = rEdge.Hash(tri.hash) + } + leftHash := lNode.Value + if isEdge(sNode.key, storageNode{node: lNode, key: sNode.node.Left}) { + edgePath := path(sNode.node.Left, sNode.key) + lEdge := ProofNode{Edge: &Edge{ + Path: &edgePath, + Child: lNode.Value, + }} + leftHash = lEdge.Hash(tri.hash) + } + binary := &Binary{ + LeftHash: leftHash, + RightHash: rightHash, + } + + return edge, binary, nil +} + +// https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L514 +func GetProof(leaf *felt.Felt, tri *Trie) ([]ProofNode, error) { + leafKey := tri.feltToKey(leaf) + nodesToLeaf, err := tri.nodesFromRoot(&leafKey) + if err != nil { + return nil, err + } + proofNodes := []ProofNode{} + + var parentKey *Key + + for i := 0; i < len(nodesToLeaf); i++ { + sNode := nodesToLeaf[i] + sNodeEdge, sNodeBinary, err := transformNode(tri, parentKey, sNode) + if err != nil { + return nil, err + } + isLeaf := sNode.key.len == tri.height + + if sNodeEdge != nil && !isLeaf { // Internal Edge + proofNodes = append(proofNodes, []ProofNode{{Edge: sNodeEdge}, {Binary: sNodeBinary}}...) + } else if sNodeEdge == nil && !isLeaf { // Internal Binary + proofNodes = append(proofNodes, []ProofNode{{Binary: sNodeBinary}}...) + } else if sNodeEdge != nil && isLeaf { // Leaf Edge + proofNodes = append(proofNodes, []ProofNode{{Edge: sNodeEdge}}...) + } else if sNodeEdge == nil && sNodeBinary == nil { // sNode is a binary leaf + break + } + parentKey = nodesToLeaf[i].key + } + return proofNodes, nil +} + +// verifyProof checks if `leafPath` leads from `root` to `leafHash` along the `proofNodes` +// https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2006 +func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode, hash hashFunc) bool { + if key.Len() != 251 { //nolint:gomnd + return false + } + + expectedHash := root + remainingPath := key + + for _, proofNode := range proofs { + if !proofNode.Hash(hash).Equal(expectedHash) { + return false + } + switch { + case proofNode.Binary != nil: + if remainingPath.Test(remainingPath.Len() - 1) { + expectedHash = proofNode.Binary.RightHash + } else { + expectedHash = proofNode.Binary.LeftHash + } + remainingPath.RemoveLastBit() + case proofNode.Edge != nil: + if !proofNode.Edge.Path.Equal(remainingPath.SubKey(proofNode.Edge.Path.Len())) { + return false + } + expectedHash = proofNode.Edge.Child + remainingPath.Truncate(proofNode.Edge.Path.Len()) + } + } + return expectedHash.Equal(value) +} diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go new file mode 100644 index 0000000000..0ae9c78513 --- /dev/null +++ b/core/trie/proof_test.go @@ -0,0 +1,343 @@ +package trie_test + +import ( + "fmt" + "testing" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db/pebble" + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func buildSimpleTrie(t *testing.T) *trie.Trie { + // (250, 0, x1) + // | + // (0,0,x1) + // / \ + // (2) (3) + // Build trie + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{1}), 251) + require.NoError(t, err) + + // Update trie + key1 := new(felt.Felt).SetUint64(0) + key2 := new(felt.Felt).SetUint64(1) + value1 := new(felt.Felt).SetUint64(2) + value2 := new(felt.Felt).SetUint64(3) + + _, err = tempTrie.Put(key1, value1) + require.NoError(t, err) + + _, err = tempTrie.Put(key2, value2) + require.NoError(t, err) + + require.NoError(t, tempTrie.Commit()) + return tempTrie +} + +func buildSimpleBinaryRootTrie(t *testing.T) *trie.Trie { + // (0, 0, x) + // / \ + // (250, 0, cc) (250, 11111.., dd) + // | | + // (cc) (dd) + // Build trie + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{1}), 251) + require.NoError(t, err) + + key1 := new(felt.Felt).SetUint64(0) + key2 := utils.HexToFelt(t, "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + value1 := utils.HexToFelt(t, "0xcc") + value2 := utils.HexToFelt(t, "0xdd") + + _, err = tempTrie.Put(key1, value1) + require.NoError(t, err) + + _, err = tempTrie.Put(key2, value2) + require.NoError(t, err) + + require.NoError(t, tempTrie.Commit()) + return tempTrie +} + +func buildSimpleDoubleBinaryTrie(t *testing.T) *trie.Trie { + // Build trie + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{1}), 251) + require.NoError(t, err) + + // Update trie + key1 := new(felt.Felt).SetUint64(0) + key2 := new(felt.Felt).SetUint64(1) + key3 := new(felt.Felt).SetUint64(3) + value1 := new(felt.Felt).SetUint64(2) + value2 := new(felt.Felt).SetUint64(3) + value3 := new(felt.Felt).SetUint64(5) + + _, err = tempTrie.Put(key1, value1) + require.NoError(t, err) + + _, err = tempTrie.Put(key2, value2) + require.NoError(t, err) + + _, err = tempTrie.Put(key3, value3) + require.NoError(t, err) + + require.NoError(t, tempTrie.Commit()) + return tempTrie +} + +func TestGetProofs(t *testing.T) { + t.Run("Simple Trie - simple binary", func(t *testing.T) { + tempTrie := buildSimpleTrie(t) + + zero := trie.NewKey(250, []byte{0}) + expectedProofNodes := []trie.ProofNode{ + { + Edge: &trie.Edge{ + Path: &zero, + Child: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), + }, + }, + { + Binary: &trie.Binary{ + LeftHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002"), + RightHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003"), + }, + }, + } + + proofNodes, err := trie.GetProof(new(felt.Felt).SetUint64(0), tempTrie) + require.NoError(t, err) + + // Better inspection + for _, pNode := range proofNodes { + pNode.PrettyPrint() + } + require.Equal(t, expectedProofNodes, proofNodes) + }) + + t.Run("Simple Trie - simple double binary", func(t *testing.T) { + tempTrie := buildSimpleDoubleBinaryTrie(t) + + zero := trie.NewKey(249, []byte{0}) + expectedProofNodes := []trie.ProofNode{ + { + Edge: &trie.Edge{ + Path: &zero, + Child: utils.HexToFelt(t, "0x055C81F6A791FD06FC2E2CCAD922397EC76C3E35F2E06C0C0D43D551005A8DEA"), + }, + }, + { + Binary: &trie.Binary{ + LeftHash: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), + RightHash: utils.HexToFelt(t, "0x07C5BC1CC68B7BC8CA2F632DE98297E6DA9594FA23EDE872DD2ABEAFDE353B43"), + }, + }, + { + Binary: &trie.Binary{ + LeftHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002"), + RightHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003"), + }, + }, + } + + proofNodes, err := trie.GetProof(new(felt.Felt).SetUint64(0), tempTrie) + require.NoError(t, err) + + // Better inspection + for _, pNode := range proofNodes { + pNode.PrettyPrint() + } + require.Equal(t, expectedProofNodes, proofNodes) + }) + + t.Run("Simple Trie - simple double binary edge", func(t *testing.T) { + tempTrie := buildSimpleDoubleBinaryTrie(t) + + zero := trie.NewKey(249, []byte{0}) + value3 := new(felt.Felt).SetUint64(5) + key3Bytes := new(felt.Felt).SetUint64(1).Bytes() + path3 := trie.NewKey(1, key3Bytes[:]) + expectedProofNodes := []trie.ProofNode{ + { + Edge: &trie.Edge{ + Path: &zero, + Child: utils.HexToFelt(t, "0x055C81F6A791FD06FC2E2CCAD922397EC76C3E35F2E06C0C0D43D551005A8DEA"), + }, + }, + { + Binary: &trie.Binary{ + LeftHash: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), + RightHash: utils.HexToFelt(t, "0x07C5BC1CC68B7BC8CA2F632DE98297E6DA9594FA23EDE872DD2ABEAFDE353B43"), + }, + }, + { + Edge: &trie.Edge{ + Path: &path3, + Child: value3, + }, + }, + } + + proofNodes, err := trie.GetProof(new(felt.Felt).SetUint64(3), tempTrie) + require.NoError(t, err) + + // Better inspection + for _, pNode := range proofNodes { + pNode.PrettyPrint() + } + require.Equal(t, expectedProofNodes, proofNodes) + }) + + t.Run("Simple Trie - simple binary root", func(t *testing.T) { + tempTrie := buildSimpleBinaryRootTrie(t) + + key1Bytes := new(felt.Felt).SetUint64(0).Bytes() + path1 := trie.NewKey(250, key1Bytes[:]) + expectedProofNodes := []trie.ProofNode{ + { + Binary: &trie.Binary{ + LeftHash: utils.HexToFelt(t, "0x06E08BF82793229338CE60B65D1845F836C8E2FBFE2BC59FF24AEDBD8BA219C4"), + RightHash: utils.HexToFelt(t, "0x04F9B8E66212FB528C0C1BD02F43309C53B895AA7D9DC91180001BDD28A588FA"), + }, + }, + { + Edge: &trie.Edge{ + Path: &path1, + Child: utils.HexToFelt(t, "0xcc"), + }, + }, + } + + proofNodes, err := trie.GetProof(new(felt.Felt).SetUint64(0), tempTrie) + require.NoError(t, err) + + // Better inspection + for _, pNode := range proofNodes { + pNode.PrettyPrint() + } + require.Equal(t, expectedProofNodes, proofNodes) + }) + + t.Run("Simple Trie - left-right edge", func(t *testing.T) { + // (251,0xff,0xaa) + // / + // \ + // (0xaa) + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{1}), 251) + require.NoError(t, err) + + key1 := utils.HexToFelt(t, "0xff") + value1 := utils.HexToFelt(t, "0xaa") + + _, err = tempTrie.Put(key1, value1) + require.NoError(t, err) + + require.NoError(t, tempTrie.Commit()) + + key1Bytes := key1.Bytes() + path1 := trie.NewKey(251, key1Bytes[:]) + + child := utils.HexToFelt(t, "0x00000000000000000000000000000000000000000000000000000000000000AA") + expectedProofNodes := []trie.ProofNode{ + { + Edge: &trie.Edge{ + Path: &path1, + Child: child, + }, + }, + } + + proofNodes, err := trie.GetProof(new(felt.Felt).SetUint64(0), tempTrie) + require.NoError(t, err) + + // Better inspection + for i, pNode := range proofNodes { + fmt.Println(i) + pNode.PrettyPrint() + } + require.Equal(t, expectedProofNodes, proofNodes) + }) +} + +func TestVerifyProofs(t *testing.T) { + // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2137 + t.Run("Simple binary trie", func(t *testing.T) { + tempTrie := buildSimpleTrie(t) + zero := trie.NewKey(250, []byte{0}) + expectedProofNodes := []trie.ProofNode{ + { + Edge: &trie.Edge{ + Path: &zero, + Child: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), + }, + }, + { + Binary: &trie.Binary{ + LeftHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002"), + RightHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003"), + }, + }, + } + + root, err := tempTrie.Root() + require.NoError(t, err) + key1Bytes := new(felt.Felt).SetUint64(0).Bytes() + key1 := trie.NewKey(251, key1Bytes[:]) + val1 := new(felt.Felt).SetUint64(2) + assert.True(t, trie.VerifyProof(root, &key1, val1, expectedProofNodes, crypto.Pedersen)) + }) + + // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2167 + t.Run("Simple double binary trie", func(t *testing.T) { + tempTrie := buildSimpleDoubleBinaryTrie(t) + zero := trie.NewKey(249, []byte{0}) + expectedProofNodes := []trie.ProofNode{ + { + Edge: &trie.Edge{ + Path: &zero, + Child: utils.HexToFelt(t, "0x055C81F6A791FD06FC2E2CCAD922397EC76C3E35F2E06C0C0D43D551005A8DEA"), + }, + }, + { + Binary: &trie.Binary{ + LeftHash: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), + RightHash: utils.HexToFelt(t, "0x07C5BC1CC68B7BC8CA2F632DE98297E6DA9594FA23EDE872DD2ABEAFDE353B43"), + }, + }, + { + Binary: &trie.Binary{ + LeftHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002"), + RightHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003"), + }, + }, + } + + root, err := tempTrie.Root() + require.NoError(t, err) + key1Bytes := new(felt.Felt).SetUint64(0).Bytes() + key1 := trie.NewKey(251, key1Bytes[:]) + val1 := new(felt.Felt).SetUint64(2) + assert.True(t, trie.VerifyProof(root, &key1, val1, expectedProofNodes, crypto.Pedersen)) + }) +} diff --git a/core/trie/trie.go b/core/trie/trie.go index 1a5ce7f00a..2496b724d4 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -182,6 +182,11 @@ func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { return &leafValue, nil } +// GetNodeFromKey returns the node for a given key. +func (t *Trie) GetNodeFromKey(key *Key) (*Node, error) { + return t.storage.Get(key) +} + // check if we are updating an existing leaf, if yes avoid traversing the trie func (t *Trie) updateLeaf(nodeKey Key, node *Node, value *felt.Felt) (*felt.Felt, error) { // Check if we are updating an existing leaf