diff --git a/core/trie/proof.go b/core/trie/proof.go index 5da9fee7f8..5d46382f85 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -1,6 +1,7 @@ package trie import ( + "errors" "fmt" "github.com/NethermindEth/juno/core/felt" @@ -28,6 +29,13 @@ func (pn *ProofNode) Hash(hash hashFunc) *felt.Felt { } } +func (pn *ProofNode) Len() uint8 { + if pn.Binary != nil { + return 1 + } + return pn.Edge.Path.len +} + func (pn *ProofNode) PrettyPrint() { if pn.Binary != nil { fmt.Printf(" Binary:\n") @@ -48,12 +56,27 @@ type Binary struct { } type Edge struct { - Child *felt.Felt - Path *Key - Value *felt.Felt + Child *felt.Felt // child hash + Path *Key // path from parent to child + Value *felt.Felt // this nodes hash +} + +func GetBoundaryProofs(leftBoundary, rightBoundary *Key, tri *Trie) ([2][]ProofNode, error) { + proofs := [2][]ProofNode{} + leftProof, err := GetProof(leftBoundary, tri) + if err != nil { + return proofs, err + } + rightProof, err := GetProof(rightBoundary, tri) + if err != nil { + return proofs, err + } + proofs[0] = leftProof + proofs[1] = rightProof + return proofs, nil } -func isEdge(parentKey *Key, sNode storageNode) bool { +func isEdge(parentKey *Key, sNode StorageNode) bool { sNodeLen := sNode.key.len if parentKey == nil { // Root return sNodeLen != 0 @@ -63,7 +86,7 @@ func isEdge(parentKey *Key, sNode storageNode) bool { // Note: we need to account for the fact that Junos Trie has nodes that are Binary AND Edge, // whereas the protocol requires nodes that are Binary XOR Edge -func transformNode(tri *Trie, parentKey *Key, sNode storageNode) (*Edge, *Binary, error) { +func transformNode(tri *Trie, parentKey *Key, sNode StorageNode) (*Edge, *Binary, error) { isEdgeBool := isEdge(parentKey, sNode) var edge *Edge @@ -87,7 +110,7 @@ func transformNode(tri *Trie, parentKey *Key, sNode storageNode) (*Edge, *Binary } rightHash := rNode.Value - if isEdge(sNode.key, storageNode{node: rNode, key: sNode.node.Right}) { + if isEdge(sNode.key, StorageNode{node: rNode, key: sNode.node.Right}) { edgePath := path(sNode.node.Right, sNode.key) rEdge := ProofNode{Edge: &Edge{ Path: &edgePath, @@ -96,7 +119,7 @@ func transformNode(tri *Trie, parentKey *Key, sNode storageNode) (*Edge, *Binary rightHash = rEdge.Hash(tri.hash) } leftHash := lNode.Value - if isEdge(sNode.key, storageNode{node: lNode, key: sNode.node.Left}) { + if isEdge(sNode.key, StorageNode{node: lNode, key: sNode.node.Left}) { edgePath := path(sNode.node.Left, sNode.key) lEdge := ProofNode{Edge: &Edge{ Path: &edgePath, @@ -147,7 +170,7 @@ func GetProof(key *Key, tri *Trie) ([]ProofNode, error) { // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2006 func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode, hash hashFunc) bool { expectedHash := root - remainingPath := key + remainingPath := NewKey(key.len, key.bitset[:]) for _, proofNode := range proofs { if !proofNode.Hash(hash).Equal(expectedHash) { return false @@ -172,5 +195,261 @@ func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode remainingPath.Truncate(251 - proofNode.Edge.Path.Len()) //nolint:gomnd } } + return expectedHash.Equal(value) } + +// VerifyRangeProof verifies the range proof for the given range of keys. +// This is achieved by constructing a trie from the boundary proofs, and the supplied key-values. +// If the root of the reconstructed trie matches the supplied root, then the verification passes. +// If the trie is constructed incorrectly then the root will have an incorrect key(len,path), and value, +// and therefore it's hash won't match the expected root +// ref: https://github.com/ethereum/go-ethereum/blob/v1.14.3/trie/proof.go#L484 +func VerifyRangeProof(root *felt.Felt, keys, values []*felt.Felt, proofKeys [2]*Key, proofValues [2]*felt.Felt, + proofs [2][]ProofNode, hash hashFunc, +) (bool, error) { + // Step 0: checks + if len(keys) != len(values) { + return false, fmt.Errorf("inconsistent proof data, number of keys: %d, number of values: %d", len(keys), len(values)) + } + + // Ensure all keys are monotonic increasing + if err := ensureMonotonicIncreasing(proofKeys, keys); err != nil { + return false, err + } + + // Ensure the inner values contain no deletions + for _, value := range values { + if value.Equal(&felt.Zero) { + return false, errors.New("range contains deletion") + } + } + + // Step 1: Verify proofs, and get proof paths + var proofPaths [2][]StorageNode + var err error + for i := 0; i < 2; i++ { + if proofs[i] != nil { + if !VerifyProof(root, proofKeys[i], proofValues[i], proofs[i], hash) { + return false, fmt.Errorf("invalid proof for key %x", proofKeys[i].String()) + } + + proofPaths[i], err = ProofToPath(proofs[i], proofKeys[i], hash) + if err != nil { + return false, err + } + } + } + + // Step 2: Build trie from proofPaths and keys + tmpTrie, err := BuildTrie(proofPaths[0], proofPaths[1], keys, values) + if err != nil { + return false, err + } + + // Verify that the recomputed root hash matches the provided root hash + recomputedRoot, err := tmpTrie.Root() + if err != nil { + return false, err + } + if !recomputedRoot.Equal(root) { + return false, errors.New("root hash mismatch") + } + + return true, nil +} + +func ensureMonotonicIncreasing(proofKeys [2]*Key, keys []*felt.Felt) error { + if proofKeys[0] != nil { + leftProofFelt := proofKeys[0].Felt() + if leftProofFelt.Cmp(keys[0]) >= 0 { + return errors.New("range is not monotonically increasing") + } + } + if proofKeys[1] != nil { + rightProofFelt := proofKeys[1].Felt() + if keys[len(keys)-1].Cmp(&rightProofFelt) >= 0 { + return errors.New("range is not monotonically increasing") + } + } + if len(keys) >= 2 { + for i := 0; i < len(keys)-1; i++ { + if keys[i].Cmp(keys[i+1]) >= 0 { + return errors.New("range is not monotonically increasing") + } + } + } + return nil +} + +// shouldSquish determines if the node needs compressed, and if so, the len needed to arrive at the next key +func shouldSquish(idx int, proofNodes []ProofNode) (int, uint8) { + parent := &proofNodes[idx] + var child *ProofNode + // The child is nil of the current node is a leaf + if idx != len(proofNodes)-1 { + child = &proofNodes[idx+1] + } + + if child == nil { + return 0, 0 + } + + if parent.Edge != nil && child.Binary != nil { + return 1, parent.Edge.Path.len + } + + if parent.Binary != nil && child.Edge != nil { + return 1, child.Edge.Path.len + } + + return 0, 0 +} + +func assignChild(crntNode *Node, nilKey, childKey *Key, isRight bool) { + if isRight { + crntNode.Right = childKey + crntNode.Left = nilKey + } else { + crntNode.Right = nilKey + crntNode.Left = childKey + } +} + +// ProofToPath returns the set of storage nodes along the proofNodes towards the leaf. +// Note that only the nodes and children along this path will be set correctly. +func ProofToPath(proofNodes []ProofNode, leafKey *Key, hashF hashFunc) ([]StorageNode, error) { + pathNodes := []StorageNode{} + + // Hack: this allows us to store a right without an existing left node. + zeroFeltBytes := new(felt.Felt).Bytes() + nilKey := NewKey(0, zeroFeltBytes[:]) + + i, offset := 0, 0 + for i <= len(proofNodes)-1 { + var crntKey *Key + crntNode := Node{} + + height := getHeight(i, pathNodes, proofNodes) + + // Set the key of the current node + var err error + squishedParent, squishParentOffset := shouldSquish(i, proofNodes) + if proofNodes[i].Binary != nil { + crntKey, err = leafKey.SubKey(height) + } else { + crntKey, err = leafKey.SubKey(height + squishParentOffset) + } + if err != nil { + return nil, err + } + offset += squishedParent + + // Set the value of the current node + crntNode.Value = proofNodes[i].Hash(hashF) + + // Set the children of the current node + childIdx := i + squishedParent + 1 + childIsRight := leafKey.Test(leafKey.len - crntKey.len - 1) + if i+2+squishedParent < len(proofNodes)-1 { // The child will be compressed, so point to its compressed form + squishedChild, squishChildOffset := shouldSquish(childIdx, proofNodes) + childKey, err := leafKey.SubKey(height + squishParentOffset + squishChildOffset + uint8(squishedChild)) + if err != nil { + return nil, err + } + assignChild(&crntNode, &nilKey, childKey, childIsRight) + } else if i+1+offset == len(proofNodes)-1 { // The child points to a leaf, keep it as is + if proofNodes[childIdx].Edge != nil { + assignChild(&crntNode, &nilKey, leafKey, childIsRight) + } else { + childKey, err := leafKey.SubKey(crntKey.len + proofNodes[childIdx].Len()) + if err != nil { + return nil, err + } + assignChild(&crntNode, &nilKey, childKey, childIsRight) + } + } else { // Current node points directly to leaf + if proofNodes[i].Edge != nil && len(pathNodes) > 0 { + break + } + assignChild(&crntNode, &nilKey, leafKey, childIsRight) + } + + pathNodes = append(pathNodes, StorageNode{key: crntKey, node: &crntNode}) + i += 1 + offset + } + pathNodes = addLeafNode(proofNodes, pathNodes, leafKey) + return pathNodes, nil +} + +// getHeight returns the height of the current node, which depends on the previous +// height and whether the current proofnode is edge or binary +func getHeight(idx int, pathNodes []StorageNode, proofNodes []ProofNode) uint8 { + if len(pathNodes) > 0 { + if proofNodes[idx].Edge != nil { + return pathNodes[len(pathNodes)-1].key.len + proofNodes[idx].Edge.Path.len + } else { + return pathNodes[len(pathNodes)-1].key.len + 1 + } + } else { + return 0 + } +} + +// addLeafNode appends the leaf node, if the final node in pathNodes points to a leaf. +func addLeafNode(proofNodes []ProofNode, pathNodes []StorageNode, leafKey *Key) []StorageNode { + lastNode := pathNodes[len(pathNodes)-1].node + lastProof := proofNodes[len(proofNodes)-1] + if lastNode.Left.Equal(leafKey) || lastNode.Right.Equal(leafKey) { + leafNode := Node{} + if lastProof.Edge != nil { + leafNode.Value = lastProof.Edge.Child + } else if lastNode.Left.Equal(leafKey) { + leafNode.Value = lastProof.Binary.LeftHash + } else { + leafNode.Value = lastProof.Binary.RightHash + } + pathNodes = append(pathNodes, StorageNode{key: leafKey, node: &leafNode}) + } + return pathNodes +} + +// BuildTrie builds a trie using the proof paths (including inner nodes), and then sets all the keys-values (leaves) +func BuildTrie(leftProofPath, rightProofPath []StorageNode, keys, values []*felt.Felt) (*Trie, error) { + tempTrie, err := NewTriePedersen(newMemStorage(), 251) //nolint:gomnd + if err != nil { + return nil, err + } + + // merge proof paths + for i := range min(len(leftProofPath), len(rightProofPath)) { + if leftProofPath[i].key.Equal(rightProofPath[i].key) { + leftProofPath[i].node.Right = rightProofPath[i].node.Right + rightProofPath[i].node.Left = leftProofPath[i].node.Left + } else { + break + } + } + + for _, sNode := range leftProofPath { + _, err := tempTrie.PutInner(sNode.key, sNode.node) + if err != nil { + return nil, err + } + } + + for _, sNode := range rightProofPath { + _, err := tempTrie.PutInner(sNode.key, sNode.node) + if err != nil { + return nil, err + } + } + + for i := range len(keys) { + _, err := tempTrie.PutWithProof(keys[i], values[i], leftProofPath, rightProofPath) + if err != nil { + return nil, err + } + } + return tempTrie, nil +} diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index 9356db86a4..6757a75225 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -1,7 +1,6 @@ package trie_test import ( - "fmt" "testing" "github.com/NethermindEth/juno/core/crypto" @@ -24,7 +23,7 @@ func buildSimpleTrie(t *testing.T) *trie.Trie { txn, err := memdb.NewTransaction(true) require.NoError(t, err) - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{1}), 251) + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) require.NoError(t, err) // Update trie @@ -44,12 +43,18 @@ func buildSimpleTrie(t *testing.T) *trie.Trie { } func buildSimpleBinaryRootTrie(t *testing.T) *trie.Trie { + // PF // (0, 0, x) // / \ // (250, 0, cc) (250, 11111.., dd) // | | // (cc) (dd) - // Build trie + + // JUNO + // (0, 0, x) + // / \ + // (251, 0, cc) (251, 11111.., dd) + memdb := pebble.NewMemTest(t) txn, err := memdb.NewTransaction(true) require.NoError(t, err) @@ -135,6 +140,53 @@ func buildSimpleDoubleBinaryTrie(t *testing.T) (*trie.Trie, []trie.ProofNode) { return tempTrie, expectedProofNodes } +func build3KeyTrie(t *testing.T) *trie.Trie { + // Starknet + // + // Edge + // | + // Binary with len 249 + // / \ + // Binary (250) Edge with len 250 (?) + // / \ \ + // 0x4 0x5 0x6 (edge?) + + // Juno + // + // Node (path 249) + // / \ + // Node (binary) 0x6 + // / \ + // 0x4 0x5 + + // Build trie + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) + + // Update trie + key1 := new(felt.Felt).SetUint64(0) + key2 := new(felt.Felt).SetUint64(1) + key3 := new(felt.Felt).SetUint64(2) + value1 := new(felt.Felt).SetUint64(4) + value2 := new(felt.Felt).SetUint64(5) + value3 := new(felt.Felt).SetUint64(6) + + _, err = tempTrie.Put(key1, value1) + require.NoError(t, err) + + _, err = tempTrie.Put(key3, value3) + require.NoError(t, err) + _, err = tempTrie.Put(key2, value2) + require.NoError(t, err) + + require.NoError(t, tempTrie.Commit()) + return tempTrie +} + func TestGetProof(t *testing.T) { t.Run("Simple Trie - simple binary", func(t *testing.T) { tempTrie := buildSimpleTrie(t) @@ -160,9 +212,9 @@ func TestGetProof(t *testing.T) { require.NoError(t, err) // Better inspection - for _, pNode := range proofNodes { - pNode.PrettyPrint() - } + // for _, pNode := range proofNodes { + // pNode.PrettyPrint() + // } require.Equal(t, expectedProofNodes, proofNodes) }) @@ -182,9 +234,9 @@ func TestGetProof(t *testing.T) { require.NoError(t, err) // Better inspection - for _, pNode := range proofNodes { - pNode.PrettyPrint() - } + // for _, pNode := range proofNodes { + // pNode.PrettyPrint() + // } require.Equal(t, expectedProofNodes, proofNodes) }) @@ -196,9 +248,9 @@ func TestGetProof(t *testing.T) { require.NoError(t, err) // Better inspection - for _, pNode := range proofNodes { - pNode.PrettyPrint() - } + // for _, pNode := range proofNodes { + // pNode.PrettyPrint() + // } require.Equal(t, expectedProofNodes, proofNodes) }) @@ -228,9 +280,9 @@ func TestGetProof(t *testing.T) { require.NoError(t, err) // Better inspection - for _, pNode := range proofNodes { - pNode.PrettyPrint() - } + // for _, pNode := range proofNodes { + // pNode.PrettyPrint() + // } require.Equal(t, expectedProofNodes, proofNodes) }) @@ -272,10 +324,9 @@ func TestGetProof(t *testing.T) { require.NoError(t, err) // Better inspection - for i, pNode := range proofNodes { - fmt.Println(i) - pNode.PrettyPrint() - } + // for _, pNode := range proofNodes { + // pNode.PrettyPrint() + // } require.Equal(t, expectedProofNodes, proofNodes) }) @@ -347,10 +398,11 @@ func TestVerifyProof(t *testing.T) { root, err := tempTrie.Root() require.NoError(t, err) - key1Bytes := new(felt.Felt).SetUint64(0).Bytes() - key1 := trie.NewKey(251, key1Bytes[:]) val1 := new(felt.Felt).SetUint64(2) - assert.True(t, trie.VerifyProof(root, &key1, val1, expectedProofNodes, crypto.Pedersen)) + + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + leafkey := trie.NewKey(251, zeroFeltBytes[:]) + assert.True(t, trie.VerifyProof(root, &leafkey, val1, expectedProofNodes, crypto.Pedersen)) }) // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2167 @@ -380,10 +432,49 @@ func TestVerifyProof(t *testing.T) { root, err := tempTrie.Root() require.NoError(t, err) - key1Bytes := new(felt.Felt).SetUint64(0).Bytes() - key1 := trie.NewKey(251, key1Bytes[:]) val1 := new(felt.Felt).SetUint64(2) - require.True(t, trie.VerifyProof(root, &key1, val1, expectedProofNodes, crypto.Pedersen)) + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + leafkey := trie.NewKey(251, zeroFeltBytes[:]) + assert.True(t, trie.VerifyProof(root, &leafkey, val1, expectedProofNodes, crypto.Pedersen)) + }) + + t.Run("three key trie", func(t *testing.T) { + tempTrie := build3KeyTrie(t) + zero := trie.NewKey(249, []byte{0}) + felt2 := new(felt.Felt).SetUint64(0).Bytes() + lastPath := trie.NewKey(1, felt2[:]) + expectedProofNodes := []trie.ProofNode{ + { + Edge: &trie.Edge{ + Path: &zero, + Child: utils.HexToFelt(t, "0x0768DEB8D0795D80AAAC2E5E326141F33044759F97A1BF092D8EB9C4E4BE9234"), + }, + }, + { + Binary: &trie.Binary{ + LeftHash: utils.HexToFelt(t, "0x057166F9476D0A2D6875124251841EB85A9AE37462FAE3CBF7304BCD593938E7"), + RightHash: utils.HexToFelt(t, "0x060FBDE29F96F706498EFD132DC7F312A4C99A9AE051BF152C2AF2B3CAF31E5B"), + }, + }, + { + Edge: &trie.Edge{ + Path: &lastPath, + Child: utils.HexToFelt(t, "0x6"), + }, + }, + } + + root, err := tempTrie.Root() + require.NoError(t, err) + val6 := new(felt.Felt).SetUint64(6) + + twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() + leafkey := trie.NewKey(251, twoFeltBytes[:]) + gotProof, err := trie.GetProof(&leafkey, tempTrie) + require.NoError(t, err) + require.Equal(t, expectedProofNodes, gotProof) + + assert.True(t, trie.VerifyProof(root, &leafkey, val6, expectedProofNodes, crypto.Pedersen)) }) t.Run("non existent key - less than root edge", func(t *testing.T) { @@ -415,3 +506,282 @@ func TestVerifyProof(t *testing.T) { require.False(t, trie.VerifyProof(root, &nonExistentKey, nonExistentKeyValue, proofNodes, crypto.Pedersen)) }) } + +func TestProofToPath(t *testing.T) { + t.Run("Simple binary trie proof to path", func(t *testing.T) { + // Todo check leaf + tempTrie := buildSimpleTrie(t) + zero := trie.NewKey(250, []byte{0}) + proofNodes := []trie.ProofNode{ + { + Edge: &trie.Edge{ + Path: &zero, + Child: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), + }, + }, + { + Binary: &trie.Binary{ + LeftHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002"), + RightHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003"), + }, + }, + } + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + leafkey := trie.NewKey(251, zeroFeltBytes[:]) + sns, err := trie.ProofToPath(proofNodes, &leafkey, crypto.Pedersen) // Todo : we should be able to set the leaf as well + require.NoError(t, err) + + rootKey := tempTrie.RootKey() + rootNodes, err := tempTrie.GetNodeFromKey(rootKey) + require.NoError(t, err) + + require.Equal(t, 2, len(sns)) + require.Equal(t, rootKey.Len(), sns[0].Key().Len()) + require.Equal(t, rootNodes.Left, sns[0].Node().Left) + require.NotEqual(t, rootNodes.Right, sns[0].Node().Right) + }) + + t.Run("Simple double binary trie proof to path", func(t *testing.T) { + // Todo: check leaf + tempTrie := buildSimpleBinaryRootTrie(t) + + key1Bytes := new(felt.Felt).SetUint64(0).Bytes() + path1 := trie.NewKey(250, key1Bytes[:]) + proofNodes := []trie.ProofNode{ + { + Binary: &trie.Binary{ + LeftHash: utils.HexToFelt(t, "0x06E08BF82793229338CE60B65D1845F836C8E2FBFE2BC59FF24AEDBD8BA219C4"), + RightHash: utils.HexToFelt(t, "0x04F9B8E66212FB528C0C1BD02F43309C53B895AA7D9DC91180001BDD28A588FA"), + }, + }, + { + Edge: &trie.Edge{ + Path: &path1, + Child: utils.HexToFelt(t, "0xcc"), + }, + }, + } + + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + leafkey := trie.NewKey(251, zeroFeltBytes[:]) + sns, err := trie.ProofToPath(proofNodes, &leafkey, crypto.Pedersen) + require.NoError(t, err) + + rootKey := tempTrie.RootKey() + + rootNodes, err := tempTrie.GetNodeFromKey(rootKey) + require.NoError(t, err) + require.Equal(t, 2, len(sns)) + require.Equal(t, rootKey.Len(), sns[0].Key().Len()) + require.Equal(t, rootNodes.Left, sns[0].Node().Left) + require.NotEqual(t, rootNodes.Right, sns[0].Node().Right) + }) + + t.Run("boundary proofs wth three key trie", func(t *testing.T) { + tri := build3KeyTrie(t) + rootKey := tri.RootKey() + rootNode, err := tri.GetNodeFromKey(rootKey) + require.NoError(t, err) + + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() + twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) + bProofs, err := trie.GetBoundaryProofs(&zeroLeafkey, &twoLeafkey, tri) + require.NoError(t, err) + + leftProofPath, err := trie.ProofToPath(bProofs[0], &zeroLeafkey, crypto.Pedersen) + require.Equal(t, 3, len(leftProofPath)) + require.NoError(t, err) + require.Equal(t, rootKey, leftProofPath[0].Key()) + require.Equal(t, rootNode.Left, leftProofPath[0].Node().Left) + require.NotEqual(t, rootNode.Right, leftProofPath[0].Node().Right) + + leftNode, err := tri.GetNodeFromKey(rootNode.Left) + require.NoError(t, err) + require.Equal(t, rootNode.Left, leftProofPath[1].Key()) + require.Equal(t, leftNode.Left, leftProofPath[1].Node().Left) + require.NotEqual(t, leftNode.Right, leftProofPath[0].Node().Right) + + rightProofPath, err := trie.ProofToPath(bProofs[1], &twoLeafkey, crypto.Pedersen) + require.Equal(t, 2, len(rightProofPath)) + require.NoError(t, err) + require.Equal(t, rootKey, rightProofPath[0].Key()) + require.Equal(t, rootNode.Right, rightProofPath[0].Node().Right) + require.NotEqual(t, rootNode.Left, rightProofPath[0].Node().Left) + }) +} + +func TestBuildTrie(t *testing.T) { + t.Run("Simple binary trie proof to path", func(t *testing.T) { + compareLeftRight := func(t *testing.T, want, got *trie.Node) { + require.Equal(t, want.Left, got.Left, "left fail") + require.Equal(t, want.Right, got.Right, "right fail") + } + + // Node (edge path 249) + // / \ + // Node (binary) 0x6 (leaf) + // / \ + // 0x4 0x5 (leaf, leaf) + + tri := build3KeyTrie(t) + rootKey := tri.RootKey() + rootCommitment, err := tri.Root() + require.NoError(t, err) + rootNode, err := tri.GetNodeFromKey(rootKey) + require.NoError(t, err) + leftNode, err := tri.GetNodeFromKey(rootNode.Left) + require.NoError(t, err) + rightNode, err := tri.GetNodeFromKey(rootNode.Right) + require.NoError(t, err) + leftleftNode, err := tri.GetNodeFromKey(leftNode.Left) + require.NoError(t, err) + leftrightNode, err := tri.GetNodeFromKey(leftNode.Right) + require.NoError(t, err) + + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() + twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) + bProofs, err := trie.GetBoundaryProofs(&zeroLeafkey, &twoLeafkey, tri) + require.NoError(t, err) + + leftProof, err := trie.ProofToPath(bProofs[0], &zeroLeafkey, crypto.Pedersen) + require.NoError(t, err) + + rightProof, err := trie.ProofToPath(bProofs[1], &twoLeafkey, crypto.Pedersen) + require.NoError(t, err) + + keys := []*felt.Felt{new(felt.Felt).SetUint64(1)} + values := []*felt.Felt{new(felt.Felt).SetUint64(5)} + builtTrie, err := trie.BuildTrie(leftProof, rightProof, keys, values) + require.NoError(t, err) + + builtRootKey := builtTrie.RootKey() + builtRootNode, err := builtTrie.GetNodeFromKey(builtRootKey) + require.NoError(t, err) + builtLeftNode, err := builtTrie.GetNodeFromKey(builtRootNode.Left) + require.NoError(t, err) + builtRightNode, err := builtTrie.GetNodeFromKey(builtRootNode.Right) + require.NoError(t, err) + builtLeftLeftNode, err := builtTrie.GetNodeFromKey(builtLeftNode.Left) + require.NoError(t, err) + builtLeftRightNode, err := builtTrie.GetNodeFromKey(builtLeftNode.Right) + require.NoError(t, err) + + // Assert the structure is correct + require.Equal(t, rootKey, builtRootKey) + compareLeftRight(t, rootNode, builtRootNode) + compareLeftRight(t, leftNode, builtLeftNode) + compareLeftRight(t, rightNode, builtRightNode) + compareLeftRight(t, leftleftNode, builtLeftLeftNode) + compareLeftRight(t, leftrightNode, builtLeftRightNode) + + // Assert the leaf nodes have the correct values + require.Equal(t, leftleftNode.Value.String(), builtLeftLeftNode.Value.String(), "should be 0x4") + require.Equal(t, leftrightNode.Value.String(), builtLeftRightNode.Value.String(), "should be 0x5") + require.Equal(t, rightNode.Value.String(), builtRightNode.Value.String(), "should be 0x6") + + // Given the above two asserts pass, we should be able to reconstruct the correct commitment + reconstructedRootCommitment, err := builtTrie.Root() + require.NoError(t, err) + require.Equal(t, rootCommitment.String(), reconstructedRootCommitment.String()) + }) +} + +func TestVerifyRangeProof(t *testing.T) { + t.Run("two proofs, single key trie", func(t *testing.T) { + // Node (edge path 249) + // / \ + // Node (binary) 0x6 (leaf) + // / \ + // 0x4 0x5 (leaf, leaf) + + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() + twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) + + tri := build3KeyTrie(t) + keys := []*felt.Felt{new(felt.Felt).SetUint64(1)} + values := []*felt.Felt{new(felt.Felt).SetUint64(5)} + proofKeys := [2]*trie.Key{&zeroLeafkey, &twoLeafkey} + proofValues := [2]*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(6)} + rootCommitment, err := tri.Root() + require.NoError(t, err) + proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) + require.NoError(t, err) + verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) + require.NoError(t, err) + require.True(t, verif) + }) + + t.Run("all keys provided, no proofs needed", func(t *testing.T) { + // Node (edge path 249) + // / \ + // Node (binary) 0x6 (leaf) + // / \ + // 0x4 0x5 (leaf, leaf) + tri := build3KeyTrie(t) + keys := []*felt.Felt{new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)} + values := []*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(5), new(felt.Felt).SetUint64(6)} + proofKeys := [2]*trie.Key{} + proofValues := [2]*felt.Felt{} + proofs := [2][]trie.ProofNode{} + rootCommitment, err := tri.Root() + require.NoError(t, err) + verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) + require.NoError(t, err) + require.True(t, verif) + }) + + t.Run("left proof, all right keys", func(t *testing.T) { + // Node (edge path 249) + // / \ + // Node (binary) 0x6 (leaf) + // / \ + // 0x4 0x5 (leaf, leaf) + + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + + tri := build3KeyTrie(t) + keys := []*felt.Felt{new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)} + values := []*felt.Felt{new(felt.Felt).SetUint64(5), new(felt.Felt).SetUint64(6)} + proofKeys := [2]*trie.Key{&zeroLeafkey} + proofValues := [2]*felt.Felt{new(felt.Felt).SetUint64(4)} + leftProof, err := trie.GetProof(proofKeys[0], tri) + require.NoError(t, err) + proofs := [2][]trie.ProofNode{leftProof} + rootCommitment, err := tri.Root() + require.NoError(t, err) + verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) + require.NoError(t, err) + require.True(t, verif) + }) + + t.Run("right proof, all left keys", func(t *testing.T) { + // Node (edge path 249) + // / \ + // Node (binary) 0x6 (leaf) + // / \ + // 0x4 0x5 (leaf, leaf) + twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() + twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) + + tri := build3KeyTrie(t) + keys := []*felt.Felt{new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(1)} + values := []*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(5)} + proofKeys := [2]*trie.Key{nil, &twoLeafkey} + proofValues := [2]*felt.Felt{nil, new(felt.Felt).SetUint64(6)} + rightProof, err := trie.GetProof(proofKeys[1], tri) + require.NoError(t, err) + proofs := [2][]trie.ProofNode{nil, rightProof} + rootCommitment, err := tri.Root() + require.NoError(t, err) + verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) + require.NoError(t, err) + require.True(t, verif) + }) +} diff --git a/core/trie/trie.go b/core/trie/trie.go index 2496b724d4..98080edae7 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -130,24 +130,41 @@ func path(key, parentKey *Key) Key { // storageNode is the on-disk representation of a [Node], // where key is the storage key and node is the value. -type storageNode struct { +type StorageNode struct { key *Key node *Node } +func (sn *StorageNode) Key() *Key { + return sn.key +} + +func (sn *StorageNode) Node() *Node { + return sn.node +} + +func NewStorageNode(key *Key, node *Node) *StorageNode { + return &StorageNode{key: key, node: node} +} + // nodesFromRoot enumerates the set of [Node] objects that need to be traversed from the root // of the Trie to the node which is given by the key. // The [storageNode]s are returned in descending order beginning with the root. -func (t *Trie) nodesFromRoot(key *Key) ([]storageNode, error) { - var nodes []storageNode +func (t *Trie) nodesFromRoot(key *Key) ([]StorageNode, error) { + var nodes []StorageNode cur := t.rootKey for cur != nil { + // proof nodes set "nil" nodes to zero + if len(nodes) > 0 && cur.len == 0 { + return nodes, nil + } + node, err := t.storage.Get(cur) if err != nil { return nil, err } - nodes = append(nodes, storageNode{ + nodes = append(nodes, StorageNode{ key: cur, node: node, }) @@ -217,7 +234,7 @@ func (t *Trie) handleEmptyTrie(old felt.Felt, nodeKey Key, node *Node, value *fe return &old, nil } -func (t *Trie) deleteExistingKey(sibling storageNode, nodeKey Key, nodes []storageNode) (*felt.Felt, error) { +func (t *Trie) deleteExistingKey(sibling StorageNode, nodeKey Key, nodes []StorageNode) (*felt.Felt, error) { if nodeKey.Equal(sibling.key) { // we have to deference the Value, since the Node can released back // to the NodePool and be reused anytime @@ -230,7 +247,7 @@ func (t *Trie) deleteExistingKey(sibling storageNode, nodeKey Key, nodes []stora return nil, nil } -func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent storageNode) { +func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent StorageNode) { if siblingParent.node.Left.Equal(key) { *siblingParent.node.Left = commonKey } else { @@ -238,7 +255,7 @@ func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent s } } -func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []storageNode, sibling storageNode) error { +func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode, sibling StorageNode, isProof bool) error { commonKey, _ := findCommonKey(nodeKey, sibling.key) newParent := &Node{} @@ -248,7 +265,7 @@ func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []storageNode newParent.Left, newParent.Right = sibling.key, nodeKey leftChild, rightChild = sibling.node, node } else { - newParent.Left, newParent.Right = nodeKey, sibling.key + newParent.Left, newParent.Right = nodeKey, sibling.key // leftChild, rightChild = node, sibling.node } @@ -260,16 +277,19 @@ func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []storageNode return err } - if len(nodes) > 1 { // sibling has a parent + // Don't modify the structure outlined by the proof paths + if len(nodes) > 1 && !isProof { // sibling has a parent siblingParent := (nodes)[len(nodes)-2] - t.replaceLinkWithNewParent(sibling.key, commonKey, siblingParent) + t.replaceLinkWithNewParent(sibling.key, commonKey, siblingParent) // error with overwritting right arises here if err := t.storage.Put(siblingParent.key, siblingParent.node); err != nil { return err } t.dirtyNodes = append(t.dirtyNodes, &commonKey) - } else { + } else if !isProof { t.setRootKey(&commonKey) + } else { + t.dirtyNodes = append(t.dirtyNodes, &commonKey) } if err := t.storage.Put(nodeKey, node); err != nil { @@ -296,7 +316,58 @@ func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) { return oldValue, err } - nodes, err := t.nodesFromRoot(&nodeKey) + nodes, err := t.nodesFromRoot(&nodeKey) // correct for key,value + if err != nil { + return nil, err + } + defer func() { + for _, n := range nodes { + nodePool.Put(n.node) + } + }() + + // empty trie, make new value root + if len(nodes) == 0 { + return t.handleEmptyTrie(old, nodeKey, node, value) + } else { + // Since we short-circuit in leaf updates, we will only end up here for deletions + // Delete if key already exist + sibling := nodes[len(nodes)-1] + oldValue, err = t.deleteExistingKey(sibling, nodeKey, nodes) + // xor operation, because we don't want to return if the error is nil and the old value is nil + if (err != nil) != (oldValue != nil) { + return oldValue, err + } else if value.IsZero() { + // trying to insert 0 to a key that does not exist + return nil, nil // no-op + } + err := t.insertOrUpdateValue(&nodeKey, node, nodes, sibling, false) + if err != nil { + return nil, err + } + return &old, nil + } +} + +// Put updates the corresponding `value` for a `key` +func (t *Trie) PutWithProof(key, value *felt.Felt, lProofPath, rProofPath []StorageNode) (*felt.Felt, error) { + if key.Cmp(t.maxKey) > 0 { + return nil, fmt.Errorf("key %s exceeds trie height %d", key, t.height) + } + + old := felt.Zero + nodeKey := t.feltToKey(key) + node := &Node{ + Value: value, + } + + oldValue, err := t.updateLeaf(nodeKey, node, value) + // xor operation, because we don't want to return result if the error is nil and the old value is nil + if (err != nil) != (oldValue != nil) { + return oldValue, err + } + + nodes, err := t.nodesFromRoot(&nodeKey) // correct for key,value if err != nil { return nil, err } @@ -322,7 +393,26 @@ func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) { return nil, nil // no-op } - err := t.insertOrUpdateValue(&nodeKey, node, nodes, sibling) + IsProof, found := false, false + for i, proof := range lProofPath { + if proof.key.Equal(sibling.key) { + sibling = lProofPath[i+1] + IsProof = true + found = true + break + } + } + if !found { + for i, proof := range rProofPath { + if proof.key.Equal(sibling.key) { + sibling = rProofPath[i+1] + IsProof = true + break + } + } + } + + err := t.insertOrUpdateValue(&nodeKey, node, nodes, sibling, IsProof) if err != nil { return nil, err } @@ -330,6 +420,17 @@ func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) { } } +// Put updates the corresponding `value` for a `key` +func (t *Trie) PutInner(key *Key, node *Node) (*felt.Felt, error) { + if err := t.storage.Put(key, node); err != nil { + return nil, err + } + if t.rootKey == nil { + t.setRootKey(key) + } + return &felt.Zero, nil +} + func (t *Trie) setRootKey(newRootKey *Key) { t.rootKey = newRootKey t.rootKeyIsDirty = true @@ -379,7 +480,6 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { rightPath := path(node.Right, key) node.Value = t.hash(leftChild.Hash(&leftPath, t.hash), rightChild.Hash(&rightPath, t.hash)) - if err = t.storage.Put(key, node); err != nil { return nil, err } @@ -421,7 +521,7 @@ func (t *Trie) updateChildTriesConcurrently(root *Node) (*Node, *Node, error) { } // deleteLast deletes the last node in the given list -func (t *Trie) deleteLast(nodes []storageNode) error { +func (t *Trie) deleteLast(nodes []StorageNode) error { last := nodes[len(nodes)-1] if err := t.storage.Delete(last.key); err != nil { return err diff --git a/core/trie/trie_test.go b/core/trie/trie_test.go index 1fc03fbd09..c51ba74d60 100644 --- a/core/trie/trie_test.go +++ b/core/trie/trie_test.go @@ -7,6 +7,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/pebble" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -82,6 +83,37 @@ func TestTriePut(t *testing.T) { }) } +func TestTriePutInner(t *testing.T) { + t.Run("put node to empty trie", func(t *testing.T) { + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) + + keyFeltBytes := new(felt.Felt).SetUint64(123).Bytes() + key := trie.NewKey(10, keyFeltBytes[:]) + + leftKeyFeltBytes := new(felt.Felt).SetUint64(789).Bytes() + leftKey := trie.NewKey(11, leftKeyFeltBytes[:]) + rightKeyFeltBytes := new(felt.Felt).SetUint64(135).Bytes() + rightKey := trie.NewKey(11, rightKeyFeltBytes[:]) + node := trie.Node{ + Value: new(felt.Felt).SetUint64(456), + Left: &leftKey, + Right: &rightKey, + } + + _, err = tempTrie.PutInner(&key, &node) + require.NoError(t, err) + + gotNode, err := tempTrie.GetNodeFromKey(&key) + require.NoError(t, err) + require.Equal(t, node, *gotNode) + }) +} + func TestTrieDeleteBasic(t *testing.T) { // left branch leftKeyNum, err := strconv.ParseUint("100", 2, 64)