Skip to content

Commit

Permalink
Merge pull request #84 from RyoJerryYu/feat-add-new-gatewayx
Browse files Browse the repository at this point in the history
feat: add new gatewayx marshaler
  • Loading branch information
RyoJerryYu authored Jan 15, 2025
2 parents 62d2f7f + 6f46ae7 commit 3505817
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 2 deletions.
7 changes: 7 additions & 0 deletions pkg/gatewayx/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package gatewayx

const (
MIMETextEventStream = "text/event-stream"
MIMETextPlain = "text/plain"
MIMEApplicationJSON = "application/json"
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ import (
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
)

// EventStreamMarshaler is a marshaler that returns each stream message as:
// data: {"result": {...result}}
// This will impliment the MDN EventStream for gRPC-Gateway server stream methods.
// spec: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events
type EventStreamMarshaler struct {
runtime.Marshaler
}

func (m *EventStreamMarshaler) ContentType(_ interface{}) string {
return "text/event-stream"
return MIMETextEventStream
}

func (m *EventStreamMarshaler) Marshal(v interface{}) ([]byte, error) {
Expand All @@ -39,12 +43,17 @@ func (m EventStreamMarshaler) NewEncoder(w io.Writer) runtime.Encoder {
})
}

// UnwrapEventStreamMarshaler is a marshaler that unwraps the result or error field from the response
// it will return the result as:
// data: {...result}
// instead of:
// data: {"result": {...result}}
type UnwrapEventStreamMarshaler struct {
runtime.Marshaler
}

func (m *UnwrapEventStreamMarshaler) ContentType(_ interface{}) string {
return "text/event-stream"
return MIMETextEventStream
}

func (m *UnwrapEventStreamMarshaler) Marshal(v interface{}) ([]byte, error) {
Expand Down
File renamed without changes.
38 changes: 38 additions & 0 deletions pkg/gatewayx/marshaler_plain_text.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package gatewayx

import (
"io"

"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
)

type PlainTextMarshaler struct {
runtime.Marshaler
}

func (m *PlainTextMarshaler) ContentType(_ interface{}) string {
return MIMETextPlain
}

func (m *PlainTextMarshaler) Marshal(v interface{}) ([]byte, error) {
switch v := v.(type) {
case string:
return []byte(v), nil
case []byte:
return v, nil
}
// if number, bool, or other types, use the default marshaller,
// JSON marshal would work well for them.
return m.Marshaler.Marshal(v)
}

func (m PlainTextMarshaler) NewEncoder(w io.Writer) runtime.Encoder {
return runtime.EncoderFunc(func(v interface{}) error {
data, err := m.Marshal(v)
if err != nil {
return err
}
_, err = w.Write(data)
return err
})
}
54 changes: 54 additions & 0 deletions pkg/gatewayx/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package gatewayx

import "net/http"

/**
Sometime Only use the gRPC-Gateway is not enough.
You may need to add some middleware to the gateway to handle some special cases.
*/

type pathMatcher func(string) bool

func PathEqual(path string) pathMatcher {
return func(p string) bool {
return p == path
}
}

func PathPrefix(prefix string) pathMatcher {
return func(p string) bool {
return len(p) >= len(prefix) && p[:len(prefix)] == prefix
}
}

// OverwriteAccept returns a middleware that overwrites the Accept header of the request
// It's useful when you cannot force the client to send the correct Accept header
func OverwriteAccept(accept string, matchers ...pathMatcher) func(handler http.Handler) http.Handler {
return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, matcher := range matchers {
if matcher(r.URL.Path) {
r.Header.Set("Accept", accept)
break
}
}
handler.ServeHTTP(w, r)
})
}
}

// OverwriteContentType returns a middleware that overwrites the Content-Type header of the http request
// It's useful when you cannot force the client to send the correct Content-Type header
func OverwriteContentType(contentType string, matchers ...pathMatcher) func(handler http.Handler) http.Handler {
return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, matcher := range matchers {
if matcher(r.URL.Path) {
r.Header.Set("Content-Type", contentType)
break
}
}
handler.ServeHTTP(w, r)
})
}
}

0 comments on commit 3505817

Please sign in to comment.