diff --git a/_example/client/main.go b/_example/client/main.go index a50061d..c8b3dc0 100644 --- a/_example/client/main.go +++ b/_example/client/main.go @@ -21,7 +21,7 @@ func main() { fmt.Println(err) return } - req.Header.Add("Content-Type", "application/json") + req.Header.Add("x-custom-key", "foobar") res, err := client.Do(req) if err != nil { diff --git a/_example/server/main.go b/_example/server/main.go index 8d480a2..c102fb2 100644 --- a/_example/server/main.go +++ b/_example/server/main.go @@ -50,7 +50,7 @@ func main() { log.Fatal(err) } - httpSv, jrpcLis, err := newJrpcSv(proto.RegisterEchoServiceJsonRpc(grpcConn)) + httpSv, jrpcLis, err := newJrpcSv(proto.RegisterEchoServiceJsonRPC(grpcConn)) if err != nil { log.Fatal(err) } @@ -113,8 +113,8 @@ func newGRPCSv(server *Server) (*grpc.Server, net.Listener, error) { return sv, listener, nil } -func newJrpcSv(echoClient *proto.EchoServiceJsonRpcService) (*jrpc.Server, net.Listener, error) { - jrpcSv := jrpc.NewServer() +func newJrpcSv(echoClient *proto.EchoServiceJsonRPC) (*jrpc.Server, net.Listener, error) { + jrpcSv := jrpc.NewServer(jrpc.WithCustomHeaders("x-custom-key")) jrpcSv.RegisterServices(echoClient) diff --git a/jrpc/server.go b/jrpc/server.go index 705465e..dfd3bfe 100644 --- a/jrpc/server.go +++ b/jrpc/server.go @@ -4,14 +4,13 @@ import ( "context" "encoding/json" "github.com/creachadair/jrpc2" + "github.com/creachadair/jrpc2/handler" + "github.com/creachadair/jrpc2/jhttp" "google.golang.org/grpc/metadata" "io" "net" "net/http" - "time" - - "github.com/creachadair/jrpc2/handler" - "github.com/creachadair/jrpc2/jhttp" + "strings" ) type method = func(ctx context.Context, message json.RawMessage) (any, error) @@ -21,8 +20,9 @@ type Service interface { } type Server struct { - sv *http.Server - handler http.Handler + sv *http.Server + handler http.Handler + customHeaders []string } type paramsAndHeaders struct { @@ -31,16 +31,22 @@ type paramsAndHeaders struct { } // NewServer create json rpc server -func NewServer() *Server { +func NewServer(opts ...Option) *Server { sv := new(Server) + opt := defaultOpt() mux := http.NewServeMux() mux.HandleFunc("/", sv.httpHandler) server := &http.Server{ - ReadHeaderTimeout: 3 * time.Second, + ReadHeaderTimeout: opt.ReadHeaderTimeout, Handler: mux, } + for _, o := range opts { + o(opt) + } + + sv.customHeaders = opt.CustomHeadersKey sv.sv = server return sv @@ -80,7 +86,7 @@ func (s *Server) RegisterServices(svs ...Service) { // Decorate the incoming request parameters with the headers. for _, pr := range prs { w, err := json.Marshal(paramsAndHeaders{ - Headers: headersToMetadata(req), + Headers: s.headersToMetadata(req), Params: pr.Params, }) if err != nil { @@ -97,12 +103,17 @@ func (s *Server) httpHandler(w http.ResponseWriter, r *http.Request) { s.handler.ServeHTTP(w, r) } -func headersToMetadata(r *http.Request) metadata.MD { +func (s *Server) headersToMetadata(r *http.Request) metadata.MD { headersMap := make(map[string]string) - for key, values := range r.Header { - if len(values) > 0 { - headersMap[key] = values[0] + + for _, header := range s.customHeaders { + canonicalHeader := http.CanonicalHeaderKey(header) + if v, ok := r.Header[canonicalHeader]; ok { + if len(v) > 0 { + headersMap[strings.ToLower(canonicalHeader)] = v[0] + } } } + return metadata.New(headersMap) }