From fe17176c8a69309343075c1a7cbfda18e462247a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20Faruk=20IRMAK?= Date: Fri, 2 Feb 2024 14:34:50 +0300 Subject: [PATCH] Allow toggling CORS support --- cmd/juno/juno.go | 4 ++++ node/http.go | 18 ++++++++++++++---- node/node.go | 5 +++-- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/cmd/juno/juno.go b/cmd/juno/juno.go index b46c1f31a9..23db82c286 100644 --- a/cmd/juno/juno.go +++ b/cmd/juno/juno.go @@ -77,6 +77,7 @@ const ( cnCoreContractAddressF = "cn-core-contract-address" cnUnverifiableRangeF = "cn-unverifiable-range" callMaxStepsF = "rpc-call-max-steps" + corsEnableF = "rpc-cors-enable" defaultConfig = "" defaulHost = "localhost" @@ -109,6 +110,7 @@ const ( defaultCNCoreContractAddressStr = "" defaultCallMaxSteps = 4_000_000 defaultGwTimeout = 5 * time.Second + defaultCorsEnable = false configFlagUsage = "The yaml configuration file." logLevelFlagUsage = "Options: debug, info, warn, error." @@ -152,6 +154,7 @@ const ( gwAPIKeyUsage = "API key for gateway endpoints to avoid throttling" //nolint: gosec gwTimeoutUsage = "Timeout for requests made to the gateway" //nolint: gosec callMaxStepsUsage = "Maximum number of steps to be executed in starknet_call requests" + corsEnableUsage = "Enable CORS on RPC endpoints" ) var Version string @@ -328,6 +331,7 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr junoCmd.MarkFlagsMutuallyExclusive(networkF, cnNameF) junoCmd.Flags().Uint(callMaxStepsF, defaultCallMaxSteps, callMaxStepsUsage) junoCmd.Flags().Duration(gwTimeoutF, defaultGwTimeout, gwTimeoutUsage) + junoCmd.Flags().Bool(corsEnableF, defaultCorsEnable, corsEnableUsage) return junoCmd } diff --git a/node/http.go b/node/http.go index 93727388d8..c0a8f3acbe 100644 --- a/node/http.go +++ b/node/http.go @@ -74,7 +74,7 @@ func exactPathServer(path string, handler http.Handler) http.HandlerFunc { } func makeRPCOverHTTP(host string, port uint16, servers map[string]*jsonrpc.Server, - log utils.SimpleLogger, metricsEnabled bool, + log utils.SimpleLogger, metricsEnabled bool, corsEnabled bool, ) *httpService { var listener jsonrpc.NewRequestListener if metricsEnabled { @@ -89,11 +89,16 @@ func makeRPCOverHTTP(host string, port uint16, servers map[string]*jsonrpc.Serve } mux.Handle(path, exactPathServer(path, httpHandler)) } - return makeHTTPService(host, port, cors.Default().Handler(mux)) + + var handler http.Handler = mux + if corsEnabled { + handler = cors.Default().Handler(handler) + } + return makeHTTPService(host, port, handler) } func makeRPCOverWebsocket(host string, port uint16, servers map[string]*jsonrpc.Server, - log utils.SimpleLogger, metricsEnabled bool, + log utils.SimpleLogger, metricsEnabled bool, corsEnabled bool, ) *httpService { var listener jsonrpc.NewRequestListener if metricsEnabled { @@ -110,7 +115,12 @@ func makeRPCOverWebsocket(host string, port uint16, servers map[string]*jsonrpc. wsPrefixedPath := strings.TrimSuffix("/ws"+path, "/") mux.Handle(wsPrefixedPath, exactPathServer(wsPrefixedPath, wsHandler)) } - return makeHTTPService(host, port, cors.Default().Handler(mux)) + + var handler http.Handler = mux + if corsEnabled { + handler = cors.Default().Handler(handler) + } + return makeHTTPService(host, port, handler) } func makeMetrics(host string, port uint16) *httpService { diff --git a/node/node.go b/node/node.go index 0763e9f10c..035c8f2279 100644 --- a/node/node.go +++ b/node/node.go @@ -48,6 +48,7 @@ type Config struct { HTTP bool `mapstructure:"http"` HTTPHost string `mapstructure:"http-host"` HTTPPort uint16 `mapstructure:"http-port"` + RPCCorsEnable bool `mapstructure:"rpc-cors-enable"` Websocket bool `mapstructure:"ws"` WebsocketHost string `mapstructure:"ws-host"` WebsocketPort uint16 `mapstructure:"ws-port"` @@ -169,10 +170,10 @@ func New(cfg *Config, version string) (*Node, error) { //nolint:gocyclo,funlen "/rpc" + legacyPath: jsonrpcServerLegacy, } if cfg.HTTP { - services = append(services, makeRPCOverHTTP(cfg.HTTPHost, cfg.HTTPPort, rpcServers, log, cfg.Metrics)) + services = append(services, makeRPCOverHTTP(cfg.HTTPHost, cfg.HTTPPort, rpcServers, log, cfg.Metrics, cfg.RPCCorsEnable)) } if cfg.Websocket { - services = append(services, makeRPCOverWebsocket(cfg.WebsocketHost, cfg.WebsocketPort, rpcServers, log, cfg.Metrics)) + services = append(services, makeRPCOverWebsocket(cfg.WebsocketHost, cfg.WebsocketPort, rpcServers, log, cfg.Metrics, cfg.RPCCorsEnable)) } var metricsService service.Service if cfg.Metrics {