diff --git a/jsonrpc/websocket.go b/jsonrpc/websocket.go index 2c1e207a54..75e4bb77b6 100644 --- a/jsonrpc/websocket.go +++ b/jsonrpc/websocket.go @@ -5,21 +5,28 @@ import ( "io" "net/http" "strings" + "sync/atomic" "time" "github.com/NethermindEth/juno/utils" "github.com/coder/websocket" ) -const closeReasonMaxBytes = 125 +const ( + closeReasonMaxBytes = 125 + maxConns = 2048 // TODO: an arbitrary default number, should be revisited after monitoring +) type Websocket struct { rpc *Server log utils.SimpleLogger connParams *WebsocketConnParams listener NewRequestListener + shutdown <-chan struct{} - shutdown <-chan struct{} + // Add connection tracking + connCount atomic.Int32 + maxConns int32 } func NewWebsocket(rpc *Server, shutdown <-chan struct{}, log utils.SimpleLogger) *Websocket { @@ -29,11 +36,18 @@ func NewWebsocket(rpc *Server, shutdown <-chan struct{}, log utils.SimpleLogger) connParams: DefaultWebsocketConnParams(), listener: &SelectiveListener{}, shutdown: shutdown, + maxConns: maxConns, } return ws } +// WithMaxConnections sets the maximum number of concurrent websocket connections +func (ws *Websocket) WithMaxConnections(maxConns int32) *Websocket { + ws.maxConns = maxConns + return ws +} + // WithConnParams sanity checks and applies the provided params. func (ws *Websocket) WithConnParams(p *WebsocketConnParams) *Websocket { ws.connParams = p @@ -49,12 +63,23 @@ func (ws *Websocket) WithListener(listener NewRequestListener) *Websocket { // ServeHTTP processes an HTTP request and upgrades it to a websocket connection. // The connection's entire "lifetime" is spent in this function. func (ws *Websocket) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Check connection limit + if ws.connCount.Load() >= ws.maxConns { + ws.log.Warnw("Max websocket connections reached", "maxConns", ws.maxConns) + http.Error(w, "Too many connections", http.StatusServiceUnavailable) + return + } + ws.connCount.Add(1) + conn, err := websocket.Accept(w, r, nil /* TODO: options */) if err != nil { ws.log.Errorw("Failed to upgrade connection", "err", err) return } + // Ensure we decrease the connection count when the connection closes + defer ws.connCount.Add(-1) + // TODO include connection information, such as the remote address, in the logs. ctx, cancel := context.WithCancel(r.Context()) diff --git a/jsonrpc/websocket_test.go b/jsonrpc/websocket_test.go index 4f60377c02..aaad4707d6 100644 --- a/jsonrpc/websocket_test.go +++ b/jsonrpc/websocket_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/NethermindEth/juno/jsonrpc" "github.com/NethermindEth/juno/utils" @@ -93,3 +94,36 @@ func TestSendFromHandler(t *testing.T) { require.NoError(t, conn.Close(websocket.StatusNormalClosure, "")) } + +func TestWebsocketConnectionLimit(t *testing.T) { + rpc := jsonrpc.NewServer(1, utils.NewNopZapLogger()) + ws := jsonrpc.NewWebsocket(rpc, nil, utils.NewNopZapLogger()).WithMaxConnections(2) + httpSrv := httptest.NewServer(ws) + defer httpSrv.Close() + + // First connection should succeed + conn1, resp1, err := websocket.Dial(context.Background(), httpSrv.URL, nil) //nolint:bodyclose + require.NoError(t, err) + require.Equal(t, http.StatusSwitchingProtocols, resp1.StatusCode) + defer conn1.Close(websocket.StatusNormalClosure, "") + + // Second connection should succeed + conn2, resp2, err := websocket.Dial(context.Background(), httpSrv.URL, nil) //nolint:bodyclose + require.NoError(t, err) + require.Equal(t, http.StatusSwitchingProtocols, resp2.StatusCode) + defer conn2.Close(websocket.StatusNormalClosure, "") + + // Third connection should fail with 503 Service Unavailable + _, resp3, err := websocket.Dial(context.Background(), httpSrv.URL, nil) //nolint:bodyclose + require.Error(t, err) + require.Equal(t, http.StatusServiceUnavailable, resp3.StatusCode) + + // Close one connection and try again - should succeed + require.NoError(t, conn1.Close(websocket.StatusNormalClosure, "")) + time.Sleep(10 * time.Millisecond) // Give the server time to clean up + + conn4, resp4, err := websocket.Dial(context.Background(), httpSrv.URL, nil) //nolint:bodyclose + require.NoError(t, err) + require.Equal(t, http.StatusSwitchingProtocols, resp4.StatusCode) + require.NoError(t, conn4.Close(websocket.StatusNormalClosure, "")) +} diff --git a/mocks/mock_synchronizer.go b/mocks/mock_synchronizer.go index d04a733db0..970107be8c 100644 --- a/mocks/mock_synchronizer.go +++ b/mocks/mock_synchronizer.go @@ -128,18 +128,18 @@ func (mr *MockSyncReaderMockRecorder) SubscribeNewHeads() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeNewHeads", reflect.TypeOf((*MockSyncReader)(nil).SubscribeNewHeads)) } -// SubscribePendingTxs mocks base method. -func (m *MockSyncReader) SubscribePendingTxs() sync.PendingTxSubscription { +// SubscribePending mocks base method. +func (m *MockSyncReader) SubscribePending() sync.PendingSubscription { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SubscribePendingTxs") - ret0, _ := ret[0].(sync.PendingTxSubscription) + ret := m.ctrl.Call(m, "SubscribePending") + ret0, _ := ret[0].(sync.PendingSubscription) return ret0 } -// SubscribePendingTxs indicates an expected call of SubscribePendingTxs. -func (mr *MockSyncReaderMockRecorder) SubscribePendingTxs() *gomock.Call { +// SubscribePending indicates an expected call of SubscribePending. +func (mr *MockSyncReaderMockRecorder) SubscribePending() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribePendingTxs", reflect.TypeOf((*MockSyncReader)(nil).SubscribePendingTxs)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribePending", reflect.TypeOf((*MockSyncReader)(nil).SubscribePending)) } // SubscribeReorg mocks base method. diff --git a/rpc/block.go b/rpc/block.go index 936f3f8a53..a5804c28b2 100644 --- a/rpc/block.go +++ b/rpc/block.go @@ -67,6 +67,22 @@ type BlockID struct { Number uint64 } +func (b *BlockID) IsLatest() bool { + return b.Latest +} + +func (b *BlockID) IsPending() bool { + return b.Pending +} + +func (b *BlockID) GetHash() *felt.Felt { + return b.Hash +} + +func (b *BlockID) GetNumber() uint64 { + return b.Number +} + func (b *BlockID) UnmarshalJSON(data []byte) error { if string(data) == `"latest"` { b.Latest = true diff --git a/rpc/handlers.go b/rpc/handlers.go index a3db98a346..a4c7b45b36 100644 --- a/rpc/handlers.go +++ b/rpc/handlers.go @@ -70,7 +70,6 @@ var ( ErrInvalidSubscriptionID = &jsonrpc.Error{Code: 66, Message: "Invalid subscription id"} ErrTooManyAddressesInFilter = &jsonrpc.Error{Code: 67, Message: "Too many addresses in filter sender_address filter"} ErrTooManyBlocksBack = &jsonrpc.Error{Code: 68, Message: fmt.Sprintf("Cannot go back more than %v blocks", maxBlocksBack)} - ErrCallOnPending = &jsonrpc.Error{Code: 69, Message: "This method does not support being called on the pending block"} ) const ( @@ -93,11 +92,11 @@ type Handler struct { vm vm.VM log utils.Logger - version string - newHeads *feed.Feed[*core.Header] - reorgs *feed.Feed[*sync.ReorgBlockRange] - pendingTxs *feed.Feed[[]core.Transaction] - l1Heads *feed.Feed[*core.L1Head] + version string + newHeads *feed.Feed[*core.Header] + reorgs *feed.Feed[*sync.ReorgBlockRange] + pendingBlock *feed.Feed[*core.Block] + l1Heads *feed.Feed[*core.L1Head] idgen func() uint64 subscriptions stdsync.Map // map[uint64]*subscription @@ -135,11 +134,11 @@ func New(bcReader blockchain.Reader, syncReader sync.Reader, virtualMachine vm.V } return n }, - version: version, - newHeads: feed.New[*core.Header](), - reorgs: feed.New[*sync.ReorgBlockRange](), - pendingTxs: feed.New[[]core.Transaction](), - l1Heads: feed.New[*core.L1Head](), + version: version, + newHeads: feed.New[*core.Header](), + reorgs: feed.New[*sync.ReorgBlockRange](), + pendingBlock: feed.New[*core.Block](), + l1Heads: feed.New[*core.L1Head](), blockTraceCache: lru.NewCache[traceCacheKey, []TracedBlockTransaction](traceCacheSize), filterLimit: math.MaxUint, @@ -181,16 +180,17 @@ func (h *Handler) WithGateway(gatewayClient Gateway) *Handler { func (h *Handler) Run(ctx context.Context) error { newHeadsSub := h.syncReader.SubscribeNewHeads().Subscription reorgsSub := h.syncReader.SubscribeReorg().Subscription - pendingTxsSub := h.syncReader.SubscribePendingTxs().Subscription l1HeadsSub := h.bcReader.SubscribeL1Head().Subscription + pendingBlock := h.syncReader.SubscribePending().Subscription defer newHeadsSub.Unsubscribe() defer reorgsSub.Unsubscribe() - defer pendingTxsSub.Unsubscribe() defer l1HeadsSub.Unsubscribe() + defer pendingBlock.Unsubscribe() + feed.Tee(newHeadsSub, h.newHeads) feed.Tee(reorgsSub, h.reorgs) - feed.Tee(pendingTxsSub, h.pendingTxs) feed.Tee(l1HeadsSub, h.l1Heads) + feed.Tee(pendingBlock, h.pendingBlock) <-ctx.Done() h.subscriptions.Range(func(key, value any) bool { diff --git a/rpc/helpers.go b/rpc/helpers.go index 4408ef00dc..59f4728ba0 100644 --- a/rpc/helpers.go +++ b/rpc/helpers.go @@ -60,22 +60,23 @@ func (h *Handler) blockByID(id *BlockID) (*core.Block, *jsonrpc.Error) { return block, nil } -func (h *Handler) blockHeaderByID(id *BlockID) (*core.Header, *jsonrpc.Error) { +func (h *Handler) blockHeaderByID(id BlockIdentifier) (*core.Header, *jsonrpc.Error) { var header *core.Header var err error + switch { - case id.Latest: + case id.IsLatest(): header, err = h.bcReader.HeadsHeader() - case id.Hash != nil: - header, err = h.bcReader.BlockHeaderByHash(id.Hash) - case id.Pending: + case id.GetHash() != nil: + header, err = h.bcReader.BlockHeaderByHash(id.GetHash()) + case id.IsPending(): var pending *sync.Pending pending, err = h.syncReader.Pending() if err == nil { header = pending.Block.Header } default: - header, err = h.bcReader.BlockHeaderByNumber(id.Number) + header, err = h.bcReader.BlockHeaderByNumber(id.GetNumber()) } if err != nil { diff --git a/rpc/subscriptions.go b/rpc/subscriptions.go index 8ee8425604..804deeed65 100644 --- a/rpc/subscriptions.go +++ b/rpc/subscriptions.go @@ -3,6 +3,7 @@ package rpc import ( "context" "encoding/json" + "errors" "time" "github.com/NethermindEth/juno/blockchain" @@ -10,7 +11,8 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/feed" "github.com/NethermindEth/juno/jsonrpc" - "github.com/NethermindEth/juno/sync" + junoSync "github.com/NethermindEth/juno/sync" + "github.com/NethermindEth/juno/utils" "github.com/sourcegraph/conc" ) @@ -24,15 +26,75 @@ var ( subscribeTxStatusTickerDuration = 5 * time.Second ) +var ( + _ BlockIdentifier = (*SubscriptionBlockID)(nil) + _ BlockIdentifier = (*BlockID)(nil) +) + type SubscriptionResponse struct { Version string `json:"jsonrpc"` Method string `json:"method"` Params any `json:"params"` } +type BlockIdentifier interface { + IsLatest() bool + IsPending() bool + GetHash() *felt.Felt + GetNumber() uint64 + UnmarshalJSON(data []byte) error +} + +// As per the spec, this is the same as BlockID, but without `pending` +type SubscriptionBlockID struct { + Latest bool + Hash *felt.Felt + Number uint64 +} + +func (b *SubscriptionBlockID) IsLatest() bool { + return b.Latest +} + +func (b *SubscriptionBlockID) IsPending() bool { + return false // Subscription blocks can't be pending +} + +func (b *SubscriptionBlockID) GetHash() *felt.Felt { + return b.Hash +} + +func (b *SubscriptionBlockID) GetNumber() uint64 { + return b.Number +} + +func (b *SubscriptionBlockID) UnmarshalJSON(data []byte) error { + if string(data) == `"latest"` { + b.Latest = true + } else { + jsonObject := make(map[string]json.RawMessage) + if err := json.Unmarshal(data, &jsonObject); err != nil { + return err + } + hash, ok := jsonObject["block_hash"] + if ok { + b.Hash = new(felt.Felt) + return json.Unmarshal(hash, b.Hash) + } + + number, ok := jsonObject["block_number"] + if ok { + return json.Unmarshal(number, &b.Number) + } + + return errors.New("cannot unmarshal block id") + } + return nil +} + // SubscribeEvents creates a WebSocket stream which will fire events for new Starknet events with applied filters func (h *Handler) SubscribeEvents(ctx context.Context, fromAddr *felt.Felt, keys [][]felt.Felt, - blockID *BlockID, + blockID *SubscriptionBlockID, ) (*SubscriptionID, *jsonrpc.Error) { w, ok := jsonrpc.ConnFromContext(ctx) if !ok { @@ -47,10 +109,6 @@ func (h *Handler) SubscribeEvents(ctx context.Context, fromAddr *felt.Felt, keys return nil, ErrTooManyKeysInFilter } - if blockID != nil && blockID.Pending { - return nil, ErrCallOnPending - } - requestedHeader, headHeader, rpcErr := h.resolveBlockRange(blockID) if rpcErr != nil { return nil, rpcErr @@ -66,23 +124,54 @@ func (h *Handler) SubscribeEvents(ctx context.Context, fromAddr *felt.Felt, keys headerSub := h.newHeads.Subscribe() reorgSub := h.reorgs.Subscribe() // as per the spec, reorgs are also sent in the events subscription + pendingSub := h.pendingBlock.Subscribe() sub.wg.Go(func() { defer func() { h.unsubscribe(sub, id) headerSub.Unsubscribe() reorgSub.Unsubscribe() + pendingSub.Unsubscribe() }() - // The specification doesn't enforce ordering of events therefore events from new blocks can be sent before + // The specification doesn't enforce ordering of events, therefore, events from new blocks can be sent before // old blocks. var wg conc.WaitGroup wg.Go(func() { + // Stores the transaction hash -> number of events + eventsPreviouslySent := make(map[blockchain.FilteredEvent]struct{}) + for { select { case <-subscriptionCtx.Done(): return case header := <-headerSub.Recv(): - h.processEvents(subscriptionCtx, w, id, header.Number, header.Number, fromAddr, keys) + // During syncing the events from the new head still need to be sent as there is no pending block. + // However, it is not easy to tell when the node is syncing. + // To solve this issue, we can send the events regardless, and if the node is done syncing, then the + // latest header events would have been sent when the pending block was updated. Hence, + // trying to resend the event should be of no consequences and the map can be safely emptied. + h.processEvents(subscriptionCtx, w, id, header.Number, header.Number, fromAddr, keys, eventsPreviouslySent) + + block, err := h.bcReader.BlockByNumber(header.Number) + if err != nil { + h.log.Warnw("Error retrieving block", "block number", header.Number, "err", err) + return + } + + for i, r := range block.Receipts { + for _, e := range r.Events { + fe := blockchain.FilteredEvent{ + Event: e, + BlockNumber: header.Number, + BlockHash: header.Hash, + TransactionHash: block.Transactions[i].Hash(), + } + + delete(eventsPreviouslySent, fe) + } + } + case pending := <-pendingSub.Recv(): + h.processEvents(subscriptionCtx, w, id, pending.Number, pending.Number, fromAddr, keys, eventsPreviouslySent) } } }) @@ -92,7 +181,7 @@ func (h *Handler) SubscribeEvents(ctx context.Context, fromAddr *felt.Felt, keys }) wg.Go(func() { - h.processEvents(subscriptionCtx, w, id, requestedHeader.Number, headHeader.Number, fromAddr, keys) + h.processEvents(subscriptionCtx, w, id, requestedHeader.Number, headHeader.Number, fromAddr, keys, nil) }) wg.Wait() @@ -155,14 +244,14 @@ func (h *Handler) SubscribeTransactionStatus(ctx context.Context, txHash felt.Fe } h.subscriptions.Store(id, sub) - l2HeadSub := h.newHeads.Subscribe() + pendingSub := h.pendingBlock.Subscribe() l1HeadSub := h.l1Heads.Subscribe() reorgSub := h.reorgs.Subscribe() sub.wg.Go(func() { defer func() { h.unsubscribe(sub, id) - l2HeadSub.Unsubscribe() + pendingSub.Unsubscribe() l1HeadSub.Unsubscribe() reorgSub.Unsubscribe() }() @@ -187,11 +276,13 @@ func (h *Handler) SubscribeTransactionStatus(ctx context.Context, txHash felt.Fe select { case <-subscriptionCtx.Done(): return - case <-l2HeadSub.Recv(): - // A new block has been added to the DB, hence, check if transaction has reached l2 finality, - // if not, check feeder. - // We could use a separate timer to periodically check for the transaction status at feeder - // gateway, however, for the time being new l2 head update is sufficient. + case <-pendingSub.Recv(): + // Pending block has been updated, hence, check if transaction has reached l2 finality, if not, + // check feeder. + // TransactionStatus calls TransactionReceiptByHash which checks the pending block if it contains + // a transaction and if it does, then the appropriate transaction status is returned. + // Therefore, we don't need to explicitly find the transaction in the pending block received from + // the pendingSub. if curStatus.Finality < TxnStatusAcceptedOnL2 { prevStatus := curStatus curStatus, rpcErr = h.TransactionStatus(subscriptionCtx, txHash) @@ -246,7 +337,9 @@ func (h *Handler) SubscribeTransactionStatus(ctx context.Context, txHash felt.Fe return &SubscriptionID{ID: id}, nil } -func (h *Handler) processEvents(ctx context.Context, w jsonrpc.Conn, id, from, to uint64, fromAddr *felt.Felt, keys [][]felt.Felt) { +func (h *Handler) processEvents(ctx context.Context, w jsonrpc.Conn, id, from, to uint64, fromAddr *felt.Felt, + keys [][]felt.Felt, eventsPreviouslySent map[blockchain.FilteredEvent]struct{}, +) { filter, err := h.bcReader.EventFilter(fromAddr, keys) if err != nil { h.log.Warnw("Error creating event filter", "err", err) @@ -266,7 +359,7 @@ func (h *Handler) processEvents(ctx context.Context, w jsonrpc.Conn, id, from, t return } - err = sendEvents(ctx, w, filteredEvents, id) + err = sendEvents(ctx, w, filteredEvents, eventsPreviouslySent, id) if err != nil { h.log.Warnw("Error sending events", "err", err) return @@ -279,7 +372,7 @@ func (h *Handler) processEvents(ctx context.Context, w jsonrpc.Conn, id, from, t return } - err = sendEvents(ctx, w, filteredEvents, id) + err = sendEvents(ctx, w, filteredEvents, eventsPreviouslySent, id) if err != nil { h.log.Warnw("Error sending events", "err", err) return @@ -287,12 +380,21 @@ func (h *Handler) processEvents(ctx context.Context, w jsonrpc.Conn, id, from, t } } -func sendEvents(ctx context.Context, w jsonrpc.Conn, events []*blockchain.FilteredEvent, id uint64) error { +func sendEvents(ctx context.Context, w jsonrpc.Conn, events []*blockchain.FilteredEvent, + eventsPreviouslySent map[blockchain.FilteredEvent]struct{}, id uint64, +) error { for _, event := range events { select { case <-ctx.Done(): return ctx.Err() default: + if eventsPreviouslySent != nil { + if _, ok := eventsPreviouslySent[*event]; ok { + continue + } + eventsPreviouslySent[*event] = struct{}{} + } + emittedEvent := &EmittedEvent{ BlockNumber: &event.BlockNumber, // This always be filled as subscribeEvents cannot be called on pending block BlockHash: event.BlockHash, @@ -326,16 +428,12 @@ func sendEvents(ctx context.Context, w jsonrpc.Conn, events []*blockchain.Filter } // SubscribeNewHeads creates a WebSocket stream which will fire events when a new block header is added. -func (h *Handler) SubscribeNewHeads(ctx context.Context, blockID *BlockID) (*SubscriptionID, *jsonrpc.Error) { +func (h *Handler) SubscribeNewHeads(ctx context.Context, blockID *SubscriptionBlockID) (*SubscriptionID, *jsonrpc.Error) { w, ok := jsonrpc.ConnFromContext(ctx) if !ok { return nil, jsonrpc.Err(jsonrpc.MethodNotFound, nil) } - if blockID != nil && blockID.Pending { - return nil, ErrCallOnPending - } - startHeader, latestHeader, rpcErr := h.resolveBlockRange(blockID) if rpcErr != nil { return nil, rpcErr @@ -402,30 +500,28 @@ func (h *Handler) SubscribePendingTxs(ctx context.Context, getDetails *bool, sen } h.subscriptions.Store(id, sub) - pendingTxsSub := h.pendingTxs.Subscribe() + pendingSub := h.pendingBlock.Subscribe() sub.wg.Go(func() { defer func() { h.unsubscribe(sub, id) - pendingTxsSub.Unsubscribe() + pendingSub.Unsubscribe() }() - h.processPendingTxs(subscriptionCtx, getDetails != nil && *getDetails, senderAddr, pendingTxsSub, w, id) + h.processPendingTxs(subscriptionCtx, getDetails != nil && *getDetails, senderAddr, pendingSub, w, id) }) return &SubscriptionID{ID: id}, nil } func (h *Handler) processPendingTxs(ctx context.Context, getDetails bool, senderAddr []felt.Felt, - pendingTxsSub *feed.Subscription[[]core.Transaction], - w jsonrpc.Conn, - id uint64, + pendingSub *feed.Subscription[*core.Block], w jsonrpc.Conn, id uint64, ) { for { select { case <-ctx.Done(): return - case pendingTxs := <-pendingTxsSub.Recv(): - filteredTxs := h.filterTxs(pendingTxs, getDetails, senderAddr) + case pendingBlock := <-pendingSub.Recv(): + filteredTxs := h.filterTxs(pendingBlock.Transactions, getDetails, senderAddr) if err := h.sendPendingTxs(w, filteredTxs, id); err != nil { h.log.Warnw("Error sending pending transactions", "err", err) return @@ -510,17 +606,21 @@ func (h *Handler) sendPendingTxs(w jsonrpc.Conn, result any, id uint64) error { // resolveBlockRange returns the start and latest headers based on the blockID. // It will also do some sanity checks and return errors if the blockID is invalid. -func (h *Handler) resolveBlockRange(blockID *BlockID) (*core.Header, *core.Header, *jsonrpc.Error) { +func (h *Handler) resolveBlockRange(id BlockIdentifier) (*core.Header, *core.Header, *jsonrpc.Error) { latestHeader, err := h.bcReader.HeadsHeader() if err != nil { return nil, nil, ErrInternal.CloneWithData(err.Error()) } - if blockID == nil || blockID.Latest { + if utils.IsNil(id) { + return latestHeader, latestHeader, nil + } + + if id.IsLatest() { return latestHeader, latestHeader, nil } - startHeader, rpcErr := h.blockHeaderByID(blockID) + startHeader, rpcErr := h.blockHeaderByID(id) if rpcErr != nil { return nil, nil, rpcErr } @@ -596,7 +696,7 @@ func (h *Handler) sendHeader(w jsonrpc.Conn, header *core.Header, id uint64) err return err } -func (h *Handler) processReorgs(ctx context.Context, reorgSub *feed.Subscription[*sync.ReorgBlockRange], w jsonrpc.Conn, id uint64) { +func (h *Handler) processReorgs(ctx context.Context, reorgSub *feed.Subscription[*junoSync.ReorgBlockRange], w jsonrpc.Conn, id uint64) { for { select { case <-ctx.Done(): @@ -617,7 +717,7 @@ type ReorgEvent struct { EndBlockNum uint64 `json:"ending_block_number"` } -func (h *Handler) sendReorg(w jsonrpc.Conn, reorg *sync.ReorgBlockRange, id uint64) error { +func (h *Handler) sendReorg(w jsonrpc.Conn, reorg *junoSync.ReorgBlockRange, id uint64) error { resp, err := json.Marshal(jsonrpc.Request{ Version: "2.0", Method: "starknet_subscriptionReorg", diff --git a/rpc/subscriptions_test.go b/rpc/subscriptions_test.go index ba15b4a777..eb9d05c5d0 100644 --- a/rpc/subscriptions_test.go +++ b/rpc/subscriptions_test.go @@ -72,30 +72,6 @@ func TestSubscribeEvents(t *testing.T) { assert.Equal(t, ErrTooManyKeysInFilter, rpcErr) }) - t.Run("Return error if called on pending block", func(t *testing.T) { - mockCtrl := gomock.NewController(t) - t.Cleanup(mockCtrl.Finish) - - mockChain := mocks.NewMockReader(mockCtrl) - mockSyncer := mocks.NewMockSyncReader(mockCtrl) - handler := New(mockChain, mockSyncer, nil, "", log) - - keys := make([][]felt.Felt, 1) - fromAddr := new(felt.Felt).SetBytes([]byte("from_address")) - blockID := &BlockID{Pending: true} - - serverConn, _ := net.Pipe() - t.Cleanup(func() { - require.NoError(t, serverConn.Close()) - }) - - subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) - - id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, blockID) - assert.Zero(t, id) - assert.Equal(t, ErrCallOnPending, rpcErr) - }) - t.Run("Return error if block is too far back", func(t *testing.T) { mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) @@ -106,7 +82,7 @@ func TestSubscribeEvents(t *testing.T) { keys := make([][]felt.Felt, 1) fromAddr := new(felt.Felt).SetBytes([]byte("from_address")) - blockID := &BlockID{Number: 0} + blockID := &SubscriptionBlockID{Number: 0} serverConn, _ := net.Pipe() t.Cleanup(func() { @@ -176,6 +152,50 @@ func TestSubscribeEvents(t *testing.T) { }) } + t.Run("Events from new blocks", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + t.Cleanup(mockCtrl.Finish) + + mockChain := mocks.NewMockReader(mockCtrl) + mockSyncer := mocks.NewMockSyncReader(mockCtrl) + mockEventFilterer := mocks.NewMockEventFilterer(mockCtrl) + + handler := New(mockChain, mockSyncer, nil, "", log) + newHeadFeed := feed.New[*core.Header]() + handler.newHeads = newHeadFeed + + mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: b1.Number}, nil) + mockChain.EXPECT().EventFilter(gomock.Any(), gomock.Any()).Return(mockEventFilterer, nil).AnyTimes() + mockChain.EXPECT().BlockByNumber(gomock.Any()).Return(b1, nil).AnyTimes() + mockEventFilterer.EXPECT().SetRangeEndBlockByNumber(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return([]*blockchain.FilteredEvent{filteredEvents[0]}, nil, nil) + mockEventFilterer.EXPECT().Close().AnyTimes() + + serverConn, clientConn := net.Pipe() + t.Cleanup(func() { + require.NoError(t, serverConn.Close()) + require.NoError(t, clientConn.Close()) + }) + + ctx, cancel := context.WithCancel(context.Background()) + subCtx := context.WithValue(ctx, jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) + id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, nil) + require.Nil(t, rpcErr) + + newHeadFeed.Send(&core.Header{Number: b1.Number}) + + resp, err := marshalSubEventsResp(emittedEvents[0], id.ID) + require.NoError(t, err) + + got := make([]byte, len(resp)) + _, err = clientConn.Read(got) + require.NoError(t, err) + assert.Equal(t, string(resp), string(got)) + + cancel() + time.Sleep(100 * time.Millisecond) + }) + t.Run("Events from old blocks", func(t *testing.T) { mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) @@ -201,7 +221,7 @@ func TestSubscribeEvents(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) subCtx := context.WithValue(ctx, jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) - id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, &BlockID{Number: b1.Number}) + id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, &SubscriptionBlockID{Number: b1.Number}) require.Nil(t, rpcErr) var marshalledResponses [][]byte @@ -249,7 +269,7 @@ func TestSubscribeEvents(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) subCtx := context.WithValue(ctx, jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) - id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, &BlockID{Number: b1.Number}) + id, rpcErr := handler.SubscribeEvents(subCtx, fromAddr, keys, &SubscriptionBlockID{Number: b1.Number}) require.Nil(t, rpcErr) var marshalledResponses [][]byte @@ -268,7 +288,7 @@ func TestSubscribeEvents(t *testing.T) { cancel() }) - t.Run("Events from new blocks", func(t *testing.T) { + t.Run("Events from pending block without duplicates", func(t *testing.T) { mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) @@ -277,8 +297,8 @@ func TestSubscribeEvents(t *testing.T) { mockEventFilterer := mocks.NewMockEventFilterer(mockCtrl) handler := New(mockChain, mockSyncer, nil, "", log) - headerFeed := feed.New[*core.Header]() - handler.newHeads = headerFeed + pendingFeed := feed.New[*core.Block]() + handler.pendingBlock = pendingFeed mockChain.EXPECT().HeadsHeader().Return(&core.Header{Number: b1.Number}, nil) mockChain.EXPECT().EventFilter(fromAddr, keys).Return(mockEventFilterer, nil) @@ -306,12 +326,13 @@ func TestSubscribeEvents(t *testing.T) { require.NoError(t, err) assert.Equal(t, string(resp), string(got)) + // Pending block events, due to the use of mocks events which were sent before are resent. mockChain.EXPECT().EventFilter(fromAddr, keys).Return(mockEventFilterer, nil) mockEventFilterer.EXPECT().SetRangeEndBlockByNumber(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(2) mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return([]*blockchain.FilteredEvent{filteredEvents[1]}, nil, nil) - headerFeed.Send(&core.Header{Number: b1.Number + 1}) + pendingFeed.Send(&core.Block{Header: &core.Header{Number: b1.Number + 1}}) resp, err = marshalSubEventsResp(emittedEvents[1], id.ID) require.NoError(t, err) @@ -321,6 +342,22 @@ func TestSubscribeEvents(t *testing.T) { require.NoError(t, err) assert.Equal(t, string(resp), string(got)) + mockChain.EXPECT().EventFilter(fromAddr, keys).Return(mockEventFilterer, nil) + + mockEventFilterer.EXPECT().SetRangeEndBlockByNumber(gomock.Any(), gomock.Any()).Return(nil).MaxTimes(2) + mockEventFilterer.EXPECT().Events(gomock.Any(), gomock.Any()).Return([]*blockchain. + FilteredEvent{filteredEvents[1], filteredEvents[0]}, nil, nil) + + pendingFeed.Send(&core.Block{Header: &core.Header{Number: b1.Number + 1}}) + + resp, err = marshalSubEventsResp(emittedEvents[0], id.ID) + require.NoError(t, err) + + got = make([]byte, len(resp)) + _, err = clientConn.Read(got) + require.NoError(t, err) + assert.Equal(t, string(resp), string(got)) + cancel() time.Sleep(100 * time.Millisecond) }) @@ -436,9 +473,9 @@ func TestSubscribeTxnStatus(t *testing.T) { mockSyncer := mocks.NewMockSyncReader(mockCtrl) handler := New(mockChain, mockSyncer, nil, "", log) handler.WithFeeder(client) - l2Feed := feed.New[*core.Header]() + pendingFeed := feed.New[*core.Block]() l1Feed := feed.New[*core.L1Head]() - handler.newHeads = l2Feed + handler.pendingBlock = pendingFeed handler.l1Heads = l1Feed block, err := gw.BlockByNumber(context.Background(), 38748) @@ -475,7 +512,7 @@ func TestSubscribeTxnStatus(t *testing.T) { mockChain.EXPECT().Receipt(txHash).Return(block.Receipts[0], block.Hash, block.Number, nil) mockChain.EXPECT().L1Head().Return(nil, db.ErrKeyNotFound) - l2Feed.Send(&core.Header{Number: block.Number + 1}) + pendingFeed.Send(&core.Block{Header: &core.Header{Number: block.Number + 1}}) b, err = TxnStatusAcceptedOnL2.MarshalText() require.NoError(t, err) @@ -507,16 +544,16 @@ func TestSubscribeTxnStatus(t *testing.T) { } type fakeSyncer struct { - newHeads *feed.Feed[*core.Header] - reorgs *feed.Feed[*sync.ReorgBlockRange] - pendingTxs *feed.Feed[[]core.Transaction] + newHeads *feed.Feed[*core.Header] + reorgs *feed.Feed[*sync.ReorgBlockRange] + pending *feed.Feed[*core.Block] } func newFakeSyncer() *fakeSyncer { return &fakeSyncer{ - newHeads: feed.New[*core.Header](), - reorgs: feed.New[*sync.ReorgBlockRange](), - pendingTxs: feed.New[[]core.Transaction](), + newHeads: feed.New[*core.Header](), + reorgs: feed.New[*sync.ReorgBlockRange](), + pending: feed.New[*core.Block](), } } @@ -528,8 +565,8 @@ func (fs *fakeSyncer) SubscribeReorg() sync.ReorgSubscription { return sync.ReorgSubscription{Subscription: fs.reorgs.Subscribe()} } -func (fs *fakeSyncer) SubscribePendingTxs() sync.PendingTxSubscription { - return sync.PendingTxSubscription{Subscription: fs.pendingTxs.Subscribe()} +func (fs *fakeSyncer) SubscribePending() sync.PendingSubscription { + return sync.PendingSubscription{Subscription: fs.pending.Subscribe()} } func (fs *fakeSyncer) StartingBlockNumber() (uint64, error) { @@ -547,26 +584,6 @@ func (fs *fakeSyncer) PendingState() (core.StateReader, func() error, error) { r func TestSubscribeNewHeads(t *testing.T) { log := utils.NewNopZapLogger() - t.Run("Return error if called on pending block", func(t *testing.T) { - mockCtrl := gomock.NewController(t) - t.Cleanup(mockCtrl.Finish) - - mockChain := mocks.NewMockReader(mockCtrl) - mockSyncer := mocks.NewMockSyncReader(mockCtrl) - handler := New(mockChain, mockSyncer, nil, "", log) - - serverConn, _ := net.Pipe() - t.Cleanup(func() { - require.NoError(t, serverConn.Close()) - }) - - subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn}) - - id, rpcErr := handler.SubscribeNewHeads(subCtx, &BlockID{Pending: true}) - assert.Zero(t, id) - assert.Equal(t, ErrCallOnPending, rpcErr) - }) - t.Run("Return error if block is too far back", func(t *testing.T) { mockCtrl := gomock.NewController(t) t.Cleanup(mockCtrl.Finish) @@ -575,7 +592,7 @@ func TestSubscribeNewHeads(t *testing.T) { mockSyncer := mocks.NewMockSyncReader(mockCtrl) handler := New(mockChain, mockSyncer, nil, "", log) - blockID := &BlockID{Number: 0} + blockID := &SubscriptionBlockID{Number: 0} serverConn, _ := net.Pipe() t.Cleanup(func() { @@ -856,12 +873,14 @@ func TestSubscribePendingTxs(t *testing.T) { hash4 := new(felt.Felt).SetUint64(4) hash5 := new(felt.Felt).SetUint64(5) - syncer.pendingTxs.Send([]core.Transaction{ - &core.InvokeTransaction{TransactionHash: hash1, SenderAddress: addr1}, - &core.DeclareTransaction{TransactionHash: hash2, SenderAddress: addr2}, - &core.DeployTransaction{TransactionHash: hash3}, - &core.DeployAccountTransaction{DeployTransaction: core.DeployTransaction{TransactionHash: hash4}}, - &core.L1HandlerTransaction{TransactionHash: hash5}, + syncer.pending.Send(&core.Block{ + Transactions: []core.Transaction{ + &core.InvokeTransaction{TransactionHash: hash1, SenderAddress: addr1}, + &core.DeclareTransaction{TransactionHash: hash2, SenderAddress: addr2}, + &core.DeployTransaction{TransactionHash: hash3}, + &core.DeployAccountTransaction{DeployTransaction: core.DeployTransaction{TransactionHash: hash4}}, + &core.L1HandlerTransaction{TransactionHash: hash5}, + }, }) want := `{"jsonrpc":"2.0","method":"starknet_subscriptionPendingTransactions","params":{"result":["0x1","0x2","0x3","0x4","0x5"],"subscription_id":%d}}` @@ -896,14 +915,16 @@ func TestSubscribePendingTxs(t *testing.T) { hash7 := new(felt.Felt).SetUint64(7) addr7 := new(felt.Felt).SetUint64(77) - syncer.pendingTxs.Send([]core.Transaction{ - &core.InvokeTransaction{TransactionHash: hash1, SenderAddress: addr1}, - &core.DeclareTransaction{TransactionHash: hash2, SenderAddress: addr2}, - &core.DeployTransaction{TransactionHash: hash3}, - &core.DeployAccountTransaction{DeployTransaction: core.DeployTransaction{TransactionHash: hash4}}, - &core.L1HandlerTransaction{TransactionHash: hash5}, - &core.InvokeTransaction{TransactionHash: hash6, SenderAddress: addr6}, - &core.DeclareTransaction{TransactionHash: hash7, SenderAddress: addr7}, + syncer.pending.Send(&core.Block{ + Transactions: []core.Transaction{ + &core.InvokeTransaction{TransactionHash: hash1, SenderAddress: addr1}, + &core.DeclareTransaction{TransactionHash: hash2, SenderAddress: addr2}, + &core.DeployTransaction{TransactionHash: hash3}, + &core.DeployAccountTransaction{DeployTransaction: core.DeployTransaction{TransactionHash: hash4}}, + &core.L1HandlerTransaction{TransactionHash: hash5}, + &core.InvokeTransaction{TransactionHash: hash6, SenderAddress: addr6}, + &core.DeclareTransaction{TransactionHash: hash7, SenderAddress: addr7}, + }, }) want := `{"jsonrpc":"2.0","method":"starknet_subscriptionPendingTransactions","params":{"result":["0x1","0x2"],"subscription_id":%d}}` @@ -922,21 +943,23 @@ func TestSubscribePendingTxs(t *testing.T) { got := sendWsMessage(t, ctx, conn, subMsg) require.Equal(t, subResp(id), got) - syncer.pendingTxs.Send([]core.Transaction{ - &core.InvokeTransaction{ - TransactionHash: new(felt.Felt).SetUint64(1), - CallData: []*felt.Felt{new(felt.Felt).SetUint64(2)}, - TransactionSignature: []*felt.Felt{new(felt.Felt).SetUint64(3)}, - MaxFee: new(felt.Felt).SetUint64(4), - ContractAddress: new(felt.Felt).SetUint64(5), - Version: new(core.TransactionVersion).SetUint64(3), - EntryPointSelector: new(felt.Felt).SetUint64(6), - Nonce: new(felt.Felt).SetUint64(7), - SenderAddress: new(felt.Felt).SetUint64(8), - ResourceBounds: map[core.Resource]core.ResourceBounds{}, - Tip: 9, - PaymasterData: []*felt.Felt{new(felt.Felt).SetUint64(10)}, - AccountDeploymentData: []*felt.Felt{new(felt.Felt).SetUint64(11)}, + syncer.pending.Send(&core.Block{ + Transactions: []core.Transaction{ + &core.InvokeTransaction{ + TransactionHash: new(felt.Felt).SetUint64(1), + CallData: []*felt.Felt{new(felt.Felt).SetUint64(2)}, + TransactionSignature: []*felt.Felt{new(felt.Felt).SetUint64(3)}, + MaxFee: new(felt.Felt).SetUint64(4), + ContractAddress: new(felt.Felt).SetUint64(5), + Version: new(core.TransactionVersion).SetUint64(3), + EntryPointSelector: new(felt.Felt).SetUint64(6), + Nonce: new(felt.Felt).SetUint64(7), + SenderAddress: new(felt.Felt).SetUint64(8), + ResourceBounds: map[core.Resource]core.ResourceBounds{}, + Tip: 9, + PaymasterData: []*felt.Felt{new(felt.Felt).SetUint64(10)}, + AccountDeploymentData: []*felt.Felt{new(felt.Felt).SetUint64(11)}, + }, }, }) diff --git a/sync/sync.go b/sync/sync.go index aa2ddab6b0..9ec490d002 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -46,6 +46,10 @@ type PendingTxSubscription struct { *feed.Subscription[[]core.Transaction] } +type PendingSubscription struct { + *feed.Subscription[*core.Block] +} + // ReorgBlockRange represents data about reorganised blocks, starting and ending block number and hash type ReorgBlockRange struct { // StartBlockHash is the hash of the first known block of the orphaned chain @@ -66,7 +70,7 @@ type Reader interface { HighestBlockHeader() *core.Header SubscribeNewHeads() HeaderSubscription SubscribeReorg() ReorgSubscription - SubscribePendingTxs() PendingTxSubscription + SubscribePending() PendingSubscription Pending() (*Pending, error) PendingBlock() *core.Block @@ -92,8 +96,8 @@ func (n *NoopSynchronizer) SubscribeReorg() ReorgSubscription { return ReorgSubscription{feed.New[*ReorgBlockRange]().Subscribe()} } -func (n *NoopSynchronizer) SubscribePendingTxs() PendingTxSubscription { - return PendingTxSubscription{feed.New[[]core.Transaction]().Subscribe()} +func (n *NoopSynchronizer) SubscribePending() PendingSubscription { + return PendingSubscription{feed.New[*core.Block]().Subscribe()} } func (n *NoopSynchronizer) PendingBlock() *core.Block { @@ -118,7 +122,7 @@ type Synchronizer struct { highestBlockHeader atomic.Pointer[core.Header] newHeads *feed.Feed[*core.Header] reorgFeed *feed.Feed[*ReorgBlockRange] - pendingTxsFeed *feed.Feed[[]core.Transaction] + pendingFeed *feed.Feed[*core.Block] log utils.SimpleLogger listener EventListener @@ -141,7 +145,7 @@ func New(bc *blockchain.Blockchain, starkNetData starknetdata.StarknetData, log log: log, newHeads: feed.New[*core.Header](), reorgFeed: feed.New[*ReorgBlockRange](), - pendingTxsFeed: feed.New[[]core.Transaction](), + pendingFeed: feed.New[*core.Block](), pendingPollInterval: pendingPollInterval, listener: &SelectiveListener{}, readOnlyBlockchain: readOnlyBlockchain, @@ -593,8 +597,8 @@ func (s *Synchronizer) SubscribeReorg() ReorgSubscription { return ReorgSubscription{s.reorgFeed.Subscribe()} } -func (s *Synchronizer) SubscribePendingTxs() PendingTxSubscription { - return PendingTxSubscription{s.pendingTxsFeed.Subscribe()} +func (s *Synchronizer) SubscribePending() PendingSubscription { + return PendingSubscription{s.pendingFeed.Subscribe()} } // StorePending stores a pending block given that it is for the next height @@ -626,8 +630,7 @@ func (s *Synchronizer) StorePending(p *Pending) error { } s.pending.Store(p) - // send the pending transactions to the feed - s.pendingTxsFeed.Send(p.Block.Transactions) + s.pendingFeed.Send(p.Block) return nil } diff --git a/sync/sync_test.go b/sync/sync_test.go index 2b6d514e88..4e2e960ba8 100644 --- a/sync/sync_test.go +++ b/sync/sync_test.go @@ -286,7 +286,7 @@ func TestSubscribeNewHeads(t *testing.T) { sub.Unsubscribe() } -func TestSubscribePendingTxs(t *testing.T) { +func TestSubscribePending(t *testing.T) { t.Parallel() client := feeder.NewTestClient(t, &utils.Mainnet) @@ -298,15 +298,15 @@ func TestSubscribePendingTxs(t *testing.T) { synchronizer := sync.New(bc, gw, log, time.Millisecond*100, false, testDB) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - sub := synchronizer.SubscribePendingTxs() + sub := synchronizer.SubscribePending() require.NoError(t, synchronizer.Run(ctx)) cancel() pending, err := synchronizer.Pending() require.NoError(t, err) - pendingTxs, ok := <-sub.Recv() + pendingBlock, ok := <-sub.Recv() require.True(t, ok) - require.Equal(t, pending.Block.Transactions, pendingTxs) + require.Equal(t, pending.Block, pendingBlock) sub.Unsubscribe() }