From e5b7e080f80767c66a463e6bc5ec997670ec4b58 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 16 Jan 2025 16:12:41 +0800 Subject: [PATCH 1/6] Implement `starknet_getStorageProof` --- core/state.go | 19 + core/state_snapshot.go | 15 + core/trie/trie.go | 4 + mocks/mock_state.go | 46 +++ rpc/contract.go | 62 ---- rpc/contract_test.go | 174 --------- rpc/handlers.go | 39 +- rpc/init_test.go | 5 + rpc/storage.go | 358 ++++++++++++++++++ rpc/storage_test.go | 763 +++++++++++++++++++++++++++++++++++++++ sync/pending.go | 13 + utils/orderedset.go | 12 + utils/orderedset_test.go | 142 ++++++++ utils/slices.go | 27 +- utils/slices_test.go | 39 ++ 15 files changed, 1468 insertions(+), 250 deletions(-) delete mode 100644 rpc/contract.go delete mode 100644 rpc/contract_test.go create mode 100644 rpc/init_test.go create mode 100644 rpc/storage.go create mode 100644 rpc/storage_test.go create mode 100644 utils/orderedset_test.go diff --git a/core/state.go b/core/state.go index 1a583c482f..b49419cbd3 100644 --- a/core/state.go +++ b/core/state.go @@ -42,6 +42,10 @@ type StateReader interface { ContractNonce(addr *felt.Felt) (*felt.Felt, error) ContractStorage(addr, key *felt.Felt) (*felt.Felt, error) Class(classHash *felt.Felt) (*DeclaredClass, error) + + ClassTrie() (*trie.Trie, error) + ContractTrie() (*trie.Trie, error) + ContractStorageTrie(addr *felt.Felt) (*trie.Trie, error) } type State struct { @@ -124,6 +128,21 @@ func (s *State) Root() (*felt.Felt, error) { return crypto.PoseidonArray(stateVersion, storageRoot, classesRoot), nil } +func (s *State) ClassTrie() (*trie.Trie, error) { + // We don't need to call the closer function here because we are only reading the trie + tr, _, err := s.classesTrie() + return tr, err +} + +func (s *State) ContractTrie() (*trie.Trie, error) { + tr, _, err := s.storage() + return tr, err +} + +func (s *State) ContractStorageTrie(addr *felt.Felt) (*trie.Trie, error) { + return storage(addr, s.txn) +} + // storage returns a [core.Trie] that represents the Starknet global state in the given Txn context. func (s *State) storage() (*trie.Trie, func() error, error) { return s.globalTrie(db.StateTrie, trie.NewTriePedersen) diff --git a/core/state_snapshot.go b/core/state_snapshot.go index a7062d4dae..a0697ddfa6 100644 --- a/core/state_snapshot.go +++ b/core/state_snapshot.go @@ -4,9 +4,12 @@ import ( "errors" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" ) +var ErrHistoricalTrieNotSupported = errors.New("cannot support historical trie") + type stateSnapshot struct { blockNumber uint64 state StateHistoryReader @@ -87,3 +90,15 @@ func (s *stateSnapshot) Class(classHash *felt.Felt) (*DeclaredClass, error) { } return declaredClass, nil } + +func (s *stateSnapshot) ClassTrie() (*trie.Trie, error) { + return nil, ErrHistoricalTrieNotSupported +} + +func (s *stateSnapshot) ContractTrie() (*trie.Trie, error) { + return nil, ErrHistoricalTrieNotSupported +} + +func (s *stateSnapshot) ContractStorageTrie(addr *felt.Felt) (*trie.Trie, error) { + return nil, ErrHistoricalTrieNotSupported +} diff --git a/core/trie/trie.go b/core/trie/trie.go index 5f8a51d9c0..72450646f3 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -718,6 +718,10 @@ func (t *Trie) Dump() { t.dump(0, nil) } +func (t *Trie) HashFn() crypto.HashFn { + return t.hash +} + // Try to print a [Trie] in a somewhat human-readable form /* Todo: create more meaningful representation of trie. In the current format string, storage is being diff --git a/mocks/mock_state.go b/mocks/mock_state.go index 8994085984..2525dad396 100644 --- a/mocks/mock_state.go +++ b/mocks/mock_state.go @@ -14,6 +14,7 @@ import ( core "github.com/NethermindEth/juno/core" felt "github.com/NethermindEth/juno/core/felt" + trie "github.com/NethermindEth/juno/core/trie" gomock "go.uber.org/mock/gomock" ) @@ -55,6 +56,21 @@ func (mr *MockStateHistoryReaderMockRecorder) Class(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Class", reflect.TypeOf((*MockStateHistoryReader)(nil).Class), arg0) } +// ClassTrie mocks base method. +func (m *MockStateHistoryReader) ClassTrie() (*trie.Trie, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClassTrie") + ret0, _ := ret[0].(*trie.Trie) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ClassTrie indicates an expected call of ClassTrie. +func (mr *MockStateHistoryReaderMockRecorder) ClassTrie() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClassTrie", reflect.TypeOf((*MockStateHistoryReader)(nil).ClassTrie)) +} + // ContractClassHash mocks base method. func (m *MockStateHistoryReader) ContractClassHash(arg0 *felt.Felt) (*felt.Felt, error) { m.ctrl.T.Helper() @@ -159,3 +175,33 @@ func (mr *MockStateHistoryReaderMockRecorder) ContractStorageAt(arg0, arg1, arg2 mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ContractStorageAt", reflect.TypeOf((*MockStateHistoryReader)(nil).ContractStorageAt), arg0, arg1, arg2) } + +// ContractStorageTrie mocks base method. +func (m *MockStateHistoryReader) ContractStorageTrie(arg0 *felt.Felt) (*trie.Trie, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ContractStorageTrie", arg0) + ret0, _ := ret[0].(*trie.Trie) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ContractStorageTrie indicates an expected call of ContractStorageTrie. +func (mr *MockStateHistoryReaderMockRecorder) ContractStorageTrie(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ContractStorageTrie", reflect.TypeOf((*MockStateHistoryReader)(nil).ContractStorageTrie), arg0) +} + +// ContractTrie mocks base method. +func (m *MockStateHistoryReader) ContractTrie() (*trie.Trie, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ContractTrie") + ret0, _ := ret[0].(*trie.Trie) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ContractTrie indicates an expected call of ContractTrie. +func (mr *MockStateHistoryReaderMockRecorder) ContractTrie() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ContractTrie", reflect.TypeOf((*MockStateHistoryReader)(nil).ContractTrie)) +} diff --git a/rpc/contract.go b/rpc/contract.go deleted file mode 100644 index a33f84399e..0000000000 --- a/rpc/contract.go +++ /dev/null @@ -1,62 +0,0 @@ -package rpc - -import ( - "errors" - - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/jsonrpc" -) - -/**************************************************** - Contract Handlers -*****************************************************/ - -// Nonce returns the nonce associated with the given address in the given block number -// -// It follows the specification defined here: -// https://github.com/starkware-libs/starknet-specs/blob/a789ccc3432c57777beceaa53a34a7ae2f25fda0/api/starknet_api_openrpc.json#L633 -func (h *Handler) Nonce(id BlockID, address felt.Felt) (*felt.Felt, *jsonrpc.Error) { - stateReader, stateCloser, rpcErr := h.stateByBlockID(&id) - if rpcErr != nil { - return nil, rpcErr - } - defer h.callAndLogErr(stateCloser, "Error closing state reader in getNonce") - - nonce, err := stateReader.ContractNonce(&address) - if err != nil { - return nil, ErrContractNotFound - } - - return nonce, nil -} - -// StorageAt gets the value of the storage at the given address and key. -// -// It follows the specification defined here: -// https://github.com/starkware-libs/starknet-specs/blob/a789ccc3432c57777beceaa53a34a7ae2f25fda0/api/starknet_api_openrpc.json#L110 -func (h *Handler) StorageAt(address, key felt.Felt, id BlockID) (*felt.Felt, *jsonrpc.Error) { - stateReader, stateCloser, rpcErr := h.stateByBlockID(&id) - if rpcErr != nil { - return nil, rpcErr - } - defer h.callAndLogErr(stateCloser, "Error closing state reader in getStorageAt") - - // This checks if the contract exists because if a key doesn't exist in contract storage, - // the returned value is always zero and error is nil. - _, err := stateReader.ContractClassHash(&address) - if err != nil { - if errors.Is(err, db.ErrKeyNotFound) { - return nil, ErrContractNotFound - } - h.log.Errorw("Failed to get contract nonce", "err", err) - return nil, ErrInternal - } - - value, err := stateReader.ContractStorage(&address, &key) - if err != nil { - return nil, ErrContractNotFound - } - - return value, nil -} diff --git a/rpc/contract_test.go b/rpc/contract_test.go deleted file mode 100644 index c9abb5214e..0000000000 --- a/rpc/contract_test.go +++ /dev/null @@ -1,174 +0,0 @@ -package rpc_test - -import ( - "errors" - "testing" - - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/mocks" - "github.com/NethermindEth/juno/rpc" - "github.com/NethermindEth/juno/utils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" -) - -func TestNonce(t *testing.T) { - mockCtrl := gomock.NewController(t) - t.Cleanup(mockCtrl.Finish) - - mockReader := mocks.NewMockReader(mockCtrl) - log := utils.NewNopZapLogger() - handler := rpc.New(mockReader, nil, nil, "", log) - - t.Run("empty blockchain", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - - nonce, rpcErr := handler.Nonce(rpc.BlockID{Latest: true}, felt.Zero) - require.Nil(t, nonce) - assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) - }) - - t.Run("non-existent block hash", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(nil, nil, db.ErrKeyNotFound) - - nonce, rpcErr := handler.Nonce(rpc.BlockID{Hash: &felt.Zero}, felt.Zero) - require.Nil(t, nonce) - assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) - }) - - t.Run("non-existent block number", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(nil, nil, db.ErrKeyNotFound) - - nonce, rpcErr := handler.Nonce(rpc.BlockID{Number: 0}, felt.Zero) - require.Nil(t, nonce) - assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) - }) - - mockState := mocks.NewMockStateHistoryReader(mockCtrl) - - t.Run("non-existent contract", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(nil, errors.New("non-existent contract")) - - nonce, rpcErr := handler.Nonce(rpc.BlockID{Latest: true}, felt.Zero) - require.Nil(t, nonce) - assert.Equal(t, rpc.ErrContractNotFound, rpcErr) - }) - - expectedNonce := new(felt.Felt).SetUint64(1) - - t.Run("blockID - latest", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) - - nonce, rpcErr := handler.Nonce(rpc.BlockID{Latest: true}, felt.Zero) - require.Nil(t, rpcErr) - assert.Equal(t, expectedNonce, nonce) - }) - - t.Run("blockID - hash", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) - - nonce, rpcErr := handler.Nonce(rpc.BlockID{Hash: &felt.Zero}, felt.Zero) - require.Nil(t, rpcErr) - assert.Equal(t, expectedNonce, nonce) - }) - - t.Run("blockID - number", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) - - nonce, rpcErr := handler.Nonce(rpc.BlockID{Number: 0}, felt.Zero) - require.Nil(t, rpcErr) - assert.Equal(t, expectedNonce, nonce) - }) -} - -func TestStorageAt(t *testing.T) { - mockCtrl := gomock.NewController(t) - t.Cleanup(mockCtrl.Finish) - - mockReader := mocks.NewMockReader(mockCtrl) - log := utils.NewNopZapLogger() - handler := rpc.New(mockReader, nil, nil, "", log) - - t.Run("empty blockchain", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) - }) - - t.Run("non-existent block hash", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(nil, nil, db.ErrKeyNotFound) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Hash: &felt.Zero}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) - }) - - t.Run("non-existent block number", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(nil, nil, db.ErrKeyNotFound) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Number: 0}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) - }) - - mockState := mocks.NewMockStateHistoryReader(mockCtrl) - - t.Run("non-existent contract", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, db.ErrKeyNotFound) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrContractNotFound, rpcErr) - }) - - t.Run("non-existent key", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(nil, db.ErrKeyNotFound) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) - require.Nil(t, storage) - assert.Equal(t, rpc.ErrContractNotFound, rpcErr) - }) - - expectedStorage := new(felt.Felt).SetUint64(1) - - t.Run("blockID - latest", func(t *testing.T) { - mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) - require.Nil(t, rpcErr) - assert.Equal(t, expectedStorage, storage) - }) - - t.Run("blockID - hash", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Hash: &felt.Zero}) - require.Nil(t, rpcErr) - assert.Equal(t, expectedStorage, storage) - }) - - t.Run("blockID - number", func(t *testing.T) { - mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) - mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) - mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) - - storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Number: 0}) - require.Nil(t, rpcErr) - assert.Equal(t, expectedStorage, storage) - }) -} diff --git a/rpc/handlers.go b/rpc/handlers.go index 4c9fcb3588..fcae4b2323 100644 --- a/rpc/handlers.go +++ b/rpc/handlers.go @@ -38,19 +38,22 @@ type l1Client interface { } var ( - ErrContractNotFound = &jsonrpc.Error{Code: 20, Message: "Contract not found"} - ErrBlockNotFound = &jsonrpc.Error{Code: 24, Message: "Block not found"} - ErrInvalidTxHash = &jsonrpc.Error{Code: 25, Message: "Invalid transaction hash"} - ErrInvalidBlockHash = &jsonrpc.Error{Code: 26, Message: "Invalid block hash"} - ErrInvalidTxIndex = &jsonrpc.Error{Code: 27, Message: "Invalid transaction index in a block"} - ErrClassHashNotFound = &jsonrpc.Error{Code: 28, Message: "Class hash not found"} - ErrTxnHashNotFound = &jsonrpc.Error{Code: 29, Message: "Transaction hash not found"} - ErrPageSizeTooBig = &jsonrpc.Error{Code: 31, Message: "Requested page size is too big"} - ErrNoBlock = &jsonrpc.Error{Code: 32, Message: "There are no blocks"} - ErrInvalidContinuationToken = &jsonrpc.Error{Code: 33, Message: "Invalid continuation token"} - ErrTooManyKeysInFilter = &jsonrpc.Error{Code: 34, Message: "Too many keys provided in a filter"} - ErrContractError = &jsonrpc.Error{Code: 40, Message: "Contract error"} - ErrTransactionExecutionError = &jsonrpc.Error{Code: 41, Message: "Transaction execution error"} + ErrContractNotFound = &jsonrpc.Error{Code: 20, Message: "Contract not found"} + ErrBlockNotFound = &jsonrpc.Error{Code: 24, Message: "Block not found"} + ErrInvalidTxHash = &jsonrpc.Error{Code: 25, Message: "Invalid transaction hash"} + ErrInvalidBlockHash = &jsonrpc.Error{Code: 26, Message: "Invalid block hash"} + ErrInvalidTxIndex = &jsonrpc.Error{Code: 27, Message: "Invalid transaction index in a block"} + ErrClassHashNotFound = &jsonrpc.Error{Code: 28, Message: "Class hash not found"} + ErrTxnHashNotFound = &jsonrpc.Error{Code: 29, Message: "Transaction hash not found"} + ErrPageSizeTooBig = &jsonrpc.Error{Code: 31, Message: "Requested page size is too big"} + ErrNoBlock = &jsonrpc.Error{Code: 32, Message: "There are no blocks"} + ErrInvalidContinuationToken = &jsonrpc.Error{Code: 33, Message: "Invalid continuation token"} + ErrTooManyKeysInFilter = &jsonrpc.Error{Code: 34, Message: "Too many keys provided in a filter"} + ErrContractError = &jsonrpc.Error{Code: 40, Message: "Contract error"} + ErrTransactionExecutionError = &jsonrpc.Error{Code: 41, Message: "Transaction execution error"} + ErrStorageProofNotSupported = &jsonrpc.Error{ + Code: 42, Message: "The node doesn't support storage proofs for blocks that are too far in the past", + } ErrInvalidContractClass = &jsonrpc.Error{Code: 50, Message: "Invalid contract class"} ErrClassAlreadyDeclared = &jsonrpc.Error{Code: 51, Message: "Class already declared"} ErrInternal = &jsonrpc.Error{Code: jsonrpc.InternalError, Message: "Internal error"} @@ -459,6 +462,16 @@ func (h *Handler) MethodsV0_7() ([]jsonrpc.Method, string) { //nolint: funlen Params: []jsonrpc.Parameter{{Name: "contract_address"}, {Name: "key"}, {Name: "block_id"}}, Handler: h.StorageAt, }, + { + Name: "starknet_getStorageProof", + Params: []jsonrpc.Parameter{ + {Name: "block_id"}, + {Name: "class_hashes", Optional: true}, + {Name: "contract_addresses", Optional: true}, + {Name: "contracts_storage_keys", Optional: true}, + }, + Handler: h.StorageProof, + }, { Name: "starknet_getClassHashAt", Params: []jsonrpc.Parameter{{Name: "block_id"}, {Name: "contract_address"}}, diff --git a/rpc/init_test.go b/rpc/init_test.go new file mode 100644 index 0000000000..b5f0f1df2b --- /dev/null +++ b/rpc/init_test.go @@ -0,0 +1,5 @@ +package rpc_test + +import ( + _ "github.com/NethermindEth/juno/encoder/registry" +) diff --git a/rpc/storage.go b/rpc/storage.go new file mode 100644 index 0000000000..62d2c04efa --- /dev/null +++ b/rpc/storage.go @@ -0,0 +1,358 @@ +package rpc + +import ( + "errors" + "fmt" + "strings" + + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/jsonrpc" + "github.com/NethermindEth/juno/utils" +) + +/**************************************************** + Contract Handlers +*****************************************************/ + +// Nonce returns the nonce associated with the given address in the given block number +// +// It follows the specification defined here: +// https://github.com/starkware-libs/starknet-specs/blob/a789ccc3432c57777beceaa53a34a7ae2f25fda0/api/starknet_api_openrpc.json#L633 +func (h *Handler) Nonce(id BlockID, address felt.Felt) (*felt.Felt, *jsonrpc.Error) { + stateReader, stateCloser, rpcErr := h.stateByBlockID(&id) + if rpcErr != nil { + return nil, rpcErr + } + defer h.callAndLogErr(stateCloser, "Error closing state reader in getNonce") + + nonce, err := stateReader.ContractNonce(&address) + if err != nil { + return nil, ErrContractNotFound + } + + return nonce, nil +} + +// StorageAt gets the value of the storage at the given address and key. +// +// It follows the specification defined here: +// https://github.com/starkware-libs/starknet-specs/blob/a789ccc3432c57777beceaa53a34a7ae2f25fda0/api/starknet_api_openrpc.json#L110 +func (h *Handler) StorageAt(address, key felt.Felt, id BlockID) (*felt.Felt, *jsonrpc.Error) { + stateReader, stateCloser, rpcErr := h.stateByBlockID(&id) + if rpcErr != nil { + return nil, rpcErr + } + defer h.callAndLogErr(stateCloser, "Error closing state reader in getStorageAt") + + // This checks if the contract exists because if a key doesn't exist in contract storage, + // the returned value is always zero and error is nil. + _, err := stateReader.ContractClassHash(&address) + if err != nil { + if errors.Is(err, db.ErrKeyNotFound) { + return nil, ErrContractNotFound + } + h.log.Errorw("Failed to get contract nonce", "err", err) + return nil, ErrInternal + } + + value, err := stateReader.ContractStorage(&address, &key) + if err != nil { + return nil, ErrContractNotFound + } + + return value, nil +} + +type StorageProofResult struct { + ClassesProof []*HashToNode `json:"classes_proof"` + ContractsProof *ContractProof `json:"contracts_proof"` + ContractsStorageProofs [][]*HashToNode `json:"contracts_storage_proofs"` + GlobalRoots *GlobalRoots `json:"global_roots"` +} + +func (s *StorageProofResult) String() string { + var result strings.Builder + + result.WriteString("StorageProofResult{\n") + result.WriteString(fmt.Sprintf(" ClassesProof: %d proofs\n", len(s.ClassesProof))) + + if s.ContractsProof != nil { + result.WriteString(fmt.Sprintf(" ContractsProof: %d nodes, %d leaves\n", + len(s.ContractsProof.Nodes), + len(s.ContractsProof.LeavesData))) + + // Print nodes + result.WriteString(" Nodes:\n") + for i, node := range s.ContractsProof.Nodes { + result.WriteString(fmt.Sprintf(" [%d] Hash: %v\n", i, node.Hash)) + switch n := node.Node.(type) { + case *BinaryNode: + result.WriteString(fmt.Sprintf(" Binary{Left: %v, Right: %v}\n", n.Left, n.Right)) + case *EdgeNode: + result.WriteString(fmt.Sprintf(" Edge{Path: %v, Length: %d, Child: %v}\n", n.Path, n.Length, n.Child)) + } + } + + // Print leaves + result.WriteString(" Leaves:\n") + for i, leaf := range s.ContractsProof.LeavesData { + if leaf != nil { + result.WriteString(fmt.Sprintf(" [%d] Nonce: %v, ClassHash: %v\n", i, leaf.Nonce, leaf.ClassHash)) + } + } + } + + result.WriteString(fmt.Sprintf(" ContractsStorageProofs: %d proofs\n", len(s.ContractsStorageProofs))) + // Print storage proofs + for i, proof := range s.ContractsStorageProofs { + result.WriteString(fmt.Sprintf(" Proof[%d]: %d nodes\n", i, len(proof))) + for j, node := range proof { + result.WriteString(fmt.Sprintf(" [%d] Hash: %v\n", j, node.Hash)) + switch n := node.Node.(type) { + case *BinaryNode: + result.WriteString(fmt.Sprintf(" Binary{Left: %v, Right: %v}\n", n.Left, n.Right)) + case *EdgeNode: + result.WriteString(fmt.Sprintf(" Edge{Path: %v, Length: %d, Child: %v}\n", n.Path, n.Length, n.Child)) + } + } + } + + if s.GlobalRoots != nil { + result.WriteString(" GlobalRoots: {\n") + result.WriteString(fmt.Sprintf(" ContractsTreeRoot: %v\n", s.GlobalRoots.ContractsTreeRoot)) + result.WriteString(fmt.Sprintf(" ClassesTreeRoot: %v\n", s.GlobalRoots.ClassesTreeRoot)) + result.WriteString(fmt.Sprintf(" BlockHash: %v\n", s.GlobalRoots.BlockHash)) + result.WriteString(" }") + } + + result.WriteString("\n}") + return result.String() +} + +func (h *Handler) StorageProof(id BlockID, + classes, contracts []felt.Felt, storageKeys []StorageKeys, +) (*StorageProofResult, *jsonrpc.Error) { + // We do not support historical storage proofs for now + if !id.Latest { + return nil, ErrStorageProofNotSupported + } + + head, err := h.bcReader.Head() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + state, closer, err := h.bcReader.HeadState() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + defer h.callAndLogErr(closer, "Error closing state reader in getStorageProof") + + classTrie, err := state.ClassTrie() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + contractTrie, err := state.ContractTrie() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + // Do a sanity check and remove duplicates from the keys + classes = utils.Unique(classes) + contracts = utils.Unique(contracts) + + classProof, err := getClassProof(classTrie, classes) + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + contractProof, err := getContractProof(contractTrie, state, contracts) + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + contractStorageProof, err := getContractStorageProof(state, storageKeys) + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + contractTreeRoot, err := contractTrie.Root() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + classTreeRoot, err := classTrie.Root() + if err != nil { + return nil, ErrInternal.CloneWithData(err) + } + + return &StorageProofResult{ + ClassesProof: classProof, + ContractsProof: contractProof, + ContractsStorageProofs: contractStorageProof, + GlobalRoots: &GlobalRoots{ + ContractsTreeRoot: contractTreeRoot, + ClassesTreeRoot: classTreeRoot, + BlockHash: head.Hash, + }, + }, nil +} + +func getClassProof(tr *trie.Trie, classes []felt.Felt) ([]*HashToNode, error) { + classProof := trie.NewProofNodeSet() + for _, class := range classes { + if err := tr.Prove(&class, classProof); err != nil { + return nil, err + } + } + + return adaptProofNodes(classProof), nil +} + +func getContractProof(tr *trie.Trie, state core.StateReader, contracts []felt.Felt) (*ContractProof, error) { + contractProof := trie.NewProofNodeSet() + contractLeavesData := make([]*LeafData, len(contracts)) + for i, contract := range contracts { + if err := tr.Prove(&contract, contractProof); err != nil { + return nil, err + } + + nonce, err := state.ContractNonce(&contract) + if err != nil { + if errors.Is(err, db.ErrKeyNotFound) { // contract does not exist, skip getting leaf data + continue + } + return nil, err + } + + classHash, err := state.ContractClassHash(&contract) + if err != nil { + return nil, err + } + + contractLeavesData[i] = &LeafData{ + Nonce: nonce, + ClassHash: classHash, + } + } + + return &ContractProof{ + Nodes: adaptProofNodes(contractProof), + LeavesData: contractLeavesData, + }, nil +} + +func getContractStorageProof(state core.StateReader, storageKeys []StorageKeys) ([][]*HashToNode, error) { + contractStorageRes := make([][]*HashToNode, len(storageKeys)) + for i, storageKey := range storageKeys { + contractStorageTrie, err := state.ContractStorageTrie(&storageKey.Contract) + if err != nil { + return nil, err + } + + contractStorageProof := trie.NewProofNodeSet() + keys := utils.Unique(storageKey.Keys) + for _, key := range keys { + if err := contractStorageTrie.Prove(&key, contractStorageProof); err != nil { + return nil, err + } + } + + contractStorageRes[i] = adaptProofNodes(contractStorageProof) + } + + return contractStorageRes, nil +} + +func adaptProofNodes(proof *trie.ProofNodeSet) []*HashToNode { + nodes := make([]*HashToNode, proof.Size()) + nodeList := proof.List() + for i, hash := range proof.Keys() { + var node Node + + switch n := nodeList[i].(type) { + case *trie.Binary: + node = &BinaryNode{ + Left: n.LeftHash, + Right: n.RightHash, + } + case *trie.Edge: + path := n.Path.Felt() + node = &EdgeNode{ + Path: path.String(), + Length: int(n.Path.Len()), + Child: n.Child, + } + } + + nodes[i] = &HashToNode{ + Hash: &hash, + Node: node, + } + } + + return nodes +} + +type StorageKeys struct { + Contract felt.Felt `json:"contract_address"` + Keys []felt.Felt `json:"storage_keys"` +} + +type Node interface { + AsProofNode() trie.ProofNode +} + +type BinaryNode struct { + Left *felt.Felt `json:"left"` + Right *felt.Felt `json:"right"` +} + +type EdgeNode struct { + Path string `json:"path"` + Length int `json:"length"` + Child *felt.Felt `json:"child"` +} + +func (e *EdgeNode) AsProofNode() trie.ProofNode { + f, _ := new(felt.Felt).SetString(e.Path) + pbs := f.Bytes() + path := trie.NewKey(uint8(e.Length), pbs[:]) + + return &trie.Edge{ + Path: &path, + Child: e.Child, + } +} + +func (b *BinaryNode) AsProofNode() trie.ProofNode { + return &trie.Binary{ + LeftHash: b.Left, + RightHash: b.Right, + } +} + +type HashToNode struct { + Hash *felt.Felt `json:"node_hash"` + Node Node `json:"node"` +} + +type LeafData struct { + Nonce *felt.Felt `json:"nonce"` + ClassHash *felt.Felt `json:"class_hash"` +} + +type ContractProof struct { + Nodes []*HashToNode `json:"nodes"` + LeavesData []*LeafData `json:"contract_leaves_data"` +} + +type GlobalRoots struct { + ContractsTreeRoot *felt.Felt `json:"contracts_tree_root"` + ClassesTreeRoot *felt.Felt `json:"classes_tree_root"` + BlockHash *felt.Felt `json:"block_hash"` +} diff --git a/rpc/storage_test.go b/rpc/storage_test.go new file mode 100644 index 0000000000..bb43ff1e31 --- /dev/null +++ b/rpc/storage_test.go @@ -0,0 +1,763 @@ +package rpc_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/NethermindEth/juno/blockchain" + "github.com/NethermindEth/juno/clients/feeder" + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/db/pebble" + "github.com/NethermindEth/juno/mocks" + "github.com/NethermindEth/juno/rpc" + adaptfeeder "github.com/NethermindEth/juno/starknetdata/feeder" + "github.com/NethermindEth/juno/sync" + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestNonce(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockReader := mocks.NewMockReader(mockCtrl) + log := utils.NewNopZapLogger() + handler := rpc.New(mockReader, nil, nil, "", log) + + t.Run("empty blockchain", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) + + nonce, rpcErr := handler.Nonce(rpc.BlockID{Latest: true}, felt.Zero) + require.Nil(t, nonce) + assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) + }) + + t.Run("non-existent block hash", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(nil, nil, db.ErrKeyNotFound) + + nonce, rpcErr := handler.Nonce(rpc.BlockID{Hash: &felt.Zero}, felt.Zero) + require.Nil(t, nonce) + assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) + }) + + t.Run("non-existent block number", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(nil, nil, db.ErrKeyNotFound) + + nonce, rpcErr := handler.Nonce(rpc.BlockID{Number: 0}, felt.Zero) + require.Nil(t, nonce) + assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) + }) + + mockState := mocks.NewMockStateHistoryReader(mockCtrl) + + t.Run("non-existent contract", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(nil, errors.New("non-existent contract")) + + nonce, rpcErr := handler.Nonce(rpc.BlockID{Latest: true}, felt.Zero) + require.Nil(t, nonce) + assert.Equal(t, rpc.ErrContractNotFound, rpcErr) + }) + + expectedNonce := new(felt.Felt).SetUint64(1) + + t.Run("blockID - latest", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) + + nonce, rpcErr := handler.Nonce(rpc.BlockID{Latest: true}, felt.Zero) + require.Nil(t, rpcErr) + assert.Equal(t, expectedNonce, nonce) + }) + + t.Run("blockID - hash", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) + + nonce, rpcErr := handler.Nonce(rpc.BlockID{Hash: &felt.Zero}, felt.Zero) + require.Nil(t, rpcErr) + assert.Equal(t, expectedNonce, nonce) + }) + + t.Run("blockID - number", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractNonce(&felt.Zero).Return(expectedNonce, nil) + + nonce, rpcErr := handler.Nonce(rpc.BlockID{Number: 0}, felt.Zero) + require.Nil(t, rpcErr) + assert.Equal(t, expectedNonce, nonce) + }) +} + +func TestStorageAt(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockReader := mocks.NewMockReader(mockCtrl) + log := utils.NewNopZapLogger() + handler := rpc.New(mockReader, nil, nil, "", log) + + t.Run("empty blockchain", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) + }) + + t.Run("non-existent block hash", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(nil, nil, db.ErrKeyNotFound) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Hash: &felt.Zero}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) + }) + + t.Run("non-existent block number", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(nil, nil, db.ErrKeyNotFound) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Number: 0}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrBlockNotFound, rpcErr) + }) + + mockState := mocks.NewMockStateHistoryReader(mockCtrl) + + t.Run("non-existent contract", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractClassHash(gomock.Any()).Return(nil, db.ErrKeyNotFound) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrContractNotFound, rpcErr) + }) + + t.Run("non-existent key", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(nil, db.ErrKeyNotFound) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) + require.Nil(t, storage) + assert.Equal(t, rpc.ErrContractNotFound, rpcErr) + }) + + expectedStorage := new(felt.Felt).SetUint64(1) + + t.Run("blockID - latest", func(t *testing.T) { + mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Latest: true}) + require.Nil(t, rpcErr) + assert.Equal(t, expectedStorage, storage) + }) + + t.Run("blockID - hash", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockHash(&felt.Zero).Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Hash: &felt.Zero}) + require.Nil(t, rpcErr) + assert.Equal(t, expectedStorage, storage) + }) + + t.Run("blockID - number", func(t *testing.T) { + mockReader.EXPECT().StateAtBlockNumber(uint64(0)).Return(mockState, nopCloser, nil) + mockState.EXPECT().ContractClassHash(&felt.Zero).Return(nil, nil) + mockState.EXPECT().ContractStorage(gomock.Any(), gomock.Any()).Return(expectedStorage, nil) + + storage, rpcErr := handler.StorageAt(felt.Zero, felt.Zero, rpc.BlockID{Number: 0}) + require.Nil(t, rpcErr) + assert.Equal(t, expectedStorage, storage) + }) +} + +func TestStorageProof(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + var ( + blkHash = utils.HexToFelt(t, "0x11ead") + root = utils.HexToFelt(t, "0x43f7163af64f9199e7c0bba225c2c3310ee2947be5ec0f03c9fb1551135818b") + key = new(felt.Felt).SetUint64(1) + key2 = new(felt.Felt).SetUint64(8) + noSuchKey = new(felt.Felt).SetUint64(0) + value = new(felt.Felt).SetUint64(51) + value2 = new(felt.Felt).SetUint64(58) + blockLatest = rpc.BlockID{Latest: true} + blockNumber = uint64(1313) + ) + + tempTrie := emptyTrie(t) + _, _ = tempTrie.Put(key, value) + _, _ = tempTrie.Put(key2, value2) + _ = tempTrie.Commit() + trieRoot, _ := tempTrie.Root() + + mockReader := mocks.NewMockReader(mockCtrl) + mockState := mocks.NewMockStateHistoryReader(mockCtrl) + mockReader.EXPECT().HeadState().Return(mockState, func() error { return nil }, nil).AnyTimes() + mockReader.EXPECT().Head().Return(&core.Block{Header: &core.Header{Hash: blkHash, Number: blockNumber}}, nil).AnyTimes() + mockState.EXPECT().ClassTrie().Return(tempTrie, nil).AnyTimes() + mockState.EXPECT().ContractTrie().Return(tempTrie, nil).AnyTimes() + + log := utils.NewNopZapLogger() + handler := rpc.New(mockReader, nil, nil, "", log) + + t.Run("global roots are filled", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) + require.Nil(t, rpcErr) + + require.NotNil(t, proof) + require.NotNil(t, proof.GlobalRoots) + require.Equal(t, blkHash, proof.GlobalRoots.BlockHash) + require.Equal(t, root, proof.GlobalRoots.ClassesTreeRoot) + require.Equal(t, root, proof.GlobalRoots.ContractsTreeRoot) + }) + t.Run("error is returned whenever not latest block is requested", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(rpc.BlockID{Number: 1}, nil, nil, nil) + assert.Equal(t, rpc.ErrStorageProofNotSupported, rpcErr) + require.Nil(t, proof) + }) + t.Run("error is returned even when blknum matches head", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(rpc.BlockID{Number: blockNumber}, nil, nil, nil) + assert.Equal(t, rpc.ErrStorageProofNotSupported, rpcErr) + require.Nil(t, proof) + }) + t.Run("empty request", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 0, 0, 0, 0) + }) + t.Run("class trie hash does not exist in a trie", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*noSuchKey}, nil, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 3, 0, 0, 0) + verifyIf(t, trieRoot, noSuchKey, nil, proof.ClassesProof, tempTrie.HashFn()) + }) + t.Run("class trie hash exists in a trie", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key}, nil, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 3, 0, 0, 0) + verifyIf(t, trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) + }) + t.Run("only unique proof nodes are returned", func(t *testing.T) { + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key, *key2}, nil, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + + rootNodes := utils.Filter(proof.ClassesProof, func(h *rpc.HashToNode) bool { + return h.Hash.Equal(trieRoot) + }) + require.Len(t, rootNodes, 1) + + // verify we can still prove any of the keys in query + verifyIf(t, trieRoot, key, value, proof.ClassesProof, tempTrie.HashFn()) + verifyIf(t, trieRoot, key2, value2, proof.ClassesProof, tempTrie.HashFn()) + }) + t.Run("storage trie address does not exist in a trie", func(t *testing.T) { + mockState.EXPECT().ContractNonce(noSuchKey).Return(nil, db.ErrKeyNotFound).Times(1) + mockState.EXPECT().ContractClassHash(noSuchKey).Return(nil, db.ErrKeyNotFound).Times(0) + + proof, rpcErr := handler.StorageProof(blockLatest, nil, []felt.Felt{*noSuchKey}, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 0, 3, 1, 0) + require.Nil(t, proof.ContractsProof.LeavesData[0]) + + verifyIf(t, trieRoot, noSuchKey, nil, proof.ContractsProof.Nodes, tempTrie.HashFn()) + }) + t.Run("storage trie address exists in a trie", func(t *testing.T) { + nonce := new(felt.Felt).SetUint64(121) + mockState.EXPECT().ContractNonce(key).Return(nonce, nil).Times(1) + classHasah := new(felt.Felt).SetUint64(1234) + mockState.EXPECT().ContractClassHash(key).Return(classHasah, nil).Times(1) + + proof, rpcErr := handler.StorageProof(blockLatest, nil, []felt.Felt{*key}, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 0, 3, 1, 0) + + require.NotNil(t, proof.ContractsProof.LeavesData[0]) + ld := proof.ContractsProof.LeavesData[0] + require.Equal(t, nonce, ld.Nonce) + require.Equal(t, classHasah, ld.ClassHash) + + verifyIf(t, trieRoot, key, value, proof.ContractsProof.Nodes, tempTrie.HashFn()) + }) + t.Run("contract storage trie address does not exist in a trie", func(t *testing.T) { + contract := utils.HexToFelt(t, "0xdead") + mockState.EXPECT().ContractStorageTrie(contract).Return(emptyTrie(t), nil).Times(1) + + storageKeys := []rpc.StorageKeys{{Contract: *contract, Keys: []felt.Felt{*key}}} + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, storageKeys) + require.NotNil(t, proof) + require.Nil(t, rpcErr) + arityTest(t, proof, 0, 0, 0, 1) + require.Len(t, proof.ContractsStorageProofs[0], 0) + }) + //nolint:dupl + t.Run("contract storage trie key slot does not exist in a trie", func(t *testing.T) { + contract := utils.HexToFelt(t, "0xabcd") + mockState.EXPECT().ContractStorageTrie(contract).Return(tempTrie, nil).Times(1) + + storageKeys := []rpc.StorageKeys{{Contract: *contract, Keys: []felt.Felt{*noSuchKey}}} + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, storageKeys) + require.NotNil(t, proof) + require.Nil(t, rpcErr) + arityTest(t, proof, 0, 0, 0, 1) + require.Len(t, proof.ContractsStorageProofs[0], 3) + + verifyIf(t, trieRoot, noSuchKey, nil, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + }) + //nolint:dupl + t.Run("contract storage trie address/key exists in a trie", func(t *testing.T) { + contract := utils.HexToFelt(t, "0xadd0") + mockState.EXPECT().ContractStorageTrie(contract).Return(tempTrie, nil).Times(1) + + storageKeys := []rpc.StorageKeys{{Contract: *contract, Keys: []felt.Felt{*key}}} + proof, rpcErr := handler.StorageProof(blockLatest, nil, nil, storageKeys) + require.NotNil(t, proof) + require.Nil(t, rpcErr) + arityTest(t, proof, 0, 0, 0, 1) + require.Len(t, proof.ContractsStorageProofs[0], 3) + + verifyIf(t, trieRoot, key, value, proof.ContractsStorageProofs[0], tempTrie.HashFn()) + }) + t.Run("class & storage tries proofs requested", func(t *testing.T) { + nonce := new(felt.Felt).SetUint64(121) + mockState.EXPECT().ContractNonce(key).Return(nonce, nil) + classHasah := new(felt.Felt).SetUint64(1234) + mockState.EXPECT().ContractClassHash(key).Return(classHasah, nil) + + proof, rpcErr := handler.StorageProof(blockLatest, []felt.Felt{*key}, []felt.Felt{*key}, nil) + require.Nil(t, rpcErr) + require.NotNil(t, proof) + arityTest(t, proof, 3, 3, 1, 0) + }) +} + +func TestStorageProof_VerifyPathfinderResponse(t *testing.T) { + t.Parallel() + + // Pathfinder response for query: + // "method": "starknet_getStorageProof", + // "params": [ + // "latest", + // [], + // [ + // "0x5a03b82d726f9bb31ba41ea3a0c1143f90241e37c9a4a92174d168cda9c716d", + // "0x5fbaa249500be29fee38fdd90a7a2651a8d3935c14167570f6863f563d838f0" + // ] + // ], + // Sepolia, at block 10434 + result := rpc.StorageProofResult{ + ClassesProof: []*rpc.HashToNode{}, + ContractsProof: &rpc.ContractProof{ + LeavesData: []*rpc.LeafData{ + { + Nonce: utils.HexToFelt(t, "0x0"), + ClassHash: utils.HexToFelt(t, "0x772164c9d6179a89e7f1167f099219f47d752304b16ed01f081b6e0b45c93c3"), + }, + { + Nonce: utils.HexToFelt(t, "0x0"), + ClassHash: utils.HexToFelt(t, "0x78401746828463e2c3f92ebb261fc82f7d4d4c8d9a80a356c44580dab124cb0"), + }, + }, + Nodes: []*rpc.HashToNode{ + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x5c6be09d8faaa42a8525898b1047cebdd3526349b48decc2b767a4fa612263d"), + Right: utils.HexToFelt(t, "0xcd11aa7699c4157a287e5fe574df37e40c8b6a5ed5e1aee658fc2d634398ef"), + }, + Hash: utils.HexToFelt(t, "0x7884784e689e733c1ea2c4ee3b1f790c4ca4992b26d8aee31abb5d9270d4947"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x1cdf395ebbba2f3a6234ad9827b08453a4a0b7745e2d919fe7b07749efa5325"), + Right: utils.HexToFelt(t, "0xcdd37cf6cce8bc373e2c9d8d6754b057275ddd910a9d133b4d31086632d0f4"), + }, + Hash: utils.HexToFelt(t, "0x44fcfce222b7e5a098346615dc838d8ae90ff55da82db7cdce4303f34042ff6"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x2c55bc287a1b31a405c681c2bb720811dd9f33523241561ea4b356f717ff9f6"), + Right: utils.HexToFelt(t, "0x2012025c00174e3eb72baba21e58a56e5114e571f64cb1040f7de0c8daef618"), + }, + Hash: utils.HexToFelt(t, "0x7f2b62cf9713a0b635b967c2e2891282631519eebca6ea0bddaa1a1a804919f"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x211a80e63ac0b12b29279c3d57ea5771b5003ea464b055aeb8ad8618ff3cd69"), + Right: utils.HexToFelt(t, "0x44f55356be17913dcd79e0bb4dbc986d0642bb3f000e540bb54bfa2d4189a74"), + }, + Hash: utils.HexToFelt(t, "0x69e208899d9deeae0732e95ce9d68d123abd9b59f157435fc3554e1fa3a92a8"), + }, + { + Node: &rpc.EdgeNode{ + Child: utils.HexToFelt(t, "0x6b45780618ce075fb4543396b3a6949915c04962b2e411c4f1b2a6813d540da"), + Length: 239, + Path: "0x3b82d726f9bb31ba41ea3a0c1143f90241e37c9a4a92174d168cda9c716d", + }, + Hash: utils.HexToFelt(t, "0x2c55bc287a1b31a405c681c2bb720811dd9f33523241561ea4b356f717ff9f6"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x7be97a0f8a99126208712673c69c292a26273707c884e96e17c761ee7097ae5"), + Right: utils.HexToFelt(t, "0x3ae1731f598d03a9033c6f5d29871cd5a80c4eba36a7a0a73775ea9d8d522f3"), + }, + Hash: utils.HexToFelt(t, "0xcd11aa7699c4157a287e5fe574df37e40c8b6a5ed5e1aee658fc2d634398ef"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x7f2b62cf9713a0b635b967c2e2891282631519eebca6ea0bddaa1a1a804919f"), + Right: utils.HexToFelt(t, "0x77f807a73f0e7ccad122cd946d79d8f4ce9e02f01017467e7cf4ad993cfa482"), + }, + Hash: utils.HexToFelt(t, "0x326e52c7cba85fedb456bb1c25dda2075ebe3367a329eb297144cb7f8d1f7d9"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x35d32a880d122ffc43a46e280c0ff34a9de286c2cb2e3933229f419a6ceed8e"), + Right: utils.HexToFelt(t, "0x14c9f5368ebbe1cc8d1db2dde1f97d18cabf450bbc23f154985c7e15e15bdcf"), + }, + Hash: utils.HexToFelt(t, "0x1159575d44f9b716f2cfbb13da873f8e7d9824e6b7b615dac5ce9c7b0e2bffd"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x1e5dfbcf23a5e942208f5ccfa25db1147dbfb2984df32a692102851757998cd"), + Right: utils.HexToFelt(t, "0x69e208899d9deeae0732e95ce9d68d123abd9b59f157435fc3554e1fa3a92a8"), + }, + Hash: utils.HexToFelt(t, "0x2722e2a47b3f10db016928bcc7451cd2088a1caea2fbb5f08e1b71dfe1db1c2"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x2634833b52e930231b53d58286647d9818a276dd12ace8286dae63b896c3ba1"), + Right: utils.HexToFelt(t, "0x1f248a8796f18bc9d116e5f3c3956c47e091c05f1c9596453b2fefa2b725507"), + }, + Hash: utils.HexToFelt(t, "0x109e30040b25357cc51726d6041ba1f09ec02dd8b3ca2ffa686a858c9293796"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x7884784e689e733c1ea2c4ee3b1f790c4ca4992b26d8aee31abb5d9270d4947"), + Right: utils.HexToFelt(t, "0x4e354efe4fcc718d3454d532b50cd3c73ac84f05df918981433162c84650f6c"), + }, + Hash: utils.HexToFelt(t, "0x88648f7a7b355914ed41bb28101110cff8fb68f1a9b39958823c72992d8675"), + }, + { + Node: &rpc.EdgeNode{ + Child: utils.HexToFelt(t, "0x4169679eea4895011fb8e9029b4591a210b3b9e9aa23f12f25cf45cbcaadfe8"), + Length: 1, + Path: "0x1", + }, + Hash: utils.HexToFelt(t, "0x44f55356be17913dcd79e0bb4dbc986d0642bb3f000e540bb54bfa2d4189a74"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x192804e98b1f3fdad2d8fab79bfb922611edc5fb48dcd1e9db02cd46cfa9763"), + Right: utils.HexToFelt(t, "0x4717a5dd5048d62401bc7db57594d3bdbfd3c7b99788a83c5e77b6db9822149"), + }, + Hash: utils.HexToFelt(t, "0x14c9f5368ebbe1cc8d1db2dde1f97d18cabf450bbc23f154985c7e15e15bdcf"), + }, + { + Node: &rpc.EdgeNode{ + Child: utils.HexToFelt(t, "0x25790175fe1fbeed47cbf510a41fba8676bea20a0c8888d4b9090b8f5cf19b8"), + Length: 238, + Path: "0x2a249500be29fee38fdd90a7a2651a8d3935c14167570f6863f563d838f0", + }, + Hash: utils.HexToFelt(t, "0x331128166378265a07c0be65b242d47d1965e785b6f4f6e1bca3731de5d2d1d"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x331128166378265a07c0be65b242d47d1965e785b6f4f6e1bca3731de5d2d1d"), + Right: utils.HexToFelt(t, "0x12af5e7e95772777d98792be8ade3b18c06ab21aa492a1821d5be3ac291374a"), + }, + Hash: utils.HexToFelt(t, "0x4169679eea4895011fb8e9029b4591a210b3b9e9aa23f12f25cf45cbcaadfe8"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x485b298f33aa076113362f82f4bf64f23e2eb5b84209353a630a46cd20fdde5"), + Right: utils.HexToFelt(t, "0x1159575d44f9b716f2cfbb13da873f8e7d9824e6b7b615dac5ce9c7b0e2bffd"), + }, + Hash: utils.HexToFelt(t, "0x3ae1731f598d03a9033c6f5d29871cd5a80c4eba36a7a0a73775ea9d8d522f3"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x2358473807e0a43a66b918247c0fb0d0649c72a32f19eee8bcc76c090b37951"), + Right: utils.HexToFelt(t, "0x109e30040b25357cc51726d6041ba1f09ec02dd8b3ca2ffa686a858c9293796"), + }, + Hash: utils.HexToFelt(t, "0x485b298f33aa076113362f82f4bf64f23e2eb5b84209353a630a46cd20fdde5"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x326e52c7cba85fedb456bb1c25dda2075ebe3367a329eb297144cb7f8d1f7d9"), + Right: utils.HexToFelt(t, "0x41149879a9d24ba0a2ccfb56415c04bdabb1c51eb0900a17dee2c715d6b1c70"), + }, + Hash: utils.HexToFelt(t, "0x1cdf395ebbba2f3a6234ad9827b08453a4a0b7745e2d919fe7b07749efa5325"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x454a8b3fc492869e79b16e87461d0b5101eb5d25389f492039ef6a380878b39"), + Right: utils.HexToFelt(t, "0x5a99604af4e482d046afe656b6ebe7805c72a1b7979d00608f27b276eb33442"), + }, + Hash: utils.HexToFelt(t, "0x4717a5dd5048d62401bc7db57594d3bdbfd3c7b99788a83c5e77b6db9822149"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x2f6c0e4b8022b48461e54e4f9358c51d5444ae2e2253a31baa68d4cb0c938de"), + Right: utils.HexToFelt(t, "0x88648f7a7b355914ed41bb28101110cff8fb68f1a9b39958823c72992d8675"), + }, + Hash: utils.HexToFelt(t, "0x47182b7d8158a8f80ed15822719aa306af37383a0cf91518d21ba63e73fea13"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x44fcfce222b7e5a098346615dc838d8ae90ff55da82db7cdce4303f34042ff6"), + Right: utils.HexToFelt(t, "0xc3da9c726d244197963a8a7beb4a3aee353b3b663daf2aa1bcf1c087b5e20d"), + }, + Hash: utils.HexToFelt(t, "0x2634833b52e930231b53d58286647d9818a276dd12ace8286dae63b896c3ba1"), + }, + { + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x2722e2a47b3f10db016928bcc7451cd2088a1caea2fbb5f08e1b71dfe1db1c2"), + Right: utils.HexToFelt(t, "0x79c09acd32044c7d455299ca67e2a8fafce25afaf6d5e89ff4632b251dddc8d"), + }, + Hash: utils.HexToFelt(t, "0x5a99604af4e482d046afe656b6ebe7805c72a1b7979d00608f27b276eb33442"), + }, + }, + }, + ContractsStorageProofs: [][]*rpc.HashToNode{}, + GlobalRoots: &rpc.GlobalRoots{ + BlockHash: utils.HexToFelt(t, "0xae4cc763c8b350913e00e12cffd51fb7e3b730e29036864a8afd8ec323ecd6"), + ClassesTreeRoot: utils.HexToFelt(t, "0xea1568e1ca4e5b8c19cdf130dc3194f9cb8e5eee2fa5ec54a338a4dccfd6e3"), + ContractsTreeRoot: utils.HexToFelt(t, "0x47182b7d8158a8f80ed15822719aa306af37383a0cf91518d21ba63e73fea13"), + }, + } + + root := result.GlobalRoots.ContractsTreeRoot + + t.Run("first contract proof verification", func(t *testing.T) { + t.Parallel() + + firstContractAddr := utils.HexToFelt(t, "0x5a03b82d726f9bb31ba41ea3a0c1143f90241e37c9a4a92174d168cda9c716d") + firstContractLeaf := utils.HexToFelt(t, "0x6b45780618ce075fb4543396b3a6949915c04962b2e411c4f1b2a6813d540da") + verifyIf(t, root, firstContractAddr, firstContractLeaf, result.ContractsProof.Nodes, crypto.Pedersen) + }) + + t.Run("second contract proof verification", func(t *testing.T) { + t.Parallel() + + secondContractAddr := utils.HexToFelt(t, "0x5fbaa249500be29fee38fdd90a7a2651a8d3935c14167570f6863f563d838f0") + secondContractLeaf := utils.HexToFelt(t, "0x25790175fe1fbeed47cbf510a41fba8676bea20a0c8888d4b9090b8f5cf19b8") + verifyIf(t, root, secondContractAddr, secondContractLeaf, result.ContractsProof.Nodes, crypto.Pedersen) + }) +} + +func TestStorageProof_StorageRoots(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + client := feeder.NewTestClient(t, &utils.Mainnet) + gw := adaptfeeder.New(client) + + log := utils.NewNopZapLogger() + testDB := pebble.NewMemTest(t) + bc := blockchain.New(testDB, &utils.Mainnet, nil) + synchronizer := sync.New(bc, gw, log, time.Duration(0), false, testDB) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + + require.NoError(t, synchronizer.Run(ctx)) + cancel() + + var ( + expectedBlockHash = utils.HexToFelt(t, "0x4e1f77f39545afe866ac151ac908bd1a347a2a8a7d58bef1276db4f06fdf2f6") + expectedGlobalRoot = utils.HexToFelt(t, "0x3ceee867d50b5926bb88c0ec7e0b9c20ae6b537e74aac44b8fcf6bb6da138d9") + expectedClsRoot = utils.HexToFelt(t, "0x0") + expectedStgRoot = utils.HexToFelt(t, "0x3ceee867d50b5926bb88c0ec7e0b9c20ae6b537e74aac44b8fcf6bb6da138d9") + expectedContractAddress = utils.HexToFelt(t, "0x2d6c9569dea5f18628f1ef7c15978ee3093d2d3eec3b893aac08004e678ead3") + expectedContractLeaf = utils.HexToFelt(t, "0x7036d8dd68dc9539c6db8c88f72b1ab16e76d62b5f09118eca5ae78276b0ee4") + ) + + t.Run("sanity check - mainnet block 2", func(t *testing.T) { + expectedBlockNumber := uint64(2) + + blk, err := bc.Head() + assert.NoError(t, err) + assert.Equal(t, expectedBlockNumber, blk.Number) + assert.Equal(t, expectedBlockHash, blk.Hash, blk.Hash.String()) + assert.Equal(t, expectedGlobalRoot, blk.GlobalStateRoot, blk.GlobalStateRoot.String()) + }) + + t.Run("check class and storage roots matches the global", func(t *testing.T) { + reader, closer, err := bc.HeadState() + assert.NoError(t, err) + defer func() { _ = closer() }() + + classTrie, err := reader.ClassTrie() + assert.NoError(t, err) + + contractTrie, err := reader.ContractTrie() + assert.NoError(t, err) + + clsRoot, err := classTrie.Root() + assert.NoError(t, err) + + stgRoot, err := contractTrie.Root() + assert.NoError(t, err) + + assert.Equal(t, expectedClsRoot, clsRoot, clsRoot.String()) + assert.Equal(t, expectedStgRoot, stgRoot, stgRoot.String()) + + verifyGlobalStateRoot(t, expectedGlobalRoot, clsRoot, stgRoot) + }) + + t.Run("check requested contract and storage slot exists", func(t *testing.T) { + stateReader, stCloser, err := bc.HeadState() + assert.NoError(t, err) + defer func() { _ = stCloser() }() + + contractTrie, err := stateReader.ContractTrie() + assert.NoError(t, err) + + leaf, err := contractTrie.Get(expectedContractAddress) + assert.NoError(t, err) + assert.Equal(t, leaf, expectedContractLeaf, leaf.String()) + + clsHash, err := stateReader.ContractClassHash(expectedContractAddress) + assert.NoError(t, err) + assert.Equal(t, clsHash, utils.HexToFelt(t, "0x10455c752b86932ce552f2b0fe81a880746649b9aee7e0d842bf3f52378f9f8"), clsHash.String()) + }) + + t.Run("get contract proof", func(t *testing.T) { + handler := rpc.New(bc, nil, nil, "", log) + result, rpcErr := handler.StorageProof( + rpc.BlockID{Latest: true}, nil, []felt.Felt{*expectedContractAddress}, nil) + require.Nil(t, rpcErr) + + expectedResult := rpc.StorageProofResult{ + ClassesProof: []*rpc.HashToNode{}, + ContractsStorageProofs: [][]*rpc.HashToNode{}, + ContractsProof: &rpc.ContractProof{ + LeavesData: []*rpc.LeafData{ + { + Nonce: utils.HexToFelt(t, "0x0"), + ClassHash: utils.HexToFelt(t, "0x10455c752b86932ce552f2b0fe81a880746649b9aee7e0d842bf3f52378f9f8"), + }, + }, + Nodes: []*rpc.HashToNode{ + { + Hash: utils.HexToFelt(t, "0x3ceee867d50b5926bb88c0ec7e0b9c20ae6b537e74aac44b8fcf6bb6da138d9"), + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x4e1f289e55ac8a821fd463478e6f5543256beb934a871be91d00a0d3f2e7964"), + Right: utils.HexToFelt(t, "0x67d9833b51e7bf1cab0e71e68477bf7f0b704391d753f9d793008e4f6587c53"), + }, + }, + { + Hash: utils.HexToFelt(t, "0x4e1f289e55ac8a821fd463478e6f5543256beb934a871be91d00a0d3f2e7964"), + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x1ef87d62309ff1cad58d39e8f5480f9caa9acd78a43f139d87220a1babe38a4"), + Right: utils.HexToFelt(t, "0x9a258d24b3aeb7e263e910d68a18d85305703a2f20df2e806ecbb1fb28760f"), + }, + }, + { + Hash: utils.HexToFelt(t, "0x9a258d24b3aeb7e263e910d68a18d85305703a2f20df2e806ecbb1fb28760f"), + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x53f61d0cb8099e2e7ffc214c4ef7ac8520abb5327510f84affe90b1890d314c"), + Right: utils.HexToFelt(t, "0x45ca67f381dcd01fec774743a4aaed6b36e1bda979185cf5dce538ad0007914"), + }, + }, + { + Hash: utils.HexToFelt(t, "0x53f61d0cb8099e2e7ffc214c4ef7ac8520abb5327510f84affe90b1890d314c"), + Node: &rpc.BinaryNode{ + Left: utils.HexToFelt(t, "0x17d6fc8431c48e41222a3ede441d1e2d91c31eb67a8aa9c030c99c510e9f34c"), + Right: utils.HexToFelt(t, "0x1cf95259ae39c038e87224fa5fdb7c7eeba6dd4263e05e80c9a8e27c3240f2c"), + }, + }, + { + Hash: utils.HexToFelt(t, "0x1cf95259ae39c038e87224fa5fdb7c7eeba6dd4263e05e80c9a8e27c3240f2c"), + Node: &rpc.EdgeNode{ + Path: "0x56c9569dea5f18628f1ef7c15978ee3093d2d3eec3b893aac08004e678ead3", + Length: 247, + Child: expectedContractLeaf, + }, + }, + }, + }, + GlobalRoots: &rpc.GlobalRoots{ + BlockHash: expectedBlockHash, + ClassesTreeRoot: expectedClsRoot, + ContractsTreeRoot: expectedStgRoot, + }, + } + + assert.Equal(t, expectedResult, *result) + }) +} + +func arityTest(t *testing.T, + proof *rpc.StorageProofResult, + classesProofArity int, + contractsProofNodesArity int, + contractsProofLeavesArity int, + contractStorageArity int, +) { + require.Len(t, proof.ClassesProof, classesProofArity) + require.Len(t, proof.ContractsStorageProofs, contractStorageArity) + require.NotNil(t, proof.ContractsProof) + require.Len(t, proof.ContractsProof.Nodes, contractsProofNodesArity) + require.Len(t, proof.ContractsProof.LeavesData, contractsProofLeavesArity) +} + +func verifyIf( + t *testing.T, + root, key, value *felt.Felt, + proof []*rpc.HashToNode, + hashF crypto.HashFn, +) { + t.Helper() + + proofSet := trie.NewProofNodeSet() + for _, hn := range proof { + proofSet.Put(*hn.Hash, hn.Node.AsProofNode()) + } + + leaf, err := trie.VerifyProof(root, key, proofSet, hashF) + require.NoError(t, err) + + // non-membership test + if value == nil { + value = felt.Zero.Clone() + } + require.Equal(t, leaf, value) +} + +func emptyTrie(t *testing.T) *trie.Trie { + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) + return tempTrie +} + +func verifyGlobalStateRoot(t *testing.T, globalStateRoot, classRoot, storageRoot *felt.Felt) { + stateVersion := new(felt.Felt).SetBytes([]byte(`STARKNET_STATE_V0`)) + if classRoot.IsZero() { + assert.Equal(t, globalStateRoot, storageRoot) + } else { + assert.Equal(t, globalStateRoot, crypto.PoseidonArray(stateVersion, storageRoot, classRoot)) + } +} diff --git a/sync/pending.go b/sync/pending.go index 03e12fcf3a..830dc6792f 100644 --- a/sync/pending.go +++ b/sync/pending.go @@ -3,6 +3,7 @@ package sync import ( "github.com/NethermindEth/juno/core" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie" ) type Pending struct { @@ -65,3 +66,15 @@ func (p *PendingState) Class(classHash *felt.Felt) (*core.DeclaredClass, error) return p.head.Class(classHash) } + +func (p *PendingState) ClassTrie() (*trie.Trie, error) { + return nil, core.ErrHistoricalTrieNotSupported +} + +func (p *PendingState) ContractTrie() (*trie.Trie, error) { + return nil, core.ErrHistoricalTrieNotSupported +} + +func (p *PendingState) ContractStorageTrie(addr *felt.Felt) (*trie.Trie, error) { + return nil, core.ErrHistoricalTrieNotSupported +} diff --git a/utils/orderedset.go b/utils/orderedset.go index 1fd9ed3aef..9d78809731 100644 --- a/utils/orderedset.go +++ b/utils/orderedset.go @@ -63,3 +63,15 @@ func (o *OrderedSet[K, V]) List() []V { copy(values, o.items) return values } + +// Keys returns a slice of keys in their insertion order +func (o *OrderedSet[K, V]) Keys() []K { + o.lock.RLock() + defer o.lock.RUnlock() + + keys := make([]K, len(o.items)) + for k, pos := range o.itemPos { + keys[pos] = k + } + return keys +} diff --git a/utils/orderedset_test.go b/utils/orderedset_test.go new file mode 100644 index 0000000000..99b426f8b5 --- /dev/null +++ b/utils/orderedset_test.go @@ -0,0 +1,142 @@ +package utils + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOrderedSet(t *testing.T) { + t.Run("basic operations", func(t *testing.T) { + set := NewOrderedSet[string, int]() + + // Test initial state + assert.Equal(t, 0, set.Size()) + assert.Empty(t, set.List()) + assert.Empty(t, set.Keys()) + + // Test Put and Get + set.Put("a", 1) + set.Put("b", 2) + set.Put("c", 3) + + val, exists := set.Get("b") + assert.True(t, exists) + assert.Equal(t, 2, val) + + // Test size + assert.Equal(t, 3, set.Size()) + + // Test order preservation + assert.Equal(t, []int{1, 2, 3}, set.List()) + assert.Equal(t, []string{"a", "b", "c"}, set.Keys()) + }) + + t.Run("updating existing keys", func(t *testing.T) { + set := NewOrderedSet[string, int]() + + set.Put("a", 1) + set.Put("b", 2) + set.Put("a", 10) // Update existing key + + // Check value was updated but order preserved + assert.Equal(t, []int{10, 2}, set.List()) + assert.Equal(t, []string{"a", "b"}, set.Keys()) + + val, exists := set.Get("a") + assert.True(t, exists) + assert.Equal(t, 10, val) + }) + + t.Run("non-existent keys", func(t *testing.T) { + set := NewOrderedSet[string, int]() + + val, exists := set.Get("nonexistent") + assert.False(t, exists) + assert.Zero(t, val) + }) + + t.Run("concurrent access", func(t *testing.T) { + set := NewOrderedSet[int, string]() + var wg sync.WaitGroup + + // Concurrent writes + for i := 0; i < 100; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + set.Put(n, string(rune(n))) + }(i) + } + + // Concurrent reads + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = set.Size() + _ = set.List() + _ = set.Keys() + }() + } + + wg.Wait() + assert.Equal(t, 100, set.Size()) + }) + + t.Run("different types", func(t *testing.T) { + // Test with different key-value type combinations + set1 := NewOrderedSet[int, string]() + set1.Put(1, "one") + set1.Put(2, "two") + assert.Equal(t, []string{"one", "two"}, set1.List()) + assert.Equal(t, []int{1, 2}, set1.Keys()) + + set2 := NewOrderedSet[string, bool]() + set2.Put("true", true) + set2.Put("false", false) + assert.Equal(t, []bool{true, false}, set2.List()) + assert.Equal(t, []string{"true", "false"}, set2.Keys()) + }) + + t.Run("large number of elements", func(t *testing.T) { + set := NewOrderedSet[int, int]() + n := 10000 + + // Insert elements + for i := 0; i < n; i++ { + set.Put(i, i*2) + } + + assert.Equal(t, n, set.Size()) + + // Verify all elements + for i := 0; i < n; i++ { + val, exists := set.Get(i) + assert.True(t, exists) + assert.Equal(t, i*2, val) + } + + // Verify order + keys := set.Keys() + values := set.List() + for i := 0; i < n; i++ { + assert.Equal(t, i, keys[i]) + assert.Equal(t, i*2, values[i]) + } + }) + + t.Run("zero values", func(t *testing.T) { + set := NewOrderedSet[string, int]() + + // Put zero values + set.Put("zero", 0) + set.Put("empty", 0) + + assert.Equal(t, 2, set.Size()) + val, exists := set.Get("zero") + assert.True(t, exists) + assert.Zero(t, val) + }) +} diff --git a/utils/slices.go b/utils/slices.go index 6020c96452..8e178b8926 100644 --- a/utils/slices.go +++ b/utils/slices.go @@ -1,6 +1,9 @@ package utils -import "slices" +import ( + "reflect" + "slices" +) func Map[T1, T2 any](slice []T1, f func(T1) T2) []T2 { if slice == nil { @@ -34,3 +37,25 @@ func All[T any](slice []T, f func(T) bool) bool { func AnyOf[T comparable](e T, values ...T) bool { return slices.Contains(values, e) } + +// Unique returns a new slice with duplicates removed +func Unique[T comparable](slice []T) []T { + // do not support unique on pointer types, just return the slice as it is + if len(slice) > 0 { + elt := slice[0] + if reflect.TypeOf(elt).Kind() == reflect.Ptr { + return slice + } + } + + result := make([]T, 0, len(slice)) + seen := make(map[T]struct{}, len(slice)) + for _, e := range slice { + if _, ok := seen[e]; !ok { + result = append(result, e) + seen[e] = struct{}{} + } + } + + return result +} diff --git a/utils/slices_test.go b/utils/slices_test.go index 9ef6fcff66..c0f84319e9 100644 --- a/utils/slices_test.go +++ b/utils/slices_test.go @@ -73,3 +73,42 @@ func TestAnyOf(t *testing.T) { assert.False(t, AnyOf("9", "1", "2", "3", "4", "5", "6")) }) } + +func TestUnique(t *testing.T) { + t.Run("nil slice", func(t *testing.T) { + var input []int + actual := Unique(input) + assert.Nil(t, actual) + }) + + t.Run("empty slice", func(t *testing.T) { + input := []int{} + actual := Unique(input) + assert.Empty(t, actual) + }) + + t.Run("slice with no duplicates", func(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + actual := Unique(input) + assert.Equal(t, []int{1, 2, 3, 4, 5}, actual) + }) + + t.Run("slice with duplicates", func(t *testing.T) { + input := []int{1, 2, 2, 3, 3, 3, 4, 5, 5} + actual := Unique(input) + assert.Equal(t, []int{1, 2, 3, 4, 5}, actual) + }) + + t.Run("slice of strings with duplicates", func(t *testing.T) { + input := []string{"a", "b", "b", "c", "c", "c"} + actual := Unique(input) + assert.Equal(t, []string{"a", "b", "c"}, actual) + }) + + t.Run("slice of pointers returns unchanged", func(t *testing.T) { + a, b := new(int), new(int) + input := []*int{a, b, a, b} + actual := Unique(input) + assert.Equal(t, input, actual) + }) +} From 8660d14922ee033d8f0a467173087bef16ddda8d Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 16 Jan 2025 16:23:15 +0800 Subject: [PATCH 2/6] fix Unique() on nil slice --- utils/slices.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/utils/slices.go b/utils/slices.go index 8e178b8926..14eb84093a 100644 --- a/utils/slices.go +++ b/utils/slices.go @@ -40,12 +40,13 @@ func AnyOf[T comparable](e T, values ...T) bool { // Unique returns a new slice with duplicates removed func Unique[T comparable](slice []T) []T { + if len(slice) == 0 { + return slice + } + // do not support unique on pointer types, just return the slice as it is - if len(slice) > 0 { - elt := slice[0] - if reflect.TypeOf(elt).Kind() == reflect.Ptr { - return slice - } + if reflect.TypeOf(slice[0]).Kind() == reflect.Ptr { + return slice } result := make([]T, 0, len(slice)) From 7e5ff456e3b424841e09f3dd5f0e061a6d8360d4 Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 17 Jan 2025 00:12:51 +0800 Subject: [PATCH 3/6] add `storage_root` to leaf data --- rpc/storage.go | 17 ++++++++++++----- rpc/storage_test.go | 7 +++++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/rpc/storage.go b/rpc/storage.go index 62d2c04efa..9f850b32cf 100644 --- a/rpc/storage.go +++ b/rpc/storage.go @@ -100,7 +100,7 @@ func (s *StorageProofResult) String() string { result.WriteString(" Leaves:\n") for i, leaf := range s.ContractsProof.LeavesData { if leaf != nil { - result.WriteString(fmt.Sprintf(" [%d] Nonce: %v, ClassHash: %v\n", i, leaf.Nonce, leaf.ClassHash)) + result.WriteString(fmt.Sprintf(" [%d] Nonce: %v, ClassHash: %v Root: %v\n", i, leaf.Nonce, leaf.ClassHash, leaf.StorageRoot)) } } } @@ -221,6 +221,11 @@ func getContractProof(tr *trie.Trie, state core.StateReader, contracts []felt.Fe return nil, err } + root, err := tr.Root() + if err != nil { + return nil, err + } + nonce, err := state.ContractNonce(&contract) if err != nil { if errors.Is(err, db.ErrKeyNotFound) { // contract does not exist, skip getting leaf data @@ -235,8 +240,9 @@ func getContractProof(tr *trie.Trie, state core.StateReader, contracts []felt.Fe } contractLeavesData[i] = &LeafData{ - Nonce: nonce, - ClassHash: classHash, + Nonce: nonce, + ClassHash: classHash, + StorageRoot: root, } } @@ -342,8 +348,9 @@ type HashToNode struct { } type LeafData struct { - Nonce *felt.Felt `json:"nonce"` - ClassHash *felt.Felt `json:"class_hash"` + Nonce *felt.Felt `json:"nonce"` + ClassHash *felt.Felt `json:"class_hash"` + StorageRoot *felt.Felt `json:"storage_root"` } type ContractProof struct { diff --git a/rpc/storage_test.go b/rpc/storage_test.go index bb43ff1e31..842af8c3ac 100644 --- a/rpc/storage_test.go +++ b/rpc/storage_test.go @@ -372,10 +372,12 @@ func TestStorageProof_VerifyPathfinderResponse(t *testing.T) { { Nonce: utils.HexToFelt(t, "0x0"), ClassHash: utils.HexToFelt(t, "0x772164c9d6179a89e7f1167f099219f47d752304b16ed01f081b6e0b45c93c3"), + // TODO: get the storage root }, { Nonce: utils.HexToFelt(t, "0x0"), ClassHash: utils.HexToFelt(t, "0x78401746828463e2c3f92ebb261fc82f7d4d4c8d9a80a356c44580dab124cb0"), + // TODO: get the storage root }, }, Nodes: []*rpc.HashToNode{ @@ -652,8 +654,9 @@ func TestStorageProof_StorageRoots(t *testing.T) { ContractsProof: &rpc.ContractProof{ LeavesData: []*rpc.LeafData{ { - Nonce: utils.HexToFelt(t, "0x0"), - ClassHash: utils.HexToFelt(t, "0x10455c752b86932ce552f2b0fe81a880746649b9aee7e0d842bf3f52378f9f8"), + Nonce: utils.HexToFelt(t, "0x0"), + ClassHash: utils.HexToFelt(t, "0x10455c752b86932ce552f2b0fe81a880746649b9aee7e0d842bf3f52378f9f8"), + StorageRoot: utils.HexToFelt(t, "0x3ceee867d50b5926bb88c0ec7e0b9c20ae6b537e74aac44b8fcf6bb6da138d9"), }, }, Nodes: []*rpc.HashToNode{ From df00491c775c56a88a7ab9ba439c0575add70084 Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 21 Jan 2025 12:53:08 +0800 Subject: [PATCH 4/6] panic if unique takes in pointer --- utils/slices.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/utils/slices.go b/utils/slices.go index 14eb84093a..1c4a6937b4 100644 --- a/utils/slices.go +++ b/utils/slices.go @@ -38,7 +38,8 @@ func AnyOf[T comparable](e T, values ...T) bool { return slices.Contains(values, e) } -// Unique returns a new slice with duplicates removed +// Unique returns a new slice with duplicates removed. +// Panics if the slice contains pointer types. func Unique[T comparable](slice []T) []T { if len(slice) == 0 { return slice @@ -46,7 +47,7 @@ func Unique[T comparable](slice []T) []T { // do not support unique on pointer types, just return the slice as it is if reflect.TypeOf(slice[0]).Kind() == reflect.Ptr { - return slice + panic("Unique does not support pointer types") } result := make([]T, 0, len(slice)) From 6daf85ed830a43a446125bd27df3636f0a42a6dc Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 21 Jan 2025 12:53:21 +0800 Subject: [PATCH 5/6] remove dups from storage keys --- rpc/storage.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/rpc/storage.go b/rpc/storage.go index 9f850b32cf..c2856cb521 100644 --- a/rpc/storage.go +++ b/rpc/storage.go @@ -165,6 +165,21 @@ func (h *Handler) StorageProof(id BlockID, classes = utils.Unique(classes) contracts = utils.Unique(contracts) + // Remove duplicates from the storage keys + mergedStorageKeys := make(map[felt.Felt][]felt.Felt) + for _, storageKey := range storageKeys { + if existing, ok := mergedStorageKeys[storageKey.Contract]; ok { + mergedStorageKeys[storageKey.Contract] = append(existing, storageKey.Keys...) + } else { + mergedStorageKeys[storageKey.Contract] = storageKey.Keys + } + } + + uniqueStorageKeys := make([]StorageKeys, 0, len(mergedStorageKeys)) + for contract, keys := range mergedStorageKeys { + uniqueStorageKeys = append(uniqueStorageKeys, StorageKeys{Contract: contract, Keys: utils.Unique(keys)}) + } + classProof, err := getClassProof(classTrie, classes) if err != nil { return nil, ErrInternal.CloneWithData(err) @@ -175,7 +190,7 @@ func (h *Handler) StorageProof(id BlockID, return nil, ErrInternal.CloneWithData(err) } - contractStorageProof, err := getContractStorageProof(state, storageKeys) + contractStorageProof, err := getContractStorageProof(state, uniqueStorageKeys) if err != nil { return nil, ErrInternal.CloneWithData(err) } @@ -261,8 +276,7 @@ func getContractStorageProof(state core.StateReader, storageKeys []StorageKeys) } contractStorageProof := trie.NewProofNodeSet() - keys := utils.Unique(storageKey.Keys) - for _, key := range keys { + for _, key := range storageKey.Keys { if err := contractStorageTrie.Prove(&key, contractStorageProof); err != nil { return nil, err } From e8adad87910e6ab75b10aaf83c8b758341c913e2 Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 21 Jan 2025 13:00:51 +0800 Subject: [PATCH 6/6] remove pointer slice --- utils/slices_test.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/utils/slices_test.go b/utils/slices_test.go index c0f84319e9..1a219566b6 100644 --- a/utils/slices_test.go +++ b/utils/slices_test.go @@ -104,11 +104,4 @@ func TestUnique(t *testing.T) { actual := Unique(input) assert.Equal(t, []string{"a", "b", "c"}, actual) }) - - t.Run("slice of pointers returns unchanged", func(t *testing.T) { - a, b := new(int), new(int) - input := []*int{a, b, a, b} - actual := Unique(input) - assert.Equal(t, input, actual) - }) }