diff --git a/core/trie/node.go b/core/trie/node.go index 7f4fb746e8..f2a5d92488 100644 --- a/core/trie/node.go +++ b/core/trie/node.go @@ -9,9 +9,11 @@ import ( // A Node represents a node in the [Trie] type Node struct { - Value *felt.Felt - Left *Key - Right *Key + Value *felt.Felt + Left *Key + Right *Key + LeftHash *felt.Felt + RightHash *felt.Felt } // Hash calculates the hash of a [Node] @@ -30,6 +32,12 @@ func (n *Node) Hash(path *Key, hashFunc hashFunc) *felt.Felt { return hash.Add(hash, &pathFelt) } +// Hash calculates the hash of a [Node] +func (n *Node) HashFromParent(parnetKey, nodeKey *Key, hashFunc hashFunc) *felt.Felt { + path := path(nodeKey, parnetKey) + return n.Hash(&path, hashFunc) +} + func (n *Node) WriteTo(buf *bytes.Buffer) (int64, error) { if n.Value == nil { return 0, errors.New("cannot marshal node with nil value") @@ -45,18 +53,38 @@ func (n *Node) WriteTo(buf *bytes.Buffer) (int64, error) { } if n.Left != nil { - wrote, err := n.Left.WriteTo(buf) + wrote, errInner := n.Left.WriteTo(buf) totalBytes += wrote if err != nil { - return totalBytes, err + return totalBytes, errInner } - wrote, err = n.Right.WriteTo(buf) // n.Right is non-nil by design + wrote, errInner = n.Right.WriteTo(buf) // n.Right is non-nil by design totalBytes += wrote if err != nil { - return totalBytes, err + return totalBytes, errInner } } + if n.LeftHash == nil && n.RightHash == nil { + return totalBytes, nil + } else if (n.LeftHash != nil && n.RightHash == nil) || (n.LeftHash == nil && n.RightHash != nil) { + return totalBytes, errors.New("cannot store only one lefthash or righthash") + } + + leftHashB := n.LeftHash.Bytes() + wrote, err = buf.Write(leftHashB[:]) + totalBytes += int64(wrote) + if err != nil { + return totalBytes, err + } + + rightHashB := n.RightHash.Bytes() + wrote, err = buf.Write(rightHashB[:]) + totalBytes += int64(wrote) + if err != nil { + return totalBytes, err + } + return totalBytes, nil } @@ -74,6 +102,8 @@ func (n *Node) UnmarshalBinary(data []byte) error { if len(data) == 0 { n.Left = nil n.Right = nil + n.LeftHash = nil + n.RightHash = nil return nil } @@ -85,5 +115,26 @@ func (n *Node) UnmarshalBinary(data []byte) error { if err := n.Left.UnmarshalBinary(data); err != nil { return err } - return n.Right.UnmarshalBinary(data[n.Left.EncodedLen():]) + data = data[n.Left.EncodedLen():] + if err := n.Right.UnmarshalBinary(data); err != nil { + return err + } + data = data[n.Right.EncodedLen():] + + if n.LeftHash == nil { + n.LeftHash = new(felt.Felt) + } + if n.RightHash == nil { + n.RightHash = new(felt.Felt) + } + if len(data) == 0 { + return nil + } + if len(data) != 2*felt.Bytes { + return errors.New("the node does not contain both left and right hash") + } + n.LeftHash.SetBytes(data[:felt.Bytes]) + data = data[felt.Bytes:] + n.RightHash.SetBytes(data[:felt.Bytes]) + return nil } diff --git a/core/trie/proof.go b/core/trie/proof.go index 5da9fee7f8..6a9d7eb099 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -1,6 +1,7 @@ package trie import ( + "errors" "fmt" "github.com/NethermindEth/juno/core/felt" @@ -28,6 +29,13 @@ func (pn *ProofNode) Hash(hash hashFunc) *felt.Felt { } } +func (pn *ProofNode) Len() uint8 { + if pn.Binary != nil { + return 1 + } + return pn.Edge.Path.len +} + func (pn *ProofNode) PrettyPrint() { if pn.Binary != nil { fmt.Printf(" Binary:\n") @@ -38,7 +46,6 @@ func (pn *ProofNode) PrettyPrint() { fmt.Printf(" Edge:\n") fmt.Printf(" Child: %v\n", pn.Edge.Child) fmt.Printf(" Path: %v\n", pn.Edge.Path) - fmt.Printf(" Value: %v\n", pn.Edge.Value) } } @@ -48,12 +55,26 @@ type Binary struct { } type Edge struct { - Child *felt.Felt - Path *Key - Value *felt.Felt + Child *felt.Felt // child hash + Path *Key // path from parent to child +} + +func GetBoundaryProofs(leftBoundary, rightBoundary *Key, tri *Trie) ([2][]ProofNode, error) { + proofs := [2][]ProofNode{} + leftProof, err := GetProof(leftBoundary, tri) + if err != nil { + return proofs, err + } + rightProof, err := GetProof(rightBoundary, tri) + if err != nil { + return proofs, err + } + proofs[0] = leftProof + proofs[1] = rightProof + return proofs, nil } -func isEdge(parentKey *Key, sNode storageNode) bool { +func isEdge(parentKey *Key, sNode StorageNode) bool { sNodeLen := sNode.key.len if parentKey == nil { // Root return sNodeLen != 0 @@ -63,7 +84,7 @@ func isEdge(parentKey *Key, sNode storageNode) bool { // Note: we need to account for the fact that Junos Trie has nodes that are Binary AND Edge, // whereas the protocol requires nodes that are Binary XOR Edge -func transformNode(tri *Trie, parentKey *Key, sNode storageNode) (*Edge, *Binary, error) { +func transformNode(tri *Trie, parentKey *Key, sNode StorageNode) (*Edge, *Binary, error) { isEdgeBool := isEdge(parentKey, sNode) var edge *Edge @@ -87,7 +108,7 @@ func transformNode(tri *Trie, parentKey *Key, sNode storageNode) (*Edge, *Binary } rightHash := rNode.Value - if isEdge(sNode.key, storageNode{node: rNode, key: sNode.node.Right}) { + if isEdge(sNode.key, StorageNode{node: rNode, key: sNode.node.Right}) { edgePath := path(sNode.node.Right, sNode.key) rEdge := ProofNode{Edge: &Edge{ Path: &edgePath, @@ -96,7 +117,7 @@ func transformNode(tri *Trie, parentKey *Key, sNode storageNode) (*Edge, *Binary rightHash = rEdge.Hash(tri.hash) } leftHash := lNode.Value - if isEdge(sNode.key, storageNode{node: lNode, key: sNode.node.Left}) { + if isEdge(sNode.key, StorageNode{node: lNode, key: sNode.node.Left}) { edgePath := path(sNode.node.Left, sNode.key) lEdge := ProofNode{Edge: &Edge{ Path: &edgePath, @@ -113,6 +134,8 @@ func transformNode(tri *Trie, parentKey *Key, sNode storageNode) (*Edge, *Binary } // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L514 +// GetProof generates a set of proof nodes from the root to the leaf. +// The proof never contains the leaf node if it is set, as we already know it's hash. func GetProof(key *Key, tri *Trie) ([]ProofNode, error) { nodesFromRoot, err := tri.nodesFromRoot(key) if err != nil { @@ -147,11 +170,12 @@ func GetProof(key *Key, tri *Trie) ([]ProofNode, error) { // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2006 func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode, hash hashFunc) bool { expectedHash := root - remainingPath := key - for _, proofNode := range proofs { + remainingPath := NewKey(key.len, key.bitset[:]) + for i, proofNode := range proofs { if !proofNode.Hash(hash).Equal(expectedHash) { return false } + switch { case proofNode.Binary != nil: if remainingPath.Test(remainingPath.Len() - 1) { @@ -165,6 +189,14 @@ func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode if err != nil { return false } + + // Todo: + // If we are verifying the key doesn't exist, then we should + // update subKey to point in the other direction + if value == nil && i == len(proofs)-1 { + return true + } + if !proofNode.Edge.Path.Equal(subKey) { return false } @@ -172,5 +204,314 @@ func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode remainingPath.Truncate(251 - proofNode.Edge.Path.Len()) //nolint:gomnd } } + return expectedHash.Equal(value) } + +// VerifyRangeProof verifies the range proof for the given range of keys. +// This is achieved by constructing a trie from the boundary proofs, and the supplied key-values. +// If the root of the reconstructed trie matches the supplied root, then the verification passes. +// If the trie is constructed incorrectly then the root will have an incorrect key(len,path), and value, +// and therefore it's hash won't match the expected root. +// ref: https://github.com/ethereum/go-ethereum/blob/v1.14.3/trie/proof.go#L484 +func VerifyRangeProof(root *felt.Felt, keys, values []*felt.Felt, proofKeys [2]*Key, proofValues [2]*felt.Felt, + proofs [2][]ProofNode, hash hashFunc, +) (bool, error) { + // Step 0: checks + if len(keys) != len(values) { + return false, fmt.Errorf("inconsistent proof data, number of keys: %d, number of values: %d", len(keys), len(values)) + } + + // Ensure all keys are monotonic increasing + if err := ensureMonotonicIncreasing(proofKeys, keys); err != nil { + return false, err + } + + // Ensure the inner values contain no deletions + for _, value := range values { + if value.Equal(&felt.Zero) { + return false, errors.New("range contains deletion") + } + } + + // Step 1: Verify proofs, and get proof paths + var proofPaths [2][]StorageNode + var err error + for i := 0; i < 2; i++ { + if proofs[i] != nil { + if !VerifyProof(root, proofKeys[i], proofValues[i], proofs[i], hash) { + return false, fmt.Errorf("invalid proof for key %x", proofKeys[i].String()) + } + + proofPaths[i], err = ProofToPath(proofs[i], proofKeys[i], hash) + if err != nil { + return false, err + } + } + } + + // Step 2: Build trie from proofPaths and keys + tmpTrie, err := BuildTrie(proofPaths[0], proofPaths[1], keys, values) + if err != nil { + return false, err + } + + // Verify that the recomputed root hash matches the provided root hash + recomputedRoot, err := tmpTrie.Root() + if err != nil { + return false, err + } + if !recomputedRoot.Equal(root) { + return false, errors.New("root hash mismatch") + } + + return true, nil +} + +func ensureMonotonicIncreasing(proofKeys [2]*Key, keys []*felt.Felt) error { + if proofKeys[0] != nil { + leftProofFelt := proofKeys[0].Felt() + if leftProofFelt.Cmp(keys[0]) >= 0 { + return errors.New("range is not monotonically increasing") + } + } + if proofKeys[1] != nil { + rightProofFelt := proofKeys[1].Felt() + if keys[len(keys)-1].Cmp(&rightProofFelt) >= 0 { + return errors.New("range is not monotonically increasing") + } + } + if len(keys) >= 2 { + for i := 0; i < len(keys)-1; i++ { + if keys[i].Cmp(keys[i+1]) >= 0 { + return errors.New("range is not monotonically increasing") + } + } + } + return nil +} + +// compressNode determines if the node needs compressed, and if so, the len needed to arrive at the next key +func compressNode(idx int, proofNodes []ProofNode, hashF hashFunc) (int, uint8, error) { + parent := &proofNodes[idx] + + if idx == len(proofNodes)-1 { + if parent.Edge != nil { + return 1, parent.Len(), nil + } + return 0, parent.Len(), nil + } + + child := &proofNodes[idx+1] + + switch { + case parent.Edge != nil && child.Binary != nil: + return 1, parent.Edge.Path.len, nil + case parent.Binary != nil && child.Edge != nil: + childHash := child.Hash(hashF) + if parent.Binary.LeftHash.Equal(childHash) || parent.Binary.RightHash.Equal(childHash) { + return 1, child.Edge.Path.len, nil + } else { + return 0, 0, errors.New("can't determine the child hash from the parent and child") + } + } + + return 0, 1, nil +} + +func assignChild(i, compressedParent int, parentNode *Node, + nilKey, leafKey, parentKey *Key, proofNodes []ProofNode, hashF hashFunc, +) (*Key, error) { + childInd := i + compressedParent + 1 + childKey, err := getChildKey(childInd, parentKey, leafKey, nilKey, proofNodes, hashF) + if err != nil { + return nil, err + } + if leafKey.Test(leafKey.len - parentKey.len - 1) { + parentNode.Right = childKey + parentNode.Left = nilKey + } else { + parentNode.Right = nilKey + parentNode.Left = childKey + } + return childKey, nil +} + +// ProofToPath returns a set of storage nodes from the root to the end of the proof path. +// The storage nodes will have the hashes of the children, but only the key of the child +// along the path outlined by the proof. +func ProofToPath(proofNodes []ProofNode, leafKey *Key, hashF hashFunc) ([]StorageNode, error) { + pathNodes := []StorageNode{} + + // Child keys that can't be derived are set to nilKey, so that we can store the node + zeroFeltBytes := new(felt.Felt).Bytes() + nilKey := NewKey(0, zeroFeltBytes[:]) + + for i, pNode := range proofNodes { + // Keep moving along the path (may need to skip nodes that were compressed into the last path node) + if i != 0 { + if skipNode(pNode, pathNodes, hashF) { + continue + } + } + + var parentKey *Key + parentNode := Node{} + + // Set the key of the current node + compressParent, compressParentOffset, err := compressNode(i, proofNodes, hashF) + if err != nil { + return nil, err + } + parentKey, err = getParentKey(i, compressParentOffset, leafKey, pNode, pathNodes, proofNodes) + if err != nil { + return nil, err + } + + // Don't store leafs along proof paths + if parentKey.len == 251 { //nolint:gomnd + break + } + + // Set the value of the current node + parentNode.Value = pNode.Hash(hashF) + + // Set the child key of the current node. + childKey, err := assignChild(i, compressParent, &parentNode, &nilKey, leafKey, parentKey, proofNodes, hashF) + if err != nil { + return nil, err + } + + // Set the LeftHash and RightHash values + parentNode.LeftHash, parentNode.RightHash, err = getLeftRightHash(i, proofNodes) + if err != nil { + return nil, err + } + pathNodes = append(pathNodes, StorageNode{key: parentKey, node: &parentNode}) + + // break early since we don't store leafs along proof paths, or if no more nodes exist along the proof paths + if childKey.len == 0 || childKey.len == 251 { + break + } + } + return pathNodes, nil +} + +func skipNode(pNode ProofNode, pathNodes []StorageNode, hashF hashFunc) bool { + lastNode := pathNodes[len(pathNodes)-1].node + noLeftMatch, noRightMatch := false, false + if lastNode.LeftHash != nil && !pNode.Hash(hashF).Equal(lastNode.LeftHash) { + noLeftMatch = true + } + if lastNode.RightHash != nil && !pNode.Hash(hashF).Equal(lastNode.RightHash) { + noRightMatch = true + } + if noLeftMatch && noRightMatch { + return true + } + return false +} + +func getLeftRightHash(parentInd int, proofNodes []ProofNode) (*felt.Felt, *felt.Felt, error) { + parent := &proofNodes[parentInd] + if parent.Binary == nil { + if parentInd+1 > len(proofNodes)-1 { + return nil, nil, errors.New("cant get hash of children from proof node, out of range") + } + parent = &proofNodes[parentInd+1] + } + return parent.Binary.LeftHash, parent.Binary.RightHash, nil +} + +func getParentKey(idx int, compressedParentOffset uint8, leafKey *Key, + pNode ProofNode, pathNodes []StorageNode, proofNodes []ProofNode, +) (*Key, error) { + var crntKey *Key + var err error + + var height uint8 + if len(pathNodes) > 0 { + if proofNodes[idx].Edge != nil { + height = pathNodes[len(pathNodes)-1].key.len + proofNodes[idx].Edge.Path.len + } else { + height = pathNodes[len(pathNodes)-1].key.len + 1 + } + } else { + height = 0 + } + + if pNode.Binary != nil { + crntKey, err = leafKey.SubKey(height) + } else { + crntKey, err = leafKey.SubKey(height + compressedParentOffset) + } + return crntKey, err +} + +func getChildKey(childIdx int, crntKey, leafKey, nilKey *Key, proofNodes []ProofNode, hashF hashFunc) (*Key, error) { + if childIdx > len(proofNodes)-1 { + return nilKey, nil + } + + compressChild, compressChildOffset, err := compressNode(childIdx, proofNodes, hashF) + if err != nil { + return nil, err + } + + if crntKey.len+uint8(compressChild)+compressChildOffset == 251 { //nolint:gomnd + return nilKey, nil + } + + return leafKey.SubKey(crntKey.len + uint8(compressChild) + compressChildOffset) +} + +// BuildTrie builds a trie using the proof paths (including inner nodes), and then sets all the keys-values (leaves) +func BuildTrie(leftProofPath, rightProofPath []StorageNode, keys, values []*felt.Felt) (*Trie, error) { //nolint:gocyclo + tempTrie, err := NewTriePedersen(newMemStorage(), 251) //nolint:gomnd + if err != nil { + return nil, err + } + + // merge proof paths + for i := range min(len(leftProofPath), len(rightProofPath)) { + // Can't store nil keys so stop merging + if leftProofPath[i].node.Left == nil || leftProofPath[i].node.Right == nil || + rightProofPath[i].node.Left == nil || rightProofPath[i].node.Right == nil { + break + } + if leftProofPath[i].key.Equal(rightProofPath[i].key) { + leftProofPath[i].node.Right = rightProofPath[i].node.Right + rightProofPath[i].node.Left = leftProofPath[i].node.Left + } else { + break + } + } + + for _, sNode := range leftProofPath { + if sNode.node.Left == nil || sNode.node.Right == nil { + break + } + _, err := tempTrie.PutInner(sNode.key, sNode.node) + if err != nil { + return nil, err + } + } + + for _, sNode := range rightProofPath { + if sNode.node.Left == nil || sNode.node.Right == nil { + break + } + _, err := tempTrie.PutInner(sNode.key, sNode.node) + if err != nil { + return nil, err + } + } + + for i := range len(keys) { + _, err := tempTrie.PutWithProof(keys[i], values[i], leftProofPath, rightProofPath) + if err != nil { + return nil, err + } + } + return tempTrie, nil +} diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index 9356db86a4..a566ac2d02 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -1,7 +1,6 @@ package trie_test import ( - "fmt" "testing" "github.com/NethermindEth/juno/core/crypto" @@ -14,9 +13,9 @@ import ( ) func buildSimpleTrie(t *testing.T) *trie.Trie { - // (250, 0, x1) + // (250, 0, x1) edge // | - // (0,0,x1) + // (0,0,x1) binary // / \ // (2) (3) // Build trie @@ -24,7 +23,7 @@ func buildSimpleTrie(t *testing.T) *trie.Trie { txn, err := memdb.NewTransaction(true) require.NoError(t, err) - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{1}), 251) + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) require.NoError(t, err) // Update trie @@ -44,12 +43,18 @@ func buildSimpleTrie(t *testing.T) *trie.Trie { } func buildSimpleBinaryRootTrie(t *testing.T) *trie.Trie { + // PF // (0, 0, x) // / \ // (250, 0, cc) (250, 11111.., dd) // | | // (cc) (dd) - // Build trie + + // JUNO + // (0, 0, x) + // / \ + // (251, 0, cc) (251, 11111.., dd) + memdb := pebble.NewMemTest(t) txn, err := memdb.NewTransaction(true) require.NoError(t, err) @@ -135,8 +140,120 @@ func buildSimpleDoubleBinaryTrie(t *testing.T) (*trie.Trie, []trie.ProofNode) { return tempTrie, expectedProofNodes } +func build3KeyTrie(t *testing.T) *trie.Trie { + // Starknet + // -------- + // + // Edge + // | + // Binary with len 249 parent + // / \ + // Binary (250) Edge with len 250 + // / \ / + // 0x4 0x5 0x6 child + + // Juno + // ---- + // + // Node (path 249) + // / \ + // Node (binary) \ + // / \ / + // 0x4 0x5 0x6 + + // Build trie + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) + + // Update trie + key1 := new(felt.Felt).SetUint64(0) + key2 := new(felt.Felt).SetUint64(1) + key3 := new(felt.Felt).SetUint64(2) + value1 := new(felt.Felt).SetUint64(4) + value2 := new(felt.Felt).SetUint64(5) + value3 := new(felt.Felt).SetUint64(6) + + _, err = tempTrie.Put(key1, value1) + require.NoError(t, err) + + _, err = tempTrie.Put(key3, value3) + require.NoError(t, err) + _, err = tempTrie.Put(key2, value2) + require.NoError(t, err) + + require.NoError(t, tempTrie.Commit()) + return tempTrie +} + +func build4KeyTrie(t *testing.T) *trie.Trie { + // Juno + // 248 + // / \ + // 249 \ + // / \ \ + // 250 \ \ + // / \ /\ /\ + // 0 1 2 4 + + // Juno - should be able to reconstruct this from proofs + // 248 + // / \ + // 249 // Note we cant derive the right key, but need to store it's hash + // / \ + // 250 \ + // / \ / (Left hash set, no key) + // 0 + + // Pathfinder (???) + // 0 Edge + // | + // 248 Binary + // / \ + // 249 \ Binary Edge ?? + // / \ \ + // 250 250 250 Binary Edge ?? + // / \ / / + // 0 1 2 4 + + // Build trie + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) + + // Update trie + key1 := new(felt.Felt).SetUint64(0) + key2 := new(felt.Felt).SetUint64(1) + key3 := new(felt.Felt).SetUint64(2) + key5 := new(felt.Felt).SetUint64(4) + value1 := new(felt.Felt).SetUint64(4) + value2 := new(felt.Felt).SetUint64(5) + value3 := new(felt.Felt).SetUint64(6) + value5 := new(felt.Felt).SetUint64(7) + + _, err = tempTrie.Put(key1, value1) + require.NoError(t, err) + + _, err = tempTrie.Put(key3, value3) + require.NoError(t, err) + _, err = tempTrie.Put(key2, value2) + require.NoError(t, err) + _, err = tempTrie.Put(key5, value5) + require.NoError(t, err) + + require.NoError(t, tempTrie.Commit()) + + return tempTrie +} + func TestGetProof(t *testing.T) { - t.Run("Simple Trie - simple binary", func(t *testing.T) { + t.Run("GP Simple Trie - simple binary", func(t *testing.T) { tempTrie := buildSimpleTrie(t) zero := trie.NewKey(250, []byte{0}) @@ -160,13 +277,13 @@ func TestGetProof(t *testing.T) { require.NoError(t, err) // Better inspection - for _, pNode := range proofNodes { - pNode.PrettyPrint() - } + // for _, pNode := range proofNodes { + // pNode.PrettyPrint() + // } require.Equal(t, expectedProofNodes, proofNodes) }) - t.Run("Simple Trie - simple double binary", func(t *testing.T) { + t.Run("GP Simple Trie - simple double binary", func(t *testing.T) { tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) expectedProofNodes[2] = trie.ProofNode{ @@ -182,13 +299,13 @@ func TestGetProof(t *testing.T) { require.NoError(t, err) // Better inspection - for _, pNode := range proofNodes { - pNode.PrettyPrint() - } + // for _, pNode := range proofNodes { + // pNode.PrettyPrint() + // } require.Equal(t, expectedProofNodes, proofNodes) }) - t.Run("Simple Trie - simple double binary edge", func(t *testing.T) { + t.Run("GP Simple Trie - simple double binary edge", func(t *testing.T) { tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) leafFelt := new(felt.Felt).SetUint64(3).Bytes() leafKey := trie.NewKey(251, leafFelt[:]) @@ -196,13 +313,13 @@ func TestGetProof(t *testing.T) { require.NoError(t, err) // Better inspection - for _, pNode := range proofNodes { - pNode.PrettyPrint() - } + // for _, pNode := range proofNodes { + // pNode.PrettyPrint() + // } require.Equal(t, expectedProofNodes, proofNodes) }) - t.Run("Simple Trie - simple binary root", func(t *testing.T) { + t.Run("GP Simple Trie - simple binary root", func(t *testing.T) { tempTrie := buildSimpleBinaryRootTrie(t) key1Bytes := new(felt.Felt).SetUint64(0).Bytes() @@ -228,13 +345,13 @@ func TestGetProof(t *testing.T) { require.NoError(t, err) // Better inspection - for _, pNode := range proofNodes { - pNode.PrettyPrint() - } + // for _, pNode := range proofNodes { + // pNode.PrettyPrint() + // } require.Equal(t, expectedProofNodes, proofNodes) }) - t.Run("Simple Trie - left-right edge", func(t *testing.T) { + t.Run("GP Simple Trie - left-right edge", func(t *testing.T) { // (251,0xff,0xaa) // / // \ @@ -272,14 +389,13 @@ func TestGetProof(t *testing.T) { require.NoError(t, err) // Better inspection - for i, pNode := range proofNodes { - fmt.Println(i) - pNode.PrettyPrint() - } + // for _, pNode := range proofNodes { + // pNode.PrettyPrint() + // } require.Equal(t, expectedProofNodes, proofNodes) }) - t.Run("Simple Trie - proof for non-set key", func(t *testing.T) { + t.Run("GP Simple Trie - proof for non-set key", func(t *testing.T) { tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) leafFelt := new(felt.Felt).SetUint64(123).Bytes() // The (root) edge node would have a shorter len if this key was set @@ -294,7 +410,7 @@ func TestGetProof(t *testing.T) { require.Equal(t, expectedProofNodes[0:2], proofNodes) }) - t.Run("Simple Trie - proof for inner key", func(t *testing.T) { + t.Run("GP Simple Trie - proof for inner key", func(t *testing.T) { tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) innerFelt := new(felt.Felt).SetUint64(2).Bytes() @@ -309,7 +425,7 @@ func TestGetProof(t *testing.T) { require.Equal(t, expectedProofNodes[0:2], proofNodes) }) - t.Run("Simple Trie - proof for non-set key, with leafs set to right and left", func(t *testing.T) { + t.Run("GP Simple Trie - proof for non-set key, with leafs set to right and left", func(t *testing.T) { tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) leafFelt := new(felt.Felt).SetUint64(2).Bytes() @@ -327,7 +443,7 @@ func TestGetProof(t *testing.T) { func TestVerifyProof(t *testing.T) { // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2137 - t.Run("Simple binary trie", func(t *testing.T) { + t.Run("VP Simple binary trie", func(t *testing.T) { tempTrie := buildSimpleTrie(t) zero := trie.NewKey(250, []byte{0}) expectedProofNodes := []trie.ProofNode{ @@ -347,14 +463,15 @@ func TestVerifyProof(t *testing.T) { root, err := tempTrie.Root() require.NoError(t, err) - key1Bytes := new(felt.Felt).SetUint64(0).Bytes() - key1 := trie.NewKey(251, key1Bytes[:]) val1 := new(felt.Felt).SetUint64(2) - assert.True(t, trie.VerifyProof(root, &key1, val1, expectedProofNodes, crypto.Pedersen)) + + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + leafkey := trie.NewKey(251, zeroFeltBytes[:]) + assert.True(t, trie.VerifyProof(root, &leafkey, val1, expectedProofNodes, crypto.Pedersen)) }) // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2167 - t.Run("Simple double binary trie", func(t *testing.T) { + t.Run("VP Simple double binary trie", func(t *testing.T) { tempTrie, _ := buildSimpleDoubleBinaryTrie(t) zero := trie.NewKey(249, []byte{0}) expectedProofNodes := []trie.ProofNode{ @@ -380,13 +497,52 @@ func TestVerifyProof(t *testing.T) { root, err := tempTrie.Root() require.NoError(t, err) - key1Bytes := new(felt.Felt).SetUint64(0).Bytes() - key1 := trie.NewKey(251, key1Bytes[:]) val1 := new(felt.Felt).SetUint64(2) - require.True(t, trie.VerifyProof(root, &key1, val1, expectedProofNodes, crypto.Pedersen)) + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + leafkey := trie.NewKey(251, zeroFeltBytes[:]) + assert.True(t, trie.VerifyProof(root, &leafkey, val1, expectedProofNodes, crypto.Pedersen)) + }) + + t.Run("VP three key trie", func(t *testing.T) { + tempTrie := build3KeyTrie(t) + zero := trie.NewKey(249, []byte{0}) + felt2 := new(felt.Felt).SetUint64(0).Bytes() + lastPath := trie.NewKey(1, felt2[:]) + expectedProofNodes := []trie.ProofNode{ + { + Edge: &trie.Edge{ + Path: &zero, + Child: utils.HexToFelt(t, "0x0768DEB8D0795D80AAAC2E5E326141F33044759F97A1BF092D8EB9C4E4BE9234"), + }, + }, + { + Binary: &trie.Binary{ + LeftHash: utils.HexToFelt(t, "0x057166F9476D0A2D6875124251841EB85A9AE37462FAE3CBF7304BCD593938E7"), + RightHash: utils.HexToFelt(t, "0x060FBDE29F96F706498EFD132DC7F312A4C99A9AE051BF152C2AF2B3CAF31E5B"), + }, + }, + { + Edge: &trie.Edge{ + Path: &lastPath, + Child: utils.HexToFelt(t, "0x6"), + }, + }, + } + + root, err := tempTrie.Root() + require.NoError(t, err) + val6 := new(felt.Felt).SetUint64(6) + + twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() + leafkey := trie.NewKey(251, twoFeltBytes[:]) + gotProof, err := trie.GetProof(&leafkey, tempTrie) + require.NoError(t, err) + require.Equal(t, expectedProofNodes, gotProof) + + assert.True(t, trie.VerifyProof(root, &leafkey, val6, expectedProofNodes, crypto.Pedersen)) }) - t.Run("non existent key - less than root edge", func(t *testing.T) { + t.Run("VP non existent key - less than root edge", func(t *testing.T) { tempTrie, _ := buildSimpleDoubleBinaryTrie(t) nonExistentKey := trie.NewKey(123, []byte{0}) // Diverges before the root node (len root node = 249) @@ -400,7 +556,7 @@ func TestVerifyProof(t *testing.T) { require.False(t, trie.VerifyProof(root, &nonExistentKey, nonExistentKeyValue, proofNodes, crypto.Pedersen)) }) - t.Run("non existent leaf key", func(t *testing.T) { + t.Run("VP non existent leaf key", func(t *testing.T) { tempTrie, _ := buildSimpleDoubleBinaryTrie(t) nonExistentKeyByte := new(felt.Felt).SetUint64(2).Bytes() // Key not set @@ -415,3 +571,301 @@ func TestVerifyProof(t *testing.T) { require.False(t, trie.VerifyProof(root, &nonExistentKey, nonExistentKeyValue, proofNodes, crypto.Pedersen)) }) } + +func TestProofToPath(t *testing.T) { + t.Run("PTP Proof To Path Simple binary trie proof to path", func(t *testing.T) { + tempTrie := buildSimpleTrie(t) + zeroFeltByte := new(felt.Felt).Bytes() + zero := trie.NewKey(250, zeroFeltByte[:]) + leafValue := utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002") + siblingValue := utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003") + proofNodes := []trie.ProofNode{ + { + Edge: &trie.Edge{ + Path: &zero, + Child: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), + }, + }, + { + Binary: &trie.Binary{ + LeftHash: leafValue, + RightHash: siblingValue, + }, + }, + } + + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + leafkey := trie.NewKey(251, zeroFeltBytes[:]) + sns, err := trie.ProofToPath(proofNodes, &leafkey, crypto.Pedersen) + require.NoError(t, err) + + rootKey := tempTrie.RootKey() + + require.Equal(t, 1, len(sns)) + require.Equal(t, rootKey.Len(), sns[0].Key().Len()) + require.Equal(t, leafValue.String(), sns[0].Node().LeftHash.String()) + require.Equal(t, siblingValue.String(), sns[0].Node().RightHash.String()) + }) + + t.Run("PTP Simple double binary trie proof to path", func(t *testing.T) { + tempTrie := buildSimpleBinaryRootTrie(t) + + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + leafkey := trie.NewKey(251, zeroFeltBytes[:]) + path1 := trie.NewKey(250, zeroFeltBytes[:]) + proofNodes := []trie.ProofNode{ + { + Binary: &trie.Binary{ + LeftHash: utils.HexToFelt(t, "0x06E08BF82793229338CE60B65D1845F836C8E2FBFE2BC59FF24AEDBD8BA219C4"), + RightHash: utils.HexToFelt(t, "0x04F9B8E66212FB528C0C1BD02F43309C53B895AA7D9DC91180001BDD28A588FA"), + }, + }, + { + Edge: &trie.Edge{ + Path: &path1, + Child: utils.HexToFelt(t, "0xcc"), + }, + }, + } + + siblingValue := utils.HexToFelt(t, "0xdd") + sns, err := trie.ProofToPath(proofNodes, &leafkey, crypto.Pedersen) + require.NoError(t, err) + rootKey := tempTrie.RootKey() + rootNode, err := tempTrie.GetNodeFromKey(rootKey) + require.NoError(t, err) + leftNode, err := tempTrie.GetNodeFromKey(rootNode.Left) + require.NoError(t, err) + require.Equal(t, 1, len(sns)) + require.Equal(t, rootKey.Len(), sns[0].Key().Len()) + require.Equal(t, leftNode.HashFromParent(rootKey, rootNode.Left, crypto.Pedersen).String(), sns[0].Node().LeftHash.String()) + require.NotEqual(t, siblingValue.String(), sns[0].Node().RightHash.String()) + }) + + t.Run("PTP boundary proofs with three key trie", func(t *testing.T) { + tri := build3KeyTrie(t) + rootKey := tri.RootKey() + rootNode, err := tri.GetNodeFromKey(rootKey) + require.NoError(t, err) + + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + zeroLeafValue := new(felt.Felt).SetUint64(4) + oneLeafValue := new(felt.Felt).SetUint64(5) + twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() + twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) + bProofs, err := trie.GetBoundaryProofs(&zeroLeafkey, &twoLeafkey, tri) + require.NoError(t, err) + + // Test 1 + leftProofPath, err := trie.ProofToPath(bProofs[0], &zeroLeafkey, crypto.Pedersen) + require.Equal(t, 2, len(leftProofPath)) + require.NoError(t, err) + left, err := tri.GetNodeFromKey(rootNode.Left) + require.NoError(t, err) + right, err := tri.GetNodeFromKey(rootNode.Right) + require.NoError(t, err) + require.Equal(t, rootKey, leftProofPath[0].Key()) + require.Equal(t, left.HashFromParent(rootKey, rootNode.Left, crypto.Pedersen).String(), leftProofPath[0].Node().LeftHash.String()) + require.Equal(t, right.HashFromParent(rootKey, rootNode.Right, crypto.Pedersen).String(), leftProofPath[0].Node().RightHash.String()) + require.Equal(t, rootNode.Left, leftProofPath[1].Key()) + require.Equal(t, zeroLeafValue.String(), leftProofPath[1].Node().LeftHash.String()) + require.Equal(t, oneLeafValue.String(), leftProofPath[1].Node().RightHash.String()) + + // Test 2 + rightProofPath, err := trie.ProofToPath(bProofs[1], &twoLeafkey, crypto.Pedersen) + require.Equal(t, 1, len(rightProofPath)) + require.NoError(t, err) + require.Equal(t, rootKey, rightProofPath[0].Key()) + require.NotEqual(t, rootNode.Right, rightProofPath[0].Node().Right) + require.NotEqual(t, uint8(0), rightProofPath[0].Node().Right) + require.Equal(t, right.HashFromParent(rootKey, rootNode.Right, crypto.Pedersen).String(), rightProofPath[0].Node().RightHash.String()) + }) +} + +func TestBuildTrie(t *testing.T) { + t.Run("Simple binary trie proof to path", func(t *testing.T) { + // Node (edge path 249) + // / \ + // Node (binary) 0x6 (leaf) + // / \ + // 0x4 0x5 (leaf, leaf) + + tri := build3KeyTrie(t) + rootKey := tri.RootKey() + rootCommitment, err := tri.Root() + require.NoError(t, err) + rootNode, err := tri.GetNodeFromKey(rootKey) + require.NoError(t, err) + leftNode, err := tri.GetNodeFromKey(rootNode.Left) + require.NoError(t, err) + leftleftNode, err := tri.GetNodeFromKey(leftNode.Left) + require.NoError(t, err) + leftrightNode, err := tri.GetNodeFromKey(leftNode.Right) + require.NoError(t, err) + + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() + twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) + bProofs, err := trie.GetBoundaryProofs(&zeroLeafkey, &twoLeafkey, tri) + require.NoError(t, err) + + leftProof, err := trie.ProofToPath(bProofs[0], &zeroLeafkey, crypto.Pedersen) + require.NoError(t, err) + + rightProof, err := trie.ProofToPath(bProofs[1], &twoLeafkey, crypto.Pedersen) + require.NoError(t, err) + + keys := []*felt.Felt{new(felt.Felt).SetUint64(1)} + values := []*felt.Felt{new(felt.Felt).SetUint64(5)} + builtTrie, err := trie.BuildTrie(leftProof, rightProof, keys, values) + require.NoError(t, err) + + builtRootKey := builtTrie.RootKey() + builtRootNode, err := builtTrie.GetNodeFromKey(builtRootKey) + require.NoError(t, err) + builtLeftNode, err := builtTrie.GetNodeFromKey(builtRootNode.Left) + require.NoError(t, err) + builtLeftRightNode, err := builtTrie.GetNodeFromKey(builtLeftNode.Right) + require.NoError(t, err) + + // Assert the structure / keys correct + require.Equal(t, rootKey, builtRootKey) + require.Equal(t, rootNode.Left, builtRootNode.Left, "left fail") + require.Equal(t, leftrightNode.Right, builtLeftRightNode.Right, "right fail") + require.Equal(t, uint8(0), builtRootNode.Right.Len(), "right fail") + require.Equal(t, uint8(0), builtLeftNode.Left.Len(), "left left fail") + + // Assert the leaf nodes have the correct values + require.Equal(t, leftleftNode.Value.String(), builtLeftNode.LeftHash.String(), "should be 0x4") + require.Equal(t, leftrightNode.Value.String(), builtLeftRightNode.Value.String(), "should be 0x5") + + // Given the above two asserts pass, we should be able to reconstruct the correct commitment + reconstructedRootCommitment, err := builtTrie.Root() + require.NoError(t, err) + require.Equal(t, rootCommitment.String(), reconstructedRootCommitment.String(), "root commitment not equal") + }) +} + +func TestVerifyRangeProof(t *testing.T) { + t.Run("VPR two proofs, single key trie", func(t *testing.T) { + // Node (edge path 249) + // / \ + // Node (binary) 0x6 (leaf) + // / \ + // 0x4 0x5 (leaf, leaf) + + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() + twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) + + tri := build3KeyTrie(t) + keys := []*felt.Felt{new(felt.Felt).SetUint64(1)} + values := []*felt.Felt{new(felt.Felt).SetUint64(5)} + proofKeys := [2]*trie.Key{&zeroLeafkey, &twoLeafkey} + proofValues := [2]*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(6)} + rootCommitment, err := tri.Root() + require.NoError(t, err) + proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) + require.NoError(t, err) + verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) + require.NoError(t, err) + require.True(t, verif) + }) + + t.Run("VPR all keys provided, no proofs needed", func(t *testing.T) { + // Node (edge path 249) + // / \ + // Node (binary) 0x6 (leaf) + // / \ + // 0x4 0x5 (leaf, leaf) + tri := build3KeyTrie(t) + keys := []*felt.Felt{new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)} + values := []*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(5), new(felt.Felt).SetUint64(6)} + proofKeys := [2]*trie.Key{} + proofValues := [2]*felt.Felt{} + proofs := [2][]trie.ProofNode{} + rootCommitment, err := tri.Root() + require.NoError(t, err) + verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) + require.NoError(t, err) + require.True(t, verif) + }) + + t.Run("VPR left proof, all right keys", func(t *testing.T) { + // Node (edge path 249) + // / \ + // Node (binary) 0x6 (leaf) + // / \ + // 0x4 0x5 (leaf, leaf) + + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + + tri := build3KeyTrie(t) + keys := []*felt.Felt{new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)} + values := []*felt.Felt{new(felt.Felt).SetUint64(5), new(felt.Felt).SetUint64(6)} + proofKeys := [2]*trie.Key{&zeroLeafkey} + proofValues := [2]*felt.Felt{new(felt.Felt).SetUint64(4)} + leftProof, err := trie.GetProof(proofKeys[0], tri) + require.NoError(t, err) + proofs := [2][]trie.ProofNode{leftProof} + rootCommitment, err := tri.Root() + require.NoError(t, err) + verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) + require.NoError(t, err) + require.True(t, verif) + }) + + t.Run("VPR right proof, all left keys", func(t *testing.T) { + // Node (edge path 249) + // / \ + // Node (binary) 0x6 (leaf) + // / \ + // 0x4 0x5 (leaf, leaf) + twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() + twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) + + tri := build3KeyTrie(t) + keys := []*felt.Felt{new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(1)} + values := []*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(5)} + proofKeys := [2]*trie.Key{nil, &twoLeafkey} + proofValues := [2]*felt.Felt{nil, new(felt.Felt).SetUint64(6)} + rightProof, err := trie.GetProof(proofKeys[1], tri) + require.NoError(t, err) + proofs := [2][]trie.ProofNode{nil, rightProof} + rootCommitment, err := tri.Root() + require.NoError(t, err) + verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) + require.NoError(t, err) + require.True(t, verif) + }) + + t.Run("VPR left proof, all inner keys, right proof with non-set key", func(t *testing.T) { + zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + + threeFeltBytes := new(felt.Felt).SetUint64(3).Bytes() + threeLeafkey := trie.NewKey(251, threeFeltBytes[:]) + + tri := build4KeyTrie(t) + keys := []*felt.Felt{new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)} + values := []*felt.Felt{new(felt.Felt).SetUint64(5), new(felt.Felt).SetUint64(6)} + proofKeys := [2]*trie.Key{&zeroLeafkey, &threeLeafkey} + proofValues := [2]*felt.Felt{new(felt.Felt).SetUint64(4), nil} + leftProof, err := trie.GetProof(proofKeys[0], tri) + require.NoError(t, err) + rightProof, err := trie.GetProof(proofKeys[1], tri) + require.NoError(t, err) + + proofs := [2][]trie.ProofNode{leftProof, rightProof} + rootCommitment, err := tri.Root() + require.NoError(t, err) + + verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) + require.NoError(t, err) + require.True(t, verif) + }) +} diff --git a/core/trie/trie.go b/core/trie/trie.go index 2496b724d4..4339f39cd1 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -130,24 +130,41 @@ func path(key, parentKey *Key) Key { // storageNode is the on-disk representation of a [Node], // where key is the storage key and node is the value. -type storageNode struct { +type StorageNode struct { key *Key node *Node } +func (sn *StorageNode) Key() *Key { + return sn.key +} + +func (sn *StorageNode) Node() *Node { + return sn.node +} + +func NewStorageNode(key *Key, node *Node) *StorageNode { + return &StorageNode{key: key, node: node} +} + // nodesFromRoot enumerates the set of [Node] objects that need to be traversed from the root // of the Trie to the node which is given by the key. // The [storageNode]s are returned in descending order beginning with the root. -func (t *Trie) nodesFromRoot(key *Key) ([]storageNode, error) { - var nodes []storageNode +func (t *Trie) nodesFromRoot(key *Key) ([]StorageNode, error) { + var nodes []StorageNode cur := t.rootKey for cur != nil { + // proof nodes set "nil" nodes to zero + if len(nodes) > 0 && cur.len == 0 { + return nodes, nil + } + node, err := t.storage.Get(cur) if err != nil { return nil, err } - nodes = append(nodes, storageNode{ + nodes = append(nodes, StorageNode{ key: cur, node: node, }) @@ -217,7 +234,7 @@ func (t *Trie) handleEmptyTrie(old felt.Felt, nodeKey Key, node *Node, value *fe return &old, nil } -func (t *Trie) deleteExistingKey(sibling storageNode, nodeKey Key, nodes []storageNode) (*felt.Felt, error) { +func (t *Trie) deleteExistingKey(sibling StorageNode, nodeKey Key, nodes []StorageNode) (*felt.Felt, error) { if nodeKey.Equal(sibling.key) { // we have to deference the Value, since the Node can released back // to the NodePool and be reused anytime @@ -230,7 +247,7 @@ func (t *Trie) deleteExistingKey(sibling storageNode, nodeKey Key, nodes []stora return nil, nil } -func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent storageNode) { +func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent StorageNode) { if siblingParent.node.Left.Equal(key) { *siblingParent.node.Left = commonKey } else { @@ -238,38 +255,58 @@ func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent s } } -func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []storageNode, sibling storageNode) error { +func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode, sibling StorageNode, siblingIsParentProof bool) error { commonKey, _ := findCommonKey(nodeKey, sibling.key) newParent := &Node{} var leftChild, rightChild *Node + var err error - if nodeKey.Test(nodeKey.Len() - commonKey.Len() - 1) { - newParent.Left, newParent.Right = sibling.key, nodeKey - leftChild, rightChild = sibling.node, node + // Update the (proof) parents child and hash + if siblingIsParentProof { + newParent, err = t.GetNodeFromKey(&commonKey) + if err != nil { + return nil + } + if nodeKey.Test(nodeKey.Len() - commonKey.Len() - 1) { + newParent.Right = nodeKey + newParent.RightHash = node.Hash(nodeKey, t.hash) + } else { + newParent.Left = nodeKey + newParent.LeftHash = node.Hash(nodeKey, t.hash) + } + if err := t.storage.Put(&commonKey, newParent); err != nil { + return err + } + t.dirtyNodes = append(t.dirtyNodes, &commonKey) } else { - newParent.Left, newParent.Right = nodeKey, sibling.key - leftChild, rightChild = node, sibling.node - } + if nodeKey.Test(nodeKey.Len() - commonKey.Len() - 1) { + newParent.Left, newParent.Right = sibling.key, nodeKey + leftChild, rightChild = sibling.node, node + } else { + newParent.Left, newParent.Right = nodeKey, sibling.key + leftChild, rightChild = node, sibling.node + } - leftPath := path(newParent.Left, &commonKey) - rightPath := path(newParent.Right, &commonKey) + leftPath := path(newParent.Left, &commonKey) + rightPath := path(newParent.Right, &commonKey) - newParent.Value = t.hash(leftChild.Hash(&leftPath, t.hash), rightChild.Hash(&rightPath, t.hash)) - if err := t.storage.Put(&commonKey, newParent); err != nil { - return err - } + newParent.Value = t.hash(leftChild.Hash(&leftPath, t.hash), rightChild.Hash(&rightPath, t.hash)) + if err := t.storage.Put(&commonKey, newParent); err != nil { + return err + } - if len(nodes) > 1 { // sibling has a parent - siblingParent := (nodes)[len(nodes)-2] + if len(nodes) > 1 { // sibling has a parent + siblingParent := (nodes)[len(nodes)-2] - t.replaceLinkWithNewParent(sibling.key, commonKey, siblingParent) - if err := t.storage.Put(siblingParent.key, siblingParent.node); err != nil { - return err + t.replaceLinkWithNewParent(sibling.key, commonKey, siblingParent) + if err := t.storage.Put(siblingParent.key, siblingParent.node); err != nil { + return err + } + t.dirtyNodes = append(t.dirtyNodes, &commonKey) + } else { + t.setRootKey(&commonKey) } - t.dirtyNodes = append(t.dirtyNodes, &commonKey) - } else { - t.setRootKey(&commonKey) } if err := t.storage.Put(nodeKey, node); err != nil { @@ -296,7 +333,7 @@ func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) { return oldValue, err } - nodes, err := t.nodesFromRoot(&nodeKey) + nodes, err := t.nodesFromRoot(&nodeKey) // correct for key,value if err != nil { return nil, err } @@ -321,8 +358,79 @@ func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) { // trying to insert 0 to a key that does not exist return nil, nil // no-op } + err := t.insertOrUpdateValue(&nodeKey, node, nodes, sibling, false) + if err != nil { + return nil, err + } + return &old, nil + } +} - err := t.insertOrUpdateValue(&nodeKey, node, nodes, sibling) +// Put updates the corresponding `value` for a `key` +func (t *Trie) PutWithProof(key, value *felt.Felt, lProofPath, rProofPath []StorageNode) (*felt.Felt, error) { + if key.Cmp(t.maxKey) > 0 { + return nil, fmt.Errorf("key %s exceeds trie height %d", key, t.height) + } + + old := felt.Zero + nodeKey := t.feltToKey(key) + node := &Node{ + Value: value, + } + + oldValue, err := t.updateLeaf(nodeKey, node, value) + // xor operation, because we don't want to return result if the error is nil and the old value is nil + if (err != nil) != (oldValue != nil) { + return oldValue, err + } + + nodes, err := t.nodesFromRoot(&nodeKey) // correct for key,value + if err != nil { + return nil, err + } + defer func() { + for _, n := range nodes { + nodePool.Put(n.node) + } + }() + + // empty trie, make new value root + if len(nodes) == 0 { + return t.handleEmptyTrie(old, nodeKey, node, value) + } else { + // Since we short-circuit in leaf updates, we will only end up here for deletions + // Delete if key already exist + sibling := nodes[len(nodes)-1] + oldValue, err = t.deleteExistingKey(sibling, nodeKey, nodes) + // xor operation, because we don't want to return if the error is nil and the old value is nil + if (err != nil) != (oldValue != nil) { + return oldValue, err + } else if value.IsZero() { + // trying to insert 0 to a key that does not exist + return nil, nil // no-op + } + + // override the sibling to be the parent if it's a proof + parentIsProof, found := false, false + for _, proof := range lProofPath { + if proof.key.Equal(sibling.key) { + sibling = proof + parentIsProof = true + found = true + break + } + } + if !found { + for _, proof := range rProofPath { + if proof.key.Equal(sibling.key) { + sibling = proof + parentIsProof = true + break + } + } + } + + err := t.insertOrUpdateValue(&nodeKey, node, nodes, sibling, parentIsProof) if err != nil { return nil, err } @@ -330,12 +438,26 @@ func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) { } } +// Put updates the corresponding `value` for a `key` +func (t *Trie) PutInner(key *Key, node *Node) (*felt.Felt, error) { + if err := t.storage.Put(key, node); err != nil { + return nil, err + } + if t.rootKey == nil { + t.setRootKey(key) + } + return &felt.Zero, nil +} + func (t *Trie) setRootKey(newRootKey *Key) { t.rootKey = newRootKey t.rootKeyIsDirty = true } -func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { +func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo + zeroFeltBytes := new(felt.Felt).Bytes() + nilKey := NewKey(0, zeroFeltBytes[:]) + node, err := t.storage.Get(key) if err != nil { return nil, err @@ -356,49 +478,76 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { } } + // Update inner proof nodes + if node.Left.Equal(&nilKey) && node.Right.Equal(&nilKey) { // leaf + shouldUpdate = false + } else if node.Left.Equal(&nilKey) || node.Right.Equal(&nilKey) { // inner + shouldUpdate = true + } if !shouldUpdate { return node, nil } + var leftIsProof, rightIsProof bool + var leftHash, rightHash *felt.Felt + if node.Left.Equal(&nilKey) { + leftIsProof = true + leftHash = node.LeftHash + } + if node.Right.Equal(&nilKey) { + rightIsProof = true + rightHash = node.RightHash + } + // To avoid over-extending, only use concurrent updates when we are not too // deep in to traversing the trie. const concurrencyMaxDepth = 8 // heuristically selected value var leftChild, rightChild *Node if key.len <= concurrencyMaxDepth { - leftChild, rightChild, err = t.updateChildTriesConcurrently(node) + leftChild, rightChild, err = t.updateChildTriesConcurrently(node, leftIsProof, rightIsProof) } else { - leftChild, rightChild, err = t.updateChildTriesSerially(node) + leftChild, rightChild, err = t.updateChildTriesSerially(node, leftIsProof, rightIsProof) } if err != nil { return nil, err } - defer nodePool.Put(leftChild) - defer nodePool.Put(rightChild) - - leftPath := path(node.Left, key) - rightPath := path(node.Right, key) - - node.Value = t.hash(leftChild.Hash(&leftPath, t.hash), rightChild.Hash(&rightPath, t.hash)) + if !leftIsProof { + leftPath := path(node.Left, key) + defer nodePool.Put(leftChild) + leftHash = leftChild.Hash(&leftPath, t.hash) + } + if !rightIsProof { + rightPath := path(node.Right, key) + defer nodePool.Put(rightChild) + rightHash = rightChild.Hash(&rightPath, t.hash) + } + node.Value = t.hash(leftHash, rightHash) if err = t.storage.Put(key, node); err != nil { return nil, err } return node, nil } -func (t *Trie) updateChildTriesSerially(root *Node) (*Node, *Node, error) { - leftChild, err := t.updateValueIfDirty(root.Left) - if err != nil { - return nil, nil, err +func (t *Trie) updateChildTriesSerially(root *Node, leftIsProof, rightIsProof bool) (*Node, *Node, error) { + var leftChild, rightChild *Node + var err error + if !leftIsProof { + leftChild, err = t.updateValueIfDirty(root.Left) + if err != nil { + return nil, nil, err + } } - rightChild, err := t.updateValueIfDirty(root.Right) - if err != nil { - return nil, nil, err + if !rightIsProof { + rightChild, err = t.updateValueIfDirty(root.Right) + if err != nil { + return nil, nil, err + } } return leftChild, rightChild, nil } -func (t *Trie) updateChildTriesConcurrently(root *Node) (*Node, *Node, error) { +func (t *Trie) updateChildTriesConcurrently(root *Node, leftIsProof, rightIsProof bool) (*Node, *Node, error) { var leftChild, rightChild *Node var lErr, rErr error @@ -406,9 +555,13 @@ func (t *Trie) updateChildTriesConcurrently(root *Node) (*Node, *Node, error) { wg.Add(1) go func() { defer wg.Done() - leftChild, lErr = t.updateValueIfDirty(root.Left) + if !leftIsProof { + leftChild, lErr = t.updateValueIfDirty(root.Left) + } }() - rightChild, rErr = t.updateValueIfDirty(root.Right) + if !rightIsProof { + rightChild, rErr = t.updateValueIfDirty(root.Right) + } wg.Wait() if lErr != nil { @@ -421,7 +574,7 @@ func (t *Trie) updateChildTriesConcurrently(root *Node) (*Node, *Node, error) { } // deleteLast deletes the last node in the given list -func (t *Trie) deleteLast(nodes []storageNode) error { +func (t *Trie) deleteLast(nodes []StorageNode) error { last := nodes[len(nodes)-1] if err := t.storage.Delete(last.key); err != nil { return err