-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathprotocol.go
290 lines (257 loc) · 7.25 KB
/
protocol.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
package proxyproto
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log"
"net"
"strconv"
"strings"
"sync"
"time"
)
var (
// prefix is the string we look for at the start of a connection
// to check if this connection is using the proxy protocol
prefix = []byte("PROXY ")
prefixLen = len(prefix)
ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information")
)
// SourceChecker can be used to decide whether to trust the PROXY info or pass
// the original connection address through. If set, the connecting address is
// passed in as an argument. If the function returns an error due to the source
// being disallowed, it should return ErrInvalidUpstream.
//
// If error is not nil, the call to Accept() will fail. If the reason for
// triggering this failure is due to a disallowed source, it should return
// ErrInvalidUpstream.
//
// If bool is true, the PROXY-set address is used.
//
// If bool is false, the connection's remote address is used, rather than the
// address claimed in the PROXY info.
type SourceChecker func(net.Addr) (bool, error)
// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol (version 1).
// If the connection is using the protocol, the RemoteAddr() will return
// the correct client address.
//
// Optionally define ProxyHeaderTimeout to set a maximum time to
// receive the Proxy Protocol Header. Zero means no timeout.
type Listener struct {
Listener net.Listener
ProxyHeaderTimeout time.Duration
SourceCheck SourceChecker
UnknownOK bool // allow PROXY UNKNOWN
}
// Conn is used to wrap and underlying connection which
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address.
type Conn struct {
bufReader *bufio.Reader
conn net.Conn
dstAddr *net.TCPAddr
srcAddr *net.TCPAddr
useConnAddr bool
once sync.Once
proxyHeaderTimeout time.Duration
unknownOK bool
}
// Accept waits for and returns the next connection to the listener.
func (p *Listener) Accept() (net.Conn, error) {
// Get the underlying connection
for {
conn, err := p.Listener.Accept()
if err != nil {
return nil, err
}
var useConnAddr bool
if p.SourceCheck != nil {
allowed, err := p.SourceCheck(conn.RemoteAddr())
if err != nil {
if err == ErrInvalidUpstream {
conn.Close()
continue
}
return nil, err
}
if !allowed {
useConnAddr = true
}
}
newConn := NewConn(conn, p.ProxyHeaderTimeout)
newConn.useConnAddr = useConnAddr
newConn.unknownOK = p.UnknownOK
return newConn, nil
}
}
// Close closes the underlying listener.
func (p *Listener) Close() error {
return p.Listener.Close()
}
// Addr returns the underlying listener's network address.
func (p *Listener) Addr() net.Addr {
return p.Listener.Addr()
}
// NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn
func NewConn(conn net.Conn, timeout time.Duration) *Conn {
pConn := &Conn{
bufReader: bufio.NewReader(conn),
conn: conn,
proxyHeaderTimeout: timeout,
}
return pConn
}
// Read is check for the proxy protocol header when doing
// the initial scan. If there is an error parsing the header,
// it is returned and the socket is closed.
func (p *Conn) Read(b []byte) (int, error) {
var err error
p.once.Do(func() { err = p.checkPrefix() })
if err != nil {
return 0, err
}
return p.bufReader.Read(b)
}
func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
if rf, ok := p.conn.(io.ReaderFrom); ok {
return rf.ReadFrom(r)
}
return io.Copy(p.conn, r)
}
func (p *Conn) WriteTo(w io.Writer) (int64, error) {
var err error
p.once.Do(func() { err = p.checkPrefix() })
if err != nil {
return 0, err
}
return p.bufReader.WriteTo(w)
}
func (p *Conn) Write(b []byte) (int, error) {
return p.conn.Write(b)
}
func (p *Conn) Close() error {
return p.conn.Close()
}
func (p *Conn) LocalAddr() net.Addr {
p.checkPrefixOnce()
if p.dstAddr != nil && !p.useConnAddr {
return p.dstAddr
}
return p.conn.LocalAddr()
}
// RemoteAddr returns the address of the client if the proxy
// protocol is being used, otherwise just returns the address of
// the socket peer. If there is an error parsing the header, the
// address of the client is not returned, and the socket is closed.
// Once implication of this is that the call could block if the
// client is slow. Using a Deadline is recommended if this is called
// before Read()
func (p *Conn) RemoteAddr() net.Addr {
p.checkPrefixOnce()
if p.srcAddr != nil && !p.useConnAddr {
return p.srcAddr
}
return p.conn.RemoteAddr()
}
func (p *Conn) SetDeadline(t time.Time) error {
return p.conn.SetDeadline(t)
}
func (p *Conn) SetReadDeadline(t time.Time) error {
return p.conn.SetReadDeadline(t)
}
func (p *Conn) SetWriteDeadline(t time.Time) error {
return p.conn.SetWriteDeadline(t)
}
func (p *Conn) checkPrefixOnce() {
p.once.Do(func() {
if err := p.checkPrefix(); err != nil && err != io.EOF {
log.Printf("[ERR] Failed to read proxy prefix: %v", err)
p.Close()
p.bufReader = bufio.NewReader(p.conn)
}
})
}
func (p *Conn) checkPrefix() error {
if p.proxyHeaderTimeout != 0 {
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
p.conn.SetReadDeadline(readDeadLine)
defer p.conn.SetReadDeadline(time.Time{})
}
// Incrementally check each byte of the prefix
for i := 1; i <= prefixLen; i++ {
inp, err := p.bufReader.Peek(i)
if err != nil {
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
return nil
} else {
return err
}
}
// Check for a prefix mis-match, quit early
if !bytes.Equal(inp, prefix[:i]) {
return nil
}
}
// Read the header line
header, err := p.bufReader.ReadString('\n')
if err != nil {
p.conn.Close()
return err
}
// Strip the carriage return and new line
header = header[:len(header)-2]
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
parts := strings.Split(header, " ")
if len(parts) < 2 {
p.conn.Close()
return fmt.Errorf("Invalid header line: %s", header)
}
// Verify the type is known
switch parts[1] {
case "UNKNOWN":
if !p.unknownOK || len(parts) != 2 {
p.conn.Close()
return fmt.Errorf("Invalid UNKNOWN header line: %s", header)
}
p.useConnAddr = true
return nil
case "TCP4":
case "TCP6":
default:
p.conn.Close()
return fmt.Errorf("Unhandled address type: %s", parts[1])
}
if len(parts) != 6 {
p.conn.Close()
return fmt.Errorf("Invalid header line: %s", header)
}
// Parse out the source address
ip := net.ParseIP(parts[2])
if ip == nil {
p.conn.Close()
return fmt.Errorf("Invalid source ip: %s", parts[2])
}
port, err := strconv.Atoi(parts[4])
if err != nil {
p.conn.Close()
return fmt.Errorf("Invalid source port: %s", parts[4])
}
p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
// Parse out the destination address
ip = net.ParseIP(parts[3])
if ip == nil {
p.conn.Close()
return fmt.Errorf("Invalid destination ip: %s", parts[3])
}
port, err = strconv.Atoi(parts[5])
if err != nil {
p.conn.Close()
return fmt.Errorf("Invalid destination port: %s", parts[5])
}
p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
return nil
}