diff --git a/core/trie/key.go b/core/trie/key.go index db60754fb7..28a56d882a 100644 --- a/core/trie/key.go +++ b/core/trie/key.go @@ -153,3 +153,26 @@ func (k *Key) RemoveLastBit() { inUseBytes[0] = (inUseBytes[0] << unusedBitsCount) >> unusedBitsCount } } + +// CmpAligned is Cmp as if the value is bigendian bytes of key of the same length +func (k Key) CmpAligned(other *Key) int { + // No its not aligned, so need to convert to bigint then left shift it so that the MSB is of the same index + height := k.len + if other.len > height { + height = other.len + } + + b1i := k.alignedBitInt(height) + b2i := other.alignedBitInt(height) + return b1i.Cmp(b2i) +} + +func (k Key) alignedBitInt(height uint8) *big.Int { + theint := &big.Int{} + theint = theint.SetBytes(k.bitset[:]) + if k.len < height { + theint = theint.Lsh(theint, uint(height-k.len)) + } + + return theint +} diff --git a/core/trie/key_test.go b/core/trie/key_test.go index 8d56a31e0c..0f1b02c6b5 100644 --- a/core/trie/key_test.go +++ b/core/trie/key_test.go @@ -2,6 +2,7 @@ package trie_test import ( "bytes" + "fmt" "testing" "github.com/NethermindEth/juno/core/felt" @@ -153,3 +154,67 @@ func TestTruncate(t *testing.T) { }) } } + +func Test_cmp(t *testing.T) { + tests := []struct { + n1 int + n2 int + isHigher bool + }{ + { + n1: 10, + n2: 0, + isHigher: true, + }, + { + n1: 5, + n2: 0, + isHigher: true, + }, + { + n1: 5, + n2: 4, + isHigher: true, + }, + { + n1: 5, + n2: 5, + isHigher: false, + }, + { + n1: 4, + n2: 5, + isHigher: false, + }, + { + n1: 0, + n2: 5, + isHigher: false, + }, + { + n1: 300, + n2: 1, + isHigher: true, + }, + { + n1: 1, + n2: 300, + isHigher: false, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%d %d %v", test.n1, test.n2, test.isHigher), func(t *testing.T) { + k1 := numToKey(test.n1) + k2 := numToKey(test.n2) + + assert.Equal(t, + k1.CmpAligned(&k2) > 0, + test.isHigher) + }) + } +} + +func numToKey(num int) trie.Key { + return trie.NewKey(8, []byte{byte(num)}) +} diff --git a/core/trie/trie.go b/core/trie/trie.go index 2496b724d4..bc28ded416 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -556,3 +556,46 @@ func (t *Trie) dump(level int, parentP *Key) { storage: t.storage, }).dump(level+1, t.rootKey) } + +// Iterate the trie from startValue in ascending order until the consumer returned false or an error occur or end of +// trie was reached. Return true if end of trie is reached. +// TODO: its much more efficient to iterate from the txn level. But even without that, if the leaf are ordered correctly, +// block cache should have a pretty good hit rate. +func (t *Trie) Iterate(startValue *felt.Felt, consumer func(key, value *felt.Felt) (bool, error)) (bool, error) { + startKey := t.feltToKey(startValue) + + return t.doIterate(&startKey, t.rootKey, consumer) +} + +func (t *Trie) doIterate(startKey, key *Key, consumer func(key, value *felt.Felt) (bool, error)) (bool, error) { + if key == nil { + return false, nil + } + + node, err := t.storage.Get(key) + if err != nil { + return false, err + } + + if key.Len() == t.height { + if startKey.CmpAligned(key) > 0 { + return true, nil + } + keyAsFelt := key.Felt() + return consumer(&keyAsFelt, node.Value) + } + + // If the startKey is higher than the right node, no point in going to left at all + if startKey.CmpAligned(node.Right) < 0 { + next, err := t.doIterate(startKey, node.Left, consumer) + if err != nil { + return false, err + } + + if !next { + return false, nil + } + } + + return t.doIterate(startKey, node.Right, consumer) +} diff --git a/core/trie/trie_test.go b/core/trie/trie_test.go index 1fc03fbd09..59ecb70c71 100644 --- a/core/trie/trie_test.go +++ b/core/trie/trie_test.go @@ -1,6 +1,7 @@ package trie_test import ( + "math/big" "strconv" "testing" @@ -375,3 +376,118 @@ func BenchmarkTriePut(b *testing.B) { return t.Commit() })) } + +func numToFelt(num int) *felt.Felt { + return numToFeltBigInt(big.NewInt(int64(num))) +} + +func numToFeltBigInt(num *big.Int) *felt.Felt { + f := felt.Zero + return f.SetBigInt(num) +} + +func TestTrie_Iterate(t *testing.T) { + tr, err := trie.NewTriePedersen(trie.NewStorage(db.NewMemTransaction(), []byte{1}), 251) + assert.Nil(t, err) + + for i := 0; i < 10; i++ { + _, err = tr.Put(numToFelt(i*10), numToFelt(i+10)) + assert.Nil(t, err) + } + err = tr.Commit() + assert.Nil(t, err) + + tests := []struct { + name string + startKey *felt.Felt + count int + expectedKeys []*felt.Felt + expectedValues []*felt.Felt + }{ + { + name: "all", + startKey: numToFelt(0), + count: 10, + expectedKeys: []*felt.Felt{ + numToFelt(0), + numToFelt(10), + numToFelt(20), + numToFelt(30), + numToFelt(40), + numToFelt(50), + numToFelt(60), + numToFelt(70), + numToFelt(80), + numToFelt(90), + }, + expectedValues: []*felt.Felt{ + numToFelt(10), + numToFelt(11), + numToFelt(12), + numToFelt(13), + numToFelt(14), + numToFelt(15), + numToFelt(16), + numToFelt(17), + numToFelt(18), + numToFelt(19), + }, + }, + { + name: "limited", + startKey: numToFelt(0), + count: 2, + expectedKeys: []*felt.Felt{ + numToFelt(0), + numToFelt(10), + }, + expectedValues: []*felt.Felt{ + numToFelt(10), + numToFelt(11), + }, + }, + { + name: "limited with offset", + startKey: numToFelt(30), + count: 2, + expectedKeys: []*felt.Felt{ + numToFelt(30), + numToFelt(40), + }, + expectedValues: []*felt.Felt{ + numToFelt(13), + numToFelt(14), + }, + }, + { + name: "limited with offset that does not match a leaf", + startKey: numToFelt(25), + count: 2, + expectedKeys: []*felt.Felt{ + numToFelt(30), + numToFelt(40), + }, + expectedValues: []*felt.Felt{ + numToFelt(13), + numToFelt(14), + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + keys := make([]*felt.Felt, 0) + values := make([]*felt.Felt, 0) + + _, err := tr.Iterate(test.startKey, func(key *felt.Felt, value *felt.Felt) (bool, error) { + keys = append(keys, key) + values = append(values, value) + return len(keys) < test.count, nil + }) + assert.Nil(t, err) + + assert.Equal(t, test.expectedKeys, keys) + assert.Equal(t, test.expectedValues, values) + }) + } +}