Skip to content

Commit

Permalink
Refactor trie.Put to remove gocyclo lint error (#1601)
Browse files Browse the repository at this point in the history
  • Loading branch information
aminsato authored Apr 2, 2024
1 parent b6fb092 commit 3a0294e
Showing 1 changed file with 102 additions and 69 deletions.
171 changes: 102 additions & 69 deletions core/trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,98 @@ func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) {
return &leafValue, nil
}

// check if we are updating an existing leaf, if yes avoid traversing the trie
func (t *Trie) updateLeaf(nodeKey Key, node *Node, value *felt.Felt) (*felt.Felt, error) {
// Check if we are updating an existing leaf
if !value.IsZero() {
if existingLeaf, err := t.storage.Get(&nodeKey); err == nil {
old := *existingLeaf.Value // record old value to return to caller
if err = t.storage.Put(&nodeKey, node); err != nil {
return nil, err
}
t.dirtyNodes = append(t.dirtyNodes, &nodeKey)
return &old, nil
} else if !errors.Is(err, db.ErrKeyNotFound) {
return nil, err
}
}
return nil, nil
}

func (t *Trie) handleEmptyTrie(old felt.Felt, nodeKey Key, node *Node, value *felt.Felt) (*felt.Felt, error) {
if value.IsZero() {
return nil, nil // no-op
}

if err := t.storage.Put(&nodeKey, node); err != nil {
return nil, err
}
t.setRootKey(&nodeKey)
return &old, nil
}

func (t *Trie) deleteExistingKey(sibling storageNode, nodeKey Key, nodes []storageNode) (*felt.Felt, error) {
if nodeKey.Equal(sibling.key) {
// we have to deference the Value, since the Node can released back
// to the NodePool and be reused anytime
old := *sibling.node.Value // record old value to return to caller
if err := t.deleteLast(nodes); err != nil {
return nil, err
}
return &old, nil
}
return nil, nil
}

func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent storageNode) {
if siblingParent.node.Left.Equal(key) {
*siblingParent.node.Left = commonKey
} else {
*siblingParent.node.Right = commonKey
}
}

func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []storageNode, sibling storageNode) error {
commonKey, _ := findCommonKey(nodeKey, sibling.key)

newParent := &Node{}
var leftChild, rightChild *Node

if nodeKey.Test(nodeKey.Len() - commonKey.Len() - 1) {
newParent.Left, newParent.Right = sibling.key, nodeKey
leftChild, rightChild = sibling.node, node
} else {
newParent.Left, newParent.Right = nodeKey, sibling.key
leftChild, rightChild = node, sibling.node
}

leftPath := path(newParent.Left, &commonKey)
rightPath := path(newParent.Right, &commonKey)

newParent.Value = t.hash(leftChild.Hash(&leftPath, t.hash), rightChild.Hash(&rightPath, t.hash))
if err := t.storage.Put(&commonKey, newParent); err != nil {
return err
}

if len(nodes) > 1 { // sibling has a parent
siblingParent := (nodes)[len(nodes)-2]

t.replaceLinkWithNewParent(sibling.key, commonKey, siblingParent)
if err := t.storage.Put(siblingParent.key, siblingParent.node); err != nil {
return err
}
t.dirtyNodes = append(t.dirtyNodes, &commonKey)
} else {
t.setRootKey(&commonKey)
}

if err := t.storage.Put(nodeKey, node); err != nil {
return err
}
return nil
}

// Put updates the corresponding `value` for a `key`
//
//nolint:gocyclo
func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) {
if key.Cmp(t.maxKey) > 0 {
return nil, fmt.Errorf("key %s exceeds trie height %d", key, t.height)
Expand All @@ -196,18 +285,10 @@ func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) {
Value: value,
}

// check if we are updating an existing leaf, if yes avoid traversing the trie
if !value.IsZero() {
if existingLeaf, err := t.storage.Get(&nodeKey); err == nil {
old = *existingLeaf.Value // record old value to return to caller
if err = t.storage.Put(&nodeKey, node); err != nil {
return nil, err
}
t.dirtyNodes = append(t.dirtyNodes, &nodeKey)
return &old, nil
} else if !errors.Is(err, db.ErrKeyNotFound) {
return nil, err
}
oldValue, err := t.updateLeaf(nodeKey, node, value)
// xor operation, because we don't want to return result if the error is nil and the old value is nil
if (err != nil) != (oldValue != nil) {
return oldValue, err
}

nodes, err := t.nodesFromRoot(&nodeKey)
Expand All @@ -222,70 +303,22 @@ func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) {

// empty trie, make new value root
if len(nodes) == 0 {
if value.IsZero() {
return nil, nil // no-op
}

if err = t.storage.Put(&nodeKey, node); err != nil {
return nil, err
}
t.setRootKey(&nodeKey)
return &old, nil
return t.handleEmptyTrie(old, nodeKey, node, value)
} else {
// Since we short-circuit in leaf updates, we will only end up here for deletions
// Delete if key already exist
sibling := nodes[len(nodes)-1]
if nodeKey.Equal(sibling.key) {
// we have to deference the Value, since the Node can released back
// to the NodePool and be reused anytime
old = *sibling.node.Value // record old value to return to caller
if err = t.deleteLast(nodes); err != nil {
return nil, err
}
return &old, nil
oldValue, err = t.deleteExistingKey(sibling, nodeKey, nodes)
// xor operation, because we don't want to return if the error is nil and the old value is nil
if (err != nil) != (oldValue != nil) {
return oldValue, err
} else if value.IsZero() {
// trying to insert 0 to a key that does not exist
return nil, nil // no-op
}

commonKey, _ := findCommonKey(&nodeKey, sibling.key)
newParent := &Node{}
var leftChild, rightChild *Node
if nodeKey.Test(nodeKey.Len() - commonKey.Len() - 1) {
newParent.Left, newParent.Right = sibling.key, &nodeKey
leftChild, rightChild = sibling.node, node
} else {
newParent.Left, newParent.Right = &nodeKey, sibling.key
leftChild, rightChild = node, sibling.node
}

leftPath := path(newParent.Left, &commonKey)
rightPath := path(newParent.Right, &commonKey)

newParent.Value = t.hash(leftChild.Hash(&leftPath, t.hash), rightChild.Hash(&rightPath, t.hash))
if err = t.storage.Put(&commonKey, newParent); err != nil {
return nil, err
}

if len(nodes) > 1 { // sibling has a parent
siblingParent := nodes[len(nodes)-2]

// replace the link to our sibling with the new parent
if siblingParent.node.Left.Equal(sibling.key) {
*siblingParent.node.Left = commonKey
} else {
*siblingParent.node.Right = commonKey
}

if err = t.storage.Put(siblingParent.key, siblingParent.node); err != nil {
return nil, err
}
t.dirtyNodes = append(t.dirtyNodes, &commonKey)
} else {
t.setRootKey(&commonKey)
}

if err = t.storage.Put(&nodeKey, node); err != nil {
err := t.insertOrUpdateValue(&nodeKey, node, nodes, sibling)
if err != nil {
return nil, err
}
return &old, nil
Expand Down

0 comments on commit 3a0294e

Please sign in to comment.