-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrp.go
204 lines (186 loc) · 5.51 KB
/
rp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
package rp
import (
"bytes"
"crypto/tls"
"io"
"net/http"
"net/http/httputil"
"net/url"
"path/filepath"
"strings"
)
const errorKey = "X-Proxy-Error"
// Relayer is the interface of the implementation that determines the behavior of the reverse proxy
type Relayer interface { //nostyle:ifacenames
// GetUpstream returns the upstream URL for the given request.
// If upstream is not determined, nil may be returned
// DO NOT modify the request in this method.
GetUpstream(*http.Request) (*url.URL, error)
}
type Rewriter interface {
// Rewrite rewrites the request before sending it to the upstream.
// For example, you can set `X-Forwarded-*` header here using [httputil.ProxyRequest.SetXForwarded](https://pkg.go.dev/net/http/httputil#ProxyRequest.SetXForwarded)
Rewrite(*httputil.ProxyRequest) error
}
type CertGetter interface { //nostyle:ifacenames
// GetCertificate returns the TLS certificate for the given client hello info.
GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error)
}
type RoundTripper interface {
// RoundTrip performs the round trip of the request.
// It is necessary to implement the functions that http.Transport is responsible for (e.g. MaxIdleConnsPerHost).
RoundTrip(r *http.Request) (*http.Response, error)
RoundTripOnError(r *http.Request) (*http.Response, error)
}
type RoundTipperOnErrorer interface {
// RoundTripOnError performs the round trip of the request when the upstream returns an error.
// If this method is not implemented, the request will be sent to the default transport error.
RoundTripOnError(r *http.Request) (*http.Response, error)
}
type ErrorHandler interface {
// ErrorHandler handles the error returned by the upstream.
// If this method is not implemented, the error will be returned to the client.
ErrorHandler(http.ResponseWriter, *http.Request, error)
}
type relayer struct {
Relayer
Rewrite func(*httputil.ProxyRequest) error
GetCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error)
RoundTrip func(*http.Request) (*http.Response, error)
RoundTripOnError func(*http.Request) (*http.Response, error)
ErrorHandler func(http.ResponseWriter, *http.Request, error)
}
func newRelayer(r Relayer) *relayer {
rr := &relayer{
Relayer: r,
}
if v, ok := r.(Rewriter); ok {
rr.Rewrite = v.Rewrite
}
if v, ok := r.(CertGetter); ok {
rr.GetCertificate = v.GetCertificate
}
if v, ok := r.(RoundTripper); ok {
rr.RoundTrip = v.RoundTrip
} else {
rr.RoundTrip = http.DefaultTransport.RoundTrip
}
if v, ok := r.(RoundTipperOnErrorer); ok {
rr.RoundTripOnError = v.RoundTripOnError
}
if v, ok := r.(ErrorHandler); ok {
rr.ErrorHandler = v.ErrorHandler
}
return rr
}
// NewRouter returns a new reverse proxy router.
func NewRouter(r Relayer) http.Handler {
rr := newRelayer(r)
if rr.Rewrite == nil {
return &httputil.ReverseProxy{
Rewrite: func(pr *httputil.ProxyRequest) {
u, err := rr.GetUpstream(pr.In)
if err != nil {
pr.Out.Header.Set(errorKey, err.Error())
return
}
if u != nil {
if strings.HasPrefix(u.Host, "/") {
// Unix domain socket path
pr.Out.Host = pr.In.Host
} else {
pr.Out.Host = u.Host
}
pr.Out.URL = u
pr.SetXForwarded()
}
},
Transport: newTransport(rr),
ErrorHandler: rr.ErrorHandler,
}
}
return &httputil.ReverseProxy{
Rewrite: func(pr *httputil.ProxyRequest) {
u, err := rr.GetUpstream(pr.In)
if err != nil {
pr.Out.Header.Set(errorKey, err.Error())
return
}
if u != nil {
if filepath.IsAbs(u.Host) {
// Unix domain socket path
pr.Out.Host = pr.In.Host
} else {
pr.Out.Host = u.Host
}
pr.Out.URL = u
}
if err := rr.Rewrite(pr); err != nil {
pr.Out.Header.Set(errorKey, err.Error())
return
}
},
Transport: newTransport(rr),
ErrorHandler: rr.ErrorHandler,
}
}
// NewServer returns a new reverse proxy server.
func NewServer(addr string, r Relayer) *http.Server {
rp := NewRouter(r)
return &http.Server{
Addr: addr,
Handler: rp,
}
}
// NewTLSServer returns a new reverse proxy TLS server.
func NewTLSServer(addr string, r Relayer) *http.Server {
rp := NewRouter(r)
rr := newRelayer(r)
tc := new(tls.Config)
if rr.GetCertificate != nil {
tc.GetCertificate = rr.GetCertificate
}
return &http.Server{
Addr: addr,
Handler: rp,
TLSConfig: tc,
}
}
// ListenAndServe listens on the TCP network address addr and then proxies requests using Relayer r.
func ListenAndServe(addr string, r Relayer) error {
s := NewServer(addr, r)
return s.ListenAndServe()
}
// ListenAndServeTLS acts identically to ListenAndServe, except that it expects HTTPS connections.
func ListenAndServeTLS(addr string, r Relayer) error {
s := NewTLSServer(addr, r)
return s.ListenAndServeTLS("", "")
}
type transport struct {
rr *relayer
}
func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) {
if v := r.Header.Get(errorKey); v != "" {
if t.rr.RoundTripOnError != nil {
return t.rr.RoundTripOnError(r)
}
// If errorKey is set, return error response.
body := v
resp := &http.Response{
Status: http.StatusText(http.StatusBadGateway),
StatusCode: http.StatusBadGateway,
Proto: r.Proto,
ProtoMajor: r.ProtoMajor,
ProtoMinor: r.ProtoMinor,
Body: io.NopCloser(bytes.NewBufferString(body)),
ContentLength: int64(len(body)),
Request: r,
Header: make(http.Header, 0),
}
return resp, nil
}
return t.rr.RoundTrip(r)
}
func newTransport(rr *relayer) *transport {
return &transport{rr: rr}
}