-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrl_test.go
117 lines (110 loc) · 3.49 KB
/
rl_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
package rl_test
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/2manymws/rl"
"github.com/2manymws/rl/testutil"
"github.com/go-chi/httprate"
)
var _ rl.Limiter = (*testutil.Limiter)(nil)
func TestRateLimit(t *testing.T) {
const noLimitReq = 100
tests := []struct {
name string
limiters []rl.Limiter
hosts []string
wantReqCount int
wantStatusCode int
hasXRateLimitHeaders bool
}{
{"key by ip", []rl.Limiter{testutil.NewLimiter(10, httprate.KeyByIP, 0)}, []string{"a.example.com", "b.example.com"}, 10, http.StatusTooManyRequests, true},
{"key by host", []rl.Limiter{testutil.NewLimiter(10, testutil.KeyByHost, 0)}, []string{"a.example.com", "b.example.com"}, 20, http.StatusTooManyRequests, true},
{"no limit", []rl.Limiter{testutil.NewLimiter(-1, httprate.KeyByIP, 0)}, []string{"a.example.com", "b.example.com"}, noLimitReq, http.StatusTooManyRequests, false},
{"set other statusCode", []rl.Limiter{testutil.NewLimiter(10, httprate.KeyByIP, http.StatusOK)}, []string{"a.example.com", "b.example.com"}, 10, http.StatusOK, true},
{"b.example.com is limited", []rl.Limiter{testutil.NewSkipper("a.example.com"), testutil.NewLimiter(10, testutil.KeyByHost, 0)}, []string{"b.example.com"}, 10, http.StatusTooManyRequests, true},
{"a.example.com allows unlimited requests", []rl.Limiter{testutil.NewSkipper("a.example.com"), testutil.NewLimiter(10, testutil.KeyByHost, 0)}, []string{"a.example.com"}, noLimitReq, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := http.NewServeMux()
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("Hello, world"))
if err != nil {
t.Fatal(err)
}
})
m := rl.New(tt.limiters...)
ts := httptest.NewServer(m(r))
t.Cleanup(func() {
ts.Close()
})
got := 0
L:
for {
for _, host := range tt.hosts {
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
req.Host = host
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
b, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if err := res.Body.Close(); err != nil {
t.Fatal(err)
}
if (res.Header.Get("X-RateLimit-Limit") != "") != tt.hasXRateLimitHeaders {
t.Errorf("got %v want %v", (res.Header.Get("X-RateLimit-Limit") != ""), tt.hasXRateLimitHeaders)
}
if strings.Contains(string(b), "Too many requests") {
if res.StatusCode != tt.wantStatusCode {
t.Errorf("got %v want %v", res.StatusCode, tt.wantStatusCode)
}
break L
}
got++
if got == noLimitReq { // circuit breaker
break L
}
}
}
if got != tt.wantReqCount {
t.Errorf("got %v want %v", got, tt.wantReqCount)
}
})
}
}
func BenchmarkRL(b *testing.B) { //nostyle:all
r := http.NewServeMux()
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("Hello, world"))
if err != nil {
b.Fatal(err)
}
})
m := rl.New(testutil.NewLimiter(10, httprate.KeyByIP, 0), testutil.NewLimiter(10, testutil.KeyByHost, 0))
ts := httptest.NewServer(m(r))
b.Cleanup(func() {
ts.Close()
})
for i := 0; i < b.N; i++ {
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
b.Fatal(err)
}
req.Host = "a.example.com"
res, err := http.DefaultClient.Do(req)
if err != nil {
b.Fatal(err)
}
res.Body.Close()
}
}