diff --git a/l1/eth_subscriber.go b/l1/eth_subscriber.go index 8ee42b1e46..38e7d377db 100644 --- a/l1/eth_subscriber.go +++ b/l1/eth_subscriber.go @@ -2,15 +2,15 @@ package l1 import ( "context" + "encoding/json" "fmt" "math/big" - "strconv" - "strings" "time" "github.com/NethermindEth/juno/l1/contract" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/rpc" @@ -54,28 +54,21 @@ func (s *EthSubscriber) ChainID(ctx context.Context) (*big.Int, error) { } func (s *EthSubscriber) FinalisedHeight(ctx context.Context) (uint64, error) { - finalisedBlock := make(map[string]any, 0) - if err := s.client.CallContext(ctx, &finalisedBlock, "eth_getBlockByNumber", "finalized", false); err != nil { //nolint:misspell + var raw json.RawMessage + if err := s.client.CallContext(ctx, &raw, "eth_getBlockByNumber", "finalized", false); err != nil { //nolint:misspell return 0, fmt.Errorf("get finalised Ethereum block: %w", err) } - number, ok := finalisedBlock["number"] - if !ok { - return 0, fmt.Errorf("number field not present in Ethereum block") + var head *types.Header + if err := json.Unmarshal(raw, &head); err != nil { + return 0, err } - numberString, ok := number.(string) - if !ok { - return 0, fmt.Errorf("block number is not a string: %v", number) + if head == nil { + return 0, fmt.Errorf("finalised block not found") } - numberString = strings.TrimPrefix(numberString, "0x") - numberUint, err := strconv.ParseUint(numberString, 16, 64) - if err != nil { - return 0, fmt.Errorf("parse block number: %s", numberString) - } - - return numberUint, nil + return head.Number.Uint64(), nil } func (s *EthSubscriber) Close() { diff --git a/l1/l1.go b/l1/l1.go index 1cc169e4f8..abad317057 100644 --- a/l1/l1.go +++ b/l1/l1.go @@ -178,6 +178,7 @@ func (c *Client) finalisedHeight(ctx context.Context) uint64 { return finalisedHeight } c.log.Debugw("Failed to retrieve L1 finalised height, retrying...", "error", err) + time.Sleep(c.resubscribeDelay) } } } diff --git a/l1/l1_test.go b/l1/l1_test.go index 84a4d208ea..44174b7d3c 100644 --- a/l1/l1_test.go +++ b/l1/l1_test.go @@ -4,6 +4,8 @@ import ( "context" "errors" "math/big" + "net" + "net/http" "testing" "time" @@ -15,6 +17,9 @@ import ( "github.com/NethermindEth/juno/l1/contract" "github.com/NethermindEth/juno/mocks" "github.com/NethermindEth/juno/utils" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/rpc" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) @@ -154,3 +159,101 @@ func TestEventListener(t *testing.T) { StateRoot: new(felt.Felt), }, got) } + +func newTestL1Client(service service) *rpc.Server { + server := rpc.NewServer() + if err := server.RegisterName("eth", service); err != nil { + panic(err) + } + return server +} + +type service interface { + GetBlockByNumber(ctx context.Context, number string, fullTx bool) (interface{}, error) +} + +type testService struct{} + +func (testService) GetBlockByNumber(ctx context.Context, number string, fullTx bool) (interface{}, error) { + blockHeight := big.NewInt(100) + return types.Header{ + ParentHash: common.Hash{}, + UncleHash: common.Hash{}, + Root: common.Hash{}, + TxHash: common.Hash{}, + ReceiptHash: common.Hash{}, + Bloom: types.Bloom{}, + Difficulty: big.NewInt(0), + Number: blockHeight, + GasLimit: 0, + GasUsed: 0, + Time: 0, + Extra: []byte{}, + }, nil +} + +type testEmptyService struct{} + +func (testEmptyService) GetBlockByNumber(ctx context.Context, number string, fullTx bool) (interface{}, error) { + return nil, nil +} + +type testFaultyService struct{} + +func (testFaultyService) GetBlockByNumber(ctx context.Context, number string, fullTx bool) (interface{}, error) { + return uint(0), nil +} + +func TestEthSubscriber_FinalisedHeight(t *testing.T) { + tests := map[string]struct { + service service + expectedHeight uint64 + expectedError bool + }{ + "testService": { + service: testService{}, + expectedHeight: 100, + expectedError: false, + }, + "testEmptyService": { + service: testEmptyService{}, + expectedHeight: 0, + expectedError: true, + }, + "testFaultyService": { + service: testFaultyService{}, + expectedHeight: 0, + expectedError: true, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + startServer := func(addr string, service service) (*rpc.Server, net.Listener) { + srv := newTestL1Client(service) + l, err := net.Listen("tcp", addr) + if err != nil { + t.Fatal("can't listen:", err) + } + go func() { + _ = http.Serve(l, srv.WebsocketHandler([]string{"*"})) + }() + return srv, l + } + + ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) + defer cancel() + + server, listener := startServer("127.0.0.1:0", test.service) + defer server.Stop() + + subscriber, err := l1.NewEthSubscriber("ws://"+listener.Addr().String(), common.Address{}) + require.NoError(t, err) + defer subscriber.Close() + + height, err := subscriber.FinalisedHeight(ctx) + require.Equal(t, test.expectedHeight, height) + require.Equal(t, test.expectedError, err != nil) + }) + } +}