Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Web socket updates #2409

Merged
merged 16 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions jsonrpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,28 @@ import (
"io"
"net/http"
"strings"
"sync/atomic"
"time"

"github.com/NethermindEth/juno/utils"
"github.com/coder/websocket"
)

const closeReasonMaxBytes = 125
const (
closeReasonMaxBytes = 125
maxConns = 2048 // TODO: an arbitrary default number, should be revisited after monitoring
)

type Websocket struct {
rpc *Server
log utils.SimpleLogger
connParams *WebsocketConnParams
listener NewRequestListener
shutdown <-chan struct{}

shutdown <-chan struct{}
// Add connection tracking
connCount atomic.Int32
maxConns int32
}

func NewWebsocket(rpc *Server, shutdown <-chan struct{}, log utils.SimpleLogger) *Websocket {
Expand All @@ -29,11 +36,18 @@ func NewWebsocket(rpc *Server, shutdown <-chan struct{}, log utils.SimpleLogger)
connParams: DefaultWebsocketConnParams(),
listener: &SelectiveListener{},
shutdown: shutdown,
maxConns: maxConns,
}

return ws
}

// WithMaxConnections sets the maximum number of concurrent websocket connections
func (ws *Websocket) WithMaxConnections(maxConns int32) *Websocket {
ws.maxConns = maxConns
return ws
}

// WithConnParams sanity checks and applies the provided params.
func (ws *Websocket) WithConnParams(p *WebsocketConnParams) *Websocket {
ws.connParams = p
Expand All @@ -49,12 +63,23 @@ func (ws *Websocket) WithListener(listener NewRequestListener) *Websocket {
// ServeHTTP processes an HTTP request and upgrades it to a websocket connection.
// The connection's entire "lifetime" is spent in this function.
func (ws *Websocket) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Check connection limit
if ws.connCount.Load() >= ws.maxConns {
ws.log.Warnw("Max websocket connections reached", "maxConns", ws.maxConns)
http.Error(w, "Too many connections", http.StatusServiceUnavailable)
return
}
ws.connCount.Add(1)

conn, err := websocket.Accept(w, r, nil /* TODO: options */)
if err != nil {
ws.log.Errorw("Failed to upgrade connection", "err", err)
return
}

// Ensure we decrease the connection count when the connection closes
defer ws.connCount.Add(-1)

// TODO include connection information, such as the remote address, in the logs.

ctx, cancel := context.WithCancel(r.Context())
Expand Down
34 changes: 34 additions & 0 deletions jsonrpc/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/NethermindEth/juno/jsonrpc"
"github.com/NethermindEth/juno/utils"
Expand Down Expand Up @@ -93,3 +94,36 @@ func TestSendFromHandler(t *testing.T) {

require.NoError(t, conn.Close(websocket.StatusNormalClosure, ""))
}

func TestWebsocketConnectionLimit(t *testing.T) {
rpc := jsonrpc.NewServer(1, utils.NewNopZapLogger())
ws := jsonrpc.NewWebsocket(rpc, nil, utils.NewNopZapLogger()).WithMaxConnections(2)
httpSrv := httptest.NewServer(ws)
defer httpSrv.Close()

// First connection should succeed
conn1, resp1, err := websocket.Dial(context.Background(), httpSrv.URL, nil) //nolint:bodyclose
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, resp1.StatusCode)
defer conn1.Close(websocket.StatusNormalClosure, "")

// Second connection should succeed
conn2, resp2, err := websocket.Dial(context.Background(), httpSrv.URL, nil) //nolint:bodyclose
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, resp2.StatusCode)
defer conn2.Close(websocket.StatusNormalClosure, "")

// Third connection should fail with 503 Service Unavailable
_, resp3, err := websocket.Dial(context.Background(), httpSrv.URL, nil) //nolint:bodyclose
require.Error(t, err)
require.Equal(t, http.StatusServiceUnavailable, resp3.StatusCode)

// Close one connection and try again - should succeed
require.NoError(t, conn1.Close(websocket.StatusNormalClosure, ""))
time.Sleep(10 * time.Millisecond) // Give the server time to clean up

conn4, resp4, err := websocket.Dial(context.Background(), httpSrv.URL, nil) //nolint:bodyclose
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, resp4.StatusCode)
require.NoError(t, conn4.Close(websocket.StatusNormalClosure, ""))
}
14 changes: 7 additions & 7 deletions mocks/mock_synchronizer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions rpc/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ type BlockID struct {
Number uint64
}

func (b *BlockID) IsLatest() bool {
return b.Latest
}

func (b *BlockID) IsPending() bool {
return b.Pending
}

func (b *BlockID) GetHash() *felt.Felt {
return b.Hash
}

func (b *BlockID) GetNumber() uint64 {
return b.Number
}
AnkushinDaniil marked this conversation as resolved.
Show resolved Hide resolved

func (b *BlockID) UnmarshalJSON(data []byte) error {
if string(data) == `"latest"` {
b.Latest = true
Expand Down
28 changes: 14 additions & 14 deletions rpc/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ var (
ErrInvalidSubscriptionID = &jsonrpc.Error{Code: 66, Message: "Invalid subscription id"}
ErrTooManyAddressesInFilter = &jsonrpc.Error{Code: 67, Message: "Too many addresses in filter sender_address filter"}
ErrTooManyBlocksBack = &jsonrpc.Error{Code: 68, Message: fmt.Sprintf("Cannot go back more than %v blocks", maxBlocksBack)}
ErrCallOnPending = &jsonrpc.Error{Code: 69, Message: "This method does not support being called on the pending block"}
)

const (
Expand All @@ -93,11 +92,11 @@ type Handler struct {
vm vm.VM
log utils.Logger

version string
newHeads *feed.Feed[*core.Header]
reorgs *feed.Feed[*sync.ReorgBlockRange]
pendingTxs *feed.Feed[[]core.Transaction]
l1Heads *feed.Feed[*core.L1Head]
version string
newHeads *feed.Feed[*core.Header]
reorgs *feed.Feed[*sync.ReorgBlockRange]
pendingBlock *feed.Feed[*core.Block]
l1Heads *feed.Feed[*core.L1Head]

idgen func() uint64
subscriptions stdsync.Map // map[uint64]*subscription
Expand Down Expand Up @@ -135,11 +134,11 @@ func New(bcReader blockchain.Reader, syncReader sync.Reader, virtualMachine vm.V
}
return n
},
version: version,
newHeads: feed.New[*core.Header](),
reorgs: feed.New[*sync.ReorgBlockRange](),
pendingTxs: feed.New[[]core.Transaction](),
l1Heads: feed.New[*core.L1Head](),
version: version,
newHeads: feed.New[*core.Header](),
reorgs: feed.New[*sync.ReorgBlockRange](),
pendingBlock: feed.New[*core.Block](),
l1Heads: feed.New[*core.L1Head](),

blockTraceCache: lru.NewCache[traceCacheKey, []TracedBlockTransaction](traceCacheSize),
filterLimit: math.MaxUint,
Expand Down Expand Up @@ -181,16 +180,17 @@ func (h *Handler) WithGateway(gatewayClient Gateway) *Handler {
func (h *Handler) Run(ctx context.Context) error {
newHeadsSub := h.syncReader.SubscribeNewHeads().Subscription
reorgsSub := h.syncReader.SubscribeReorg().Subscription
pendingTxsSub := h.syncReader.SubscribePendingTxs().Subscription
l1HeadsSub := h.bcReader.SubscribeL1Head().Subscription
pendingBlock := h.syncReader.SubscribePending().Subscription
defer newHeadsSub.Unsubscribe()
defer reorgsSub.Unsubscribe()
defer pendingTxsSub.Unsubscribe()
defer l1HeadsSub.Unsubscribe()
defer pendingBlock.Unsubscribe()

feed.Tee(newHeadsSub, h.newHeads)
feed.Tee(reorgsSub, h.reorgs)
feed.Tee(pendingTxsSub, h.pendingTxs)
feed.Tee(l1HeadsSub, h.l1Heads)
feed.Tee(pendingBlock, h.pendingBlock)

<-ctx.Done()
h.subscriptions.Range(func(key, value any) bool {
Expand Down
13 changes: 7 additions & 6 deletions rpc/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,23 @@ func (h *Handler) blockByID(id *BlockID) (*core.Block, *jsonrpc.Error) {
return block, nil
}

func (h *Handler) blockHeaderByID(id *BlockID) (*core.Header, *jsonrpc.Error) {
func (h *Handler) blockHeaderByID(id BlockIdentifier) (*core.Header, *jsonrpc.Error) {
weiihann marked this conversation as resolved.
Show resolved Hide resolved
var header *core.Header
var err error

switch {
case id.Latest:
case id.IsLatest():
header, err = h.bcReader.HeadsHeader()
case id.Hash != nil:
header, err = h.bcReader.BlockHeaderByHash(id.Hash)
case id.Pending:
case id.GetHash() != nil:
header, err = h.bcReader.BlockHeaderByHash(id.GetHash())
case id.IsPending():
var pending *sync.Pending
pending, err = h.syncReader.Pending()
if err == nil {
header = pending.Block.Header
}
default:
header, err = h.bcReader.BlockHeaderByNumber(id.Number)
header, err = h.bcReader.BlockHeaderByNumber(id.GetNumber())
}

if err != nil {
Expand Down
Loading