Skip to content

Commit

Permalink
Show: Just for show. This will fail test and this commit will be over…
Browse files Browse the repository at this point in the history
…written in next push.
  • Loading branch information
obasekiosa committed Aug 6, 2024
1 parent 657fa53 commit 1f8b4f2
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 100 deletions.
6 changes: 5 additions & 1 deletion jsonrpc/http.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jsonrpc

import (
"maps"
"net/http"

"github.com/NethermindEth/juno/utils"
Expand Down Expand Up @@ -46,8 +47,11 @@ func (h *HTTP) ServeHTTP(writer http.ResponseWriter, req *http.Request) {

req.Body = http.MaxBytesReader(writer, req.Body, MaxRequestBodySize)
h.listener.OnNewRequest("any")
resp, err := h.rpc.HandleReader(req.Context(), req.Body)
resp, header, err := h.rpc.HandleReader(req.Context(), req.Body)

writer.Header().Set("Content-Type", "application/json")
maps.Copy(writer.Header(), header) // overwrites duplicate headers

if err != nil {
h.log.Errorw("Handler failure", "err", err)
writer.WriteHeader(http.StatusInternalServerError)
Expand Down
63 changes: 40 additions & 23 deletions jsonrpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"net/http"
"reflect"
"strings"
"sync"
Expand Down Expand Up @@ -205,6 +206,7 @@ type connection struct {
w io.Writer
activated <-chan struct{}

//todo: guard against this in the code! don't depend on devs finding out about this!
// initialErr is not thread-safe. It must be set to its final value before the connection is activated.
initialErr error
}
Expand Down Expand Up @@ -255,7 +257,8 @@ func (s *Server) HandleReadWriter(ctx context.Context, rw io.ReadWriter) error {
activated: activated,
}
msgCtx := context.WithValue(ctx, ConnKey{}, conn)
resp, err := s.HandleReader(msgCtx, rw)
// header is unnecessary for read-writer(websocket)
resp, _, err := s.HandleReader(msgCtx, rw)
if err != nil {
conn.initialErr = err
return err
Expand All @@ -272,27 +275,31 @@ func (s *Server) HandleReadWriter(ctx context.Context, rw io.ReadWriter) error {
// HandleReader processes a request to the server
// It returns the response in a byte array, only returns an
// error if it can not create the response byte array
func (s *Server) HandleReader(ctx context.Context, reader io.Reader) ([]byte, error) {
func (s *Server) HandleReader(ctx context.Context, reader io.Reader) ([]byte, http.Header, error) {
bufferedReader := bufio.NewReaderSize(reader, bufferSize)
requestIsBatch := isBatch(bufferedReader)
res := &response{
Version: "2.0",
}

var header http.Header

dec := json.NewDecoder(bufferedReader)
dec.UseNumber()

if !requestIsBatch {
req := new(Request)
if jsonErr := dec.Decode(req); jsonErr != nil {
res.Error = Err(InvalidJSON, jsonErr.Error())
} else if resObject, handleErr := s.handleRequest(ctx, req); handleErr != nil {
} else if resObject, httpHeader, handleErr := s.handleRequest(ctx, req); handleErr != nil {
if !errors.Is(handleErr, ErrInvalidID) {
res.ID = req.ID
}
res.Error = Err(InvalidRequest, handleErr.Error())
header = httpHeader
} else {
res = resObject
header = httpHeader
}
} else {
var batchReq []json.RawMessage
Expand All @@ -306,24 +313,24 @@ func (s *Server) HandleReader(ctx context.Context, reader io.Reader) ([]byte, er
}
}

if res == nil {
return nil, nil
}
return json.Marshal(res)
result, err := json.Marshal(res)
return result, header, err
}

func (s *Server) handleBatchRequest(ctx context.Context, batchReq []json.RawMessage) ([]byte, error) {
func (s *Server) handleBatchRequest(ctx context.Context, batchReq []json.RawMessage) ([]byte, http.Header, error) {
var (
responses []json.RawMessage
mutex sync.Mutex
headers []http.Header
)

addResponse := func(response any) {
addResponse := func(response any, header http.Header) {
if responseJSON, err := json.Marshal(response); err != nil {
s.log.Errorw("failed to marshal response", "err", err)
} else {
mutex.Lock()
responses = append(responses, responseJSON)
headers = append(headers, header)
mutex.Unlock()
}
}
Expand All @@ -341,15 +348,15 @@ func (s *Server) handleBatchRequest(ctx context.Context, batchReq []json.RawMess
addResponse(&response{
Version: "2.0",
Error: Err(InvalidRequest, err.Error()),
})
}, http.Header{})
continue
}

wg.Add(1)
s.pool.Go(func() {
defer wg.Done()

resp, err := s.handleRequest(ctx, req)
resp, header, err := s.handleRequest(ctx, req)
if err != nil {
resp = &response{
Version: "2.0",
Expand All @@ -359,20 +366,22 @@ func (s *Server) handleBatchRequest(ctx context.Context, batchReq []json.RawMess
resp.ID = req.ID
}
}
// for notification request response is nil
// for notification request response is nil and header is irrelevant for now
if resp != nil {
addResponse(resp)
addResponse(resp, header)
}
})
}

wg.Wait()
// according to the spec if there are no response objects server must not return empty array
if len(responses) == 0 {
return nil, nil
return nil, nil, nil
}

return json.Marshal(responses)
result, err := json.Marshal(responses)

return result, headers[0], err // todo: fix batch request aggregate header
}

func isBatch(reader *bufio.Reader) bool {
Expand All @@ -396,11 +405,13 @@ func isNil(i any) bool {
return i == nil || reflect.ValueOf(i).IsNil()
}

func (s *Server) handleRequest(ctx context.Context, req *Request) (*response, error) {
func (s *Server) handleRequest(ctx context.Context, req *Request) (*response, http.Header, error) {
s.log.Tracew("Received request", "req", req)

header := http.Header{}
if err := req.isSane(); err != nil {
s.log.Tracew("Request sanity check failed", "err", err)
return nil, err
return nil, header, err
}

res := &response{
Expand All @@ -412,7 +423,7 @@ func (s *Server) handleRequest(ctx context.Context, req *Request) (*response, er
if !found {
res.Error = Err(MethodNotFound, nil)
s.log.Tracew("Method not found in request", "method", req.Method)
return res, nil
return res, header, nil
}

handlerTimer := time.Now()
Expand All @@ -421,7 +432,7 @@ func (s *Server) handleRequest(ctx context.Context, req *Request) (*response, er
if err != nil {
res.Error = Err(InvalidParams, err.Error())
s.log.Tracew("Error building arguments for RPC call", "err", err)
return res, nil
return res, header, nil
}
defer func() {
s.listener.OnRequestHandled(req.Method, time.Since(handlerTimer))
Expand All @@ -430,22 +441,28 @@ func (s *Server) handleRequest(ctx context.Context, req *Request) (*response, er
tuple := reflect.ValueOf(calledMethod.Handler).Call(args)
if res.ID == nil { // notification
s.log.Tracew("Notification received, no response expected")
return nil, nil
return nil, header, nil
}

errorIndex := 1
if len(tuple) == 3 {
errorIndex = 2
header = (tuple[1].Interface()).(http.Header)
}

if errAny := tuple[1].Interface(); !isNil(errAny) {
if errAny := tuple[errorIndex].Interface(); !isNil(errAny) {
res.Error = errAny.(*Error)
if res.Error.Code == InternalError {
s.listener.OnRequestFailed(req.Method, res.Error)
reqJSON, _ := json.Marshal(req)
errJSON, _ := json.Marshal(res.Error)
s.log.Debugw("Failed handing RPC request", "req", string(reqJSON), "res", string(errJSON))
}
return res, nil
return res, header, nil
}
res.Result = tuple[0].Interface()

return res, nil
return res, header, nil
}

func (s *Server) buildArguments(ctx context.Context, params any, method Method) ([]reflect.Value, error) {
Expand Down
2 changes: 1 addition & 1 deletion p2p/starknet/handlers.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//go:generate protoc --go_out=./ --proto_path=./ --go_opt=Mp2p/proto/transaction.proto=./spec --go_opt=Mp2p/proto/state.proto=./spec --go_opt=Mp2p/proto/snapshot.proto=./spec --go_opt=Mp2p/proto/receipt.proto=./spec --go_opt=Mp2p/proto/mempool.proto=./spec --go_opt=Mp2p/proto/event.proto=./spec --go_opt=Mp2p/proto/block.proto=./spec --go_opt=Mp2p/proto/common.proto=./spec p2p/proto/transaction.proto p2p/proto/state.proto p2p/proto/snapshot.proto p2p/proto/common.proto p2p/proto/block.proto p2p/proto/event.proto p2p/proto/receipt.proto
//go:generate protoc --go_out=./ --proto_path=./ --go_opt=Mp2p/proto/transaction.proto=./spec --go_opt=Mp2p/proto/state.proto=./spec --go_opt=Mp2p/proto/snapshot.proto=./spec --go_opt=Mp2p/proto/receipt.proto=./spec --go_opt=Mp2p/proto/mempool.proto=./spec --go_opt=Mp2p/proto/event.proto=./spec --go_opt=Mp2p/proto/block.proto=./spec --go_opt=Mp2p/proto/common.proto=./spec p2p/proto/transaction.proto p2p/proto/state.proto p2p/proto/common.proto p2p/proto/event.proto p2p/proto/receipt.proto
package starknet

import (
Expand Down
39 changes: 20 additions & 19 deletions rpc/estimate_fee.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"

"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/jsonrpc"
Expand Down Expand Up @@ -64,52 +65,52 @@ func (f FeeEstimate) MarshalJSON() ([]byte, error) {

func (h *Handler) EstimateFee(broadcastedTxns []BroadcastedTransaction,
simulationFlags []SimulationFlag, id BlockID,
) ([]FeeEstimate, *jsonrpc.Error) {
result, err := h.simulateTransactions(id, broadcastedTxns, append(simulationFlags, SkipFeeChargeFlag), false, true)
) ([]FeeEstimate, http.Header, *jsonrpc.Error) {
result, httpHeader, err := h.simulateTransactions(id, broadcastedTxns, append(simulationFlags, SkipFeeChargeFlag), false, true)
if err != nil {
return nil, err
return nil, httpHeader, err
}

return utils.Map(result, func(tx SimulatedTransaction) FeeEstimate {
return tx.FeeEstimation
}), nil
}), httpHeader, nil
}

func (h *Handler) EstimateFeeV0_6(broadcastedTxns []BroadcastedTransaction,
simulationFlags []SimulationFlag, id BlockID,
) ([]FeeEstimate, *jsonrpc.Error) {
result, err := h.simulateTransactions(id, broadcastedTxns, append(simulationFlags, SkipFeeChargeFlag), true, true)
) ([]FeeEstimate, http.Header, *jsonrpc.Error) {
result, httpHeader, err := h.simulateTransactions(id, broadcastedTxns, append(simulationFlags, SkipFeeChargeFlag), true, true)
if err != nil {
return nil, err
return nil, httpHeader, err
}

return utils.Map(result, func(tx SimulatedTransaction) FeeEstimate {
return tx.FeeEstimation
}), nil
}), httpHeader, nil
}

func (h *Handler) EstimateMessageFee(msg MsgFromL1, id BlockID) (*FeeEstimate, *jsonrpc.Error) { //nolint:gocritic
func (h *Handler) EstimateMessageFee(msg MsgFromL1, id BlockID) (*FeeEstimate, http.Header, *jsonrpc.Error) { //nolint:gocritic
return h.estimateMessageFee(msg, id, h.EstimateFee)
}

func (h *Handler) EstimateMessageFeeV0_6(msg MsgFromL1, id BlockID) (*FeeEstimate, *jsonrpc.Error) { //nolint:gocritic
feeEstimate, rpcErr := h.estimateMessageFee(msg, id, h.EstimateFeeV0_6)
func (h *Handler) EstimateMessageFeeV0_6(msg MsgFromL1, id BlockID) (*FeeEstimate, http.Header, *jsonrpc.Error) { //nolint:gocritic
feeEstimate, httpHeader, rpcErr := h.estimateMessageFee(msg, id, h.EstimateFeeV0_6)
if rpcErr != nil {
return nil, rpcErr
return nil, httpHeader, rpcErr
}

feeEstimate.v0_6Response = true
feeEstimate.DataGasPrice = nil
feeEstimate.DataGasConsumed = nil

return feeEstimate, nil
return feeEstimate, httpHeader, nil
}

type estimateFeeHandler func(broadcastedTxns []BroadcastedTransaction,
simulationFlags []SimulationFlag, id BlockID,
) ([]FeeEstimate, *jsonrpc.Error)
) ([]FeeEstimate, http.Header, *jsonrpc.Error)

func (h *Handler) estimateMessageFee(msg MsgFromL1, id BlockID, f estimateFeeHandler) (*FeeEstimate, *jsonrpc.Error) { //nolint:gocritic
func (h *Handler) estimateMessageFee(msg MsgFromL1, id BlockID, f estimateFeeHandler) (*FeeEstimate, http.Header, *jsonrpc.Error) { //nolint:gocritic
calldata := make([]*felt.Felt, 0, len(msg.Payload)+1)
// The order of the calldata parameters matters. msg.From must be prepended.
calldata = append(calldata, new(felt.Felt).SetBytes(msg.From.Bytes()))
Expand All @@ -129,15 +130,15 @@ func (h *Handler) estimateMessageFee(msg MsgFromL1, id BlockID, f estimateFeeHan
// Must be greater than zero to successfully execute transaction.
PaidFeeOnL1: new(felt.Felt).SetUint64(1),
}
estimates, rpcErr := f([]BroadcastedTransaction{tx}, nil, id)
estimates, httpHeader, rpcErr := f([]BroadcastedTransaction{tx}, nil, id)
if rpcErr != nil {
if rpcErr.Code == ErrTransactionExecutionError.Code {
data := rpcErr.Data.(TransactionExecutionErrorData)
return nil, makeContractError(errors.New(data.ExecutionError))
return nil, httpHeader, makeContractError(errors.New(data.ExecutionError))
}
return nil, rpcErr
return nil, httpHeader, rpcErr
}
return &estimates[0], nil
return &estimates[0], httpHeader, nil
}

type ContractErrorData struct {
Expand Down
3 changes: 2 additions & 1 deletion rpc/estimate_fee_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ func TestEstimateMessageFee(t *testing.T) {

t.Run("block not found", func(t *testing.T) {
mockReader.EXPECT().HeadState().Return(nil, nil, db.ErrKeyNotFound)
_, err := handler.EstimateMessageFeeV0_6(msg, rpc.BlockID{Latest: true})
_, httpHeader, err := handler.EstimateMessageFeeV0_6(msg, rpc.BlockID{Latest: true})
require.Equal(t, rpc.ErrBlockNotFound, err)
require.NotNil(t, httpHeader)
})

latestHeader := &core.Header{
Expand Down
6 changes: 4 additions & 2 deletions rpc/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ func TestThrottledVMError(t *testing.T) {
t.Run("simulate", func(t *testing.T) {
mockReader.EXPECT().HeadState().Return(mockState, nopCloser, nil)
mockReader.EXPECT().HeadsHeader().Return(&core.Header{}, nil)
_, rpcErr := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipFeeChargeFlag})
_, httpHeader, rpcErr := handler.SimulateTransactions(rpc.BlockID{Latest: true}, []rpc.BroadcastedTransaction{}, []rpc.SimulationFlag{rpc.SkipFeeChargeFlag})
assert.Equal(t, throttledErr, rpcErr.Data)
require.NotNil(t, httpHeader)
})

t.Run("trace", func(t *testing.T) {
Expand Down Expand Up @@ -95,7 +96,8 @@ func TestThrottledVMError(t *testing.T) {
headState := mocks.NewMockStateHistoryReader(mockCtrl)
headState.EXPECT().Class(declareTx.ClassHash).Return(declaredClass, nil)
mockReader.EXPECT().PendingState().Return(headState, nopCloser, nil)
_, rpcErr := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Hash: blockHash})
_, httpHeader, rpcErr := handler.TraceBlockTransactions(context.Background(), rpc.BlockID{Hash: blockHash})
assert.Equal(t, throttledErr, rpcErr.Data)
require.NotNil(t, httpHeader)
})
}
Loading

0 comments on commit 1f8b4f2

Please sign in to comment.