From a312ed73014107e1f9d58a04f1c3a156d61bf3c5 Mon Sep 17 00:00:00 2001 From: "Derrick J. Wippler" Date: Wed, 21 Feb 2024 07:49:29 -0700 Subject: [PATCH] Change global behavior (#219) * test: Add test for global rate limiting with load balancing * fix global update behavior * Added findNonOwningPeer() and getClientToNonOwningPeer() * Fix global mode * remove logs and add comment Co-authored-by: Yamil Asusta * fixed global over-consume issue and improved global testing --------- Co-authored-by: Philip Gough Co-authored-by: Yamil Asusta Co-authored-by: Maria Ines Parnisari --- .github/workflows/on-pull-request.yml | 2 +- Makefile | 2 +- algorithms.go | 23 +- cluster/cluster.go | 51 ++- config.go | 2 + daemon.go | 77 +++- functional_test.go | 533 +++++++++++++++++++++++--- global.go | 57 ++- gubernator.go | 51 ++- interval_test.go | 9 +- 10 files changed, 673 insertions(+), 134 deletions(-) diff --git a/.github/workflows/on-pull-request.yml b/.github/workflows/on-pull-request.yml index d89825f7..23854e74 100644 --- a/.github/workflows/on-pull-request.yml +++ b/.github/workflows/on-pull-request.yml @@ -50,7 +50,7 @@ jobs: skip-cache: true - name: Test - run: go test -v -race -p=1 -count=1 + run: go test -v -race -p=1 -count=1 -tags holster_test_mode go-bench: runs-on: ubuntu-latest timeout-minutes: 30 diff --git a/Makefile b/Makefile index 7c77cca8..75240d97 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ lint: $(GOLANGCI_LINT) .PHONY: test test: - (go test -v -race -p=1 -count=1 -coverprofile coverage.out ./...; ret=$$?; \ + (go test -v -race -p=1 -count=1 -tags holster_test_mode -coverprofile coverage.out ./...; ret=$$?; \ go tool cover -func coverage.out; \ go tool cover -html coverage.out -o coverage.html; \ exit $$ret) diff --git a/algorithms.go b/algorithms.go index 1fb8f9dd..f2ed4a82 100644 --- a/algorithms.go +++ b/algorithms.go @@ -26,6 +26,13 @@ import ( "go.opentelemetry.io/otel/trace" ) +// ### NOTE ### +// The both token and leaky follow the same semantic which allows for requests of more than the limit +// to be rejected, but subsequent requests within the same window that are under the limit to succeed. +// IE: client attempts to send 1000 emails but 100 is their limit. The request is rejected as over the +// limit, but we do not set the remainder to 0 in the cache. The client can retry within the same window +// with 100 emails and the request will succeed. You can override this default behavior with `DRAIN_OVER_LIMIT` + // Implements token bucket algorithm for rate limiting. https://en.wikipedia.org/wiki/Token_bucket func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) { @@ -82,12 +89,6 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * ResetTime: 0, }, nil } - - // The following semantic allows for requests of more than the limit to be rejected, but subsequent - // requests within the same duration that are under the limit to succeed. IE: client attempts to - // send 1000 emails but 100 is their limit. The request is rejected as over the limit, but since we - // don't store OVER_LIMIT in the cache the client can retry within the same rate limit duration with - // 100 emails and the request will succeed. t, ok := item.Value.(*TokenBucketItem) if !ok { // Client switched algorithms; perhaps due to a migration? @@ -388,22 +389,24 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * // If requested hits takes the remainder if int64(b.Remaining) == r.Hits { - b.Remaining -= float64(r.Hits) - rl.Remaining = 0 + b.Remaining = 0 + rl.Remaining = int64(b.Remaining) rl.ResetTime = now + (rl.Limit-rl.Remaining)*int64(rate) return rl, nil } // If requested is more than available, then return over the limit - // without updating the bucket. + // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. if r.Hits > int64(b.Remaining) { metricOverLimitCounter.Add(1) rl.Status = Status_OVER_LIMIT + + // DRAIN_OVER_LIMIT behavior drains the remaining counter. if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { - // DRAIN_OVER_LIMIT behavior drains the remaining counter. b.Remaining = 0 rl.Remaining = 0 } + return rl, nil } diff --git a/cluster/cluster.go b/cluster/cluster.go index 493aa71c..4c18efd6 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -68,6 +68,47 @@ func PeerAt(idx int) gubernator.PeerInfo { return peers[idx] } +// FindOwningPeer finds the peer which owns the rate limit with the provided name and unique key +func FindOwningPeer(name, key string) (gubernator.PeerInfo, error) { + p, err := daemons[0].V1Server.GetPeer(context.Background(), name+"_"+key) + if err != nil { + return gubernator.PeerInfo{}, err + } + return p.Info(), nil +} + +// FindOwningDaemon finds the daemon which owns the rate limit with the provided name and unique key +func FindOwningDaemon(name, key string) (*gubernator.Daemon, error) { + p, err := daemons[0].V1Server.GetPeer(context.Background(), name+"_"+key) + if err != nil { + return &gubernator.Daemon{}, err + } + + for i, d := range daemons { + if d.PeerInfo.GRPCAddress == p.Info().GRPCAddress { + return daemons[i], nil + } + } + return &gubernator.Daemon{}, errors.New("unable to find owning daemon") +} + +// ListNonOwningDaemons returns a list of daemons in the cluster that do not own the rate limit +// for the name and key provided. +func ListNonOwningDaemons(name, key string) ([]*gubernator.Daemon, error) { + owner, err := FindOwningDaemon(name, key) + if err != nil { + return []*gubernator.Daemon{}, err + } + + var daemons []*gubernator.Daemon + for _, d := range GetDaemons() { + if d.PeerInfo.GRPCAddress != owner.PeerInfo.GRPCAddress { + daemons = append(daemons, d) + } + } + return daemons, nil +} + // DaemonAt returns a specific daemon func DaemonAt(idx int) *gubernator.Daemon { return daemons[idx] @@ -112,6 +153,7 @@ func StartWith(localPeers []gubernator.PeerInfo) error { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) d, err := gubernator.SpawnDaemon(ctx, gubernator.DaemonConfig{ Logger: logrus.WithField("instance", peer.GRPCAddress), + InstanceID: peer.GRPCAddress, GRPCListenAddress: peer.GRPCAddress, HTTPListenAddress: peer.HTTPAddress, DataCenter: peer.DataCenter, @@ -127,12 +169,15 @@ func StartWith(localPeers []gubernator.PeerInfo) error { return errors.Wrapf(err, "while starting server for addr '%s'", peer.GRPCAddress) } - // Add the peers and daemons to the package level variables - peers = append(peers, gubernator.PeerInfo{ + p := gubernator.PeerInfo{ GRPCAddress: d.GRPCListeners[0].Addr().String(), HTTPAddress: d.HTTPListener.Addr().String(), DataCenter: peer.DataCenter, - }) + } + d.PeerInfo = p + + // Add the peers and daemons to the package level variables + peers = append(peers, p) daemons = append(daemons, d) } diff --git a/config.go b/config.go index 122ffa22..19f9f06f 100644 --- a/config.go +++ b/config.go @@ -71,6 +71,8 @@ type BehaviorConfig struct { // Config for a gubernator instance type Config struct { + InstanceID string + // (Required) A list of GRPC servers to register our instance with GRPCServers []*grpc.Server diff --git a/daemon.go b/daemon.go index a220136b..97602075 100644 --- a/daemon.go +++ b/daemon.go @@ -19,6 +19,7 @@ package gubernator import ( "context" "crypto/tls" + "fmt" "log" "net" "net/http" @@ -40,6 +41,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/resolver" "google.golang.org/protobuf/encoding/protojson" ) @@ -47,6 +49,8 @@ type Daemon struct { GRPCListeners []net.Listener HTTPListener net.Listener V1Server *V1Instance + InstanceID string + PeerInfo PeerInfo log FieldLogger pool PoolInterface @@ -59,6 +63,7 @@ type Daemon struct { promRegister *prometheus.Registry gwCancel context.CancelFunc instanceConf Config + client V1Client } // SpawnDaemon starts a new gubernator daemon according to the provided DaemonConfig. @@ -67,8 +72,9 @@ type Daemon struct { func SpawnDaemon(ctx context.Context, conf DaemonConfig) (*Daemon, error) { s := &Daemon{ - log: conf.Logger, - conf: conf, + InstanceID: conf.InstanceID, + log: conf.Logger, + conf: conf, } return s, s.Start(ctx) } @@ -77,8 +83,8 @@ func (s *Daemon) Start(ctx context.Context) error { var err error setter.SetDefault(&s.log, logrus.WithFields(logrus.Fields{ - "instance-id": s.conf.InstanceID, - "category": "gubernator", + "instance": s.conf.InstanceID, + "category": "gubernator", })) s.promRegister = prometheus.NewRegistry() @@ -148,6 +154,7 @@ func (s *Daemon) Start(ctx context.Context) error { Behaviors: s.conf.Behaviors, CacheSize: s.conf.CacheSize, Workers: s.conf.Workers, + InstanceID: s.conf.InstanceID, } s.V1Server, err = NewV1Instance(s.instanceConf) @@ -411,6 +418,30 @@ func (s *Daemon) Peers() []PeerInfo { return peers } +func (s *Daemon) MustClient() V1Client { + c, err := s.Client() + if err != nil { + panic(fmt.Sprintf("[%s] failed to init daemon client - '%s'", s.InstanceID, err)) + } + return c +} + +func (s *Daemon) Client() (V1Client, error) { + if s.client != nil { + return s.client, nil + } + + conn, err := grpc.DialContext(context.Background(), + fmt.Sprintf("static:///%s", s.PeerInfo.GRPCAddress), + grpc.WithResolvers(newStaticBuilder()), + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + s.client = NewV1Client(conn) + return s.client, nil +} + // WaitForConnect returns nil if the list of addresses is listening // for connections; will block until context is cancelled. func WaitForConnect(ctx context.Context, addresses []string) error { @@ -451,3 +482,41 @@ func WaitForConnect(ctx context.Context, addresses []string) error { } return nil } + +type staticBuilder struct{} + +var _ resolver.Builder = (*staticBuilder)(nil) + +func (sb *staticBuilder) Scheme() string { + return "static" +} + +func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { + var resolverAddrs []resolver.Address + for _, address := range strings.Split(target.Endpoint(), ",") { + resolverAddrs = append(resolverAddrs, resolver.Address{ + Addr: address, + ServerName: address, + }) + } + if err := cc.UpdateState(resolver.State{Addresses: resolverAddrs}); err != nil { + return nil, err + } + return &staticResolver{cc: cc}, nil +} + +// newStaticBuilder returns a builder which returns a staticResolver that tells GRPC +// to connect a specific peer in the cluster. +func newStaticBuilder() resolver.Builder { + return &staticBuilder{} +} + +type staticResolver struct { + cc resolver.ClientConn +} + +func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {} + +func (sr *staticResolver) Close() {} + +var _ resolver.Resolver = (*staticResolver)(nil) diff --git a/functional_test.go b/functional_test.go index 4abd2e25..b377e86a 100644 --- a/functional_test.go +++ b/functional_test.go @@ -25,6 +25,7 @@ import ( "os" "strings" "testing" + "time" guber "github.com/mailgun/gubernator/v2" "github.com/mailgun/gubernator/v2/cluster" @@ -37,14 +38,6 @@ import ( json "google.golang.org/protobuf/encoding/protojson" ) -var algos = []struct { - Name string - Algorithm guber.Algorithm -}{ - {Name: "Token bucket", Algorithm: guber.Algorithm_TOKEN_BUCKET}, - {Name: "Leaky bucket", Algorithm: guber.Algorithm_LEAKY_BUCKET}, -} - // Setup and shutdown the mock gubernator cluster for the entire test suite func TestMain(m *testing.M) { if err := cluster.StartWith([]guber.PeerInfo{ @@ -405,8 +398,8 @@ func TestDrainOverLimit(t *testing.T) { }, } - for idx, algoCase := range algos { - t.Run(algoCase.Name, func(t *testing.T) { + for idx, algoCase := range []guber.Algorithm{guber.Algorithm_TOKEN_BUCKET, guber.Algorithm_LEAKY_BUCKET} { + t.Run(guber.Algorithm_name[int32(algoCase)], func(t *testing.T) { for _, test := range tests { ctx := context.Background() t.Run(test.Name, func(t *testing.T) { @@ -415,7 +408,7 @@ func TestDrainOverLimit(t *testing.T) { { Name: "test_drain_over_limit", UniqueKey: fmt.Sprintf("account:1234:%d", idx), - Algorithm: algoCase.Algorithm, + Algorithm: algoCase, Behavior: guber.Behavior_DRAIN_OVER_LIMIT, Duration: guber.Second * 30, Hits: test.Hits, @@ -437,6 +430,49 @@ func TestDrainOverLimit(t *testing.T) { } } +func TestTokenBucketRequestMoreThanAvailable(t *testing.T) { + defer clock.Freeze(clock.Now()).Unfreeze() + + client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) + require.NoError(t, err) + + sendHit := func(status guber.Status, remain int64, hit int64) *guber.RateLimitResp { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: "test_token_more_than_available", + UniqueKey: "account:123456", + Algorithm: guber.Algorithm_TOKEN_BUCKET, + Duration: guber.Millisecond * 1000, + Hits: hit, + Limit: 2000, + }, + }, + }) + require.NoError(t, err, hit) + assert.Equal(t, "", resp.Responses[0].Error) + assert.Equal(t, status, resp.Responses[0].Status) + assert.Equal(t, remain, resp.Responses[0].Remaining) + assert.Equal(t, int64(2000), resp.Responses[0].Limit) + return resp.Responses[0] + } + + // Use half of the bucket + sendHit(guber.Status_UNDER_LIMIT, 1000, 1000) + + // Ask for more than the bucket has and the remainder is still 1000. + // See NOTE in algorithms.go + sendHit(guber.Status_OVER_LIMIT, 1000, 1500) + + // Now other clients can ask for some of the remaining until we hit our limit + sendHit(guber.Status_UNDER_LIMIT, 500, 500) + sendHit(guber.Status_UNDER_LIMIT, 100, 400) + sendHit(guber.Status_UNDER_LIMIT, 0, 100) + sendHit(guber.Status_OVER_LIMIT, 0, 1) +} + func TestLeakyBucket(t *testing.T) { defer clock.Freeze(clock.Now()).Unfreeze() @@ -696,7 +732,7 @@ func TestLeakyBucketGregorian(t *testing.T) { Hits: 1, Remaining: 58, Status: guber.Status_UNDER_LIMIT, - Sleep: clock.Second, + Sleep: clock.Millisecond * 1200, }, { Name: "third hit; leak one hit", @@ -706,7 +742,12 @@ func TestLeakyBucketGregorian(t *testing.T) { }, } + // Truncate to the nearest minute now := clock.Now() + now = now.Truncate(1 * time.Minute) + // So we don't start on the minute boundary + now = now.Add(time.Millisecond * 100) + for _, test := range tests { t.Run(test.Name, func(t *testing.T) { resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ @@ -807,6 +848,50 @@ func TestLeakyBucketNegativeHits(t *testing.T) { } } +func TestLeakyBucketRequestMoreThanAvailable(t *testing.T) { + // Freeze time so we don't leak during the test + defer clock.Freeze(clock.Now()).Unfreeze() + + client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) + require.NoError(t, err) + + sendHit := func(status guber.Status, remain int64, hits int64) *guber.RateLimitResp { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: "test_leaky_more_than_available", + UniqueKey: "account:123456", + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Duration: guber.Millisecond * 1000, + Hits: hits, + Limit: 2000, + }, + }, + }) + require.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].Error) + assert.Equal(t, status, resp.Responses[0].Status) + assert.Equal(t, remain, resp.Responses[0].Remaining) + assert.Equal(t, int64(2000), resp.Responses[0].Limit) + return resp.Responses[0] + } + + // Use half of the bucket + sendHit(guber.Status_UNDER_LIMIT, 1000, 1000) + + // Ask for more than the rate limit has and the remainder is still 1000. + // See NOTE in algorithms.go + sendHit(guber.Status_OVER_LIMIT, 1000, 1500) + + // Now other clients can ask for some of the remaining until we hit our limit + sendHit(guber.Status_UNDER_LIMIT, 500, 500) + sendHit(guber.Status_UNDER_LIMIT, 100, 400) + sendHit(guber.Status_UNDER_LIMIT, 0, 100) + sendHit(guber.Status_OVER_LIMIT, 0, 1) +} + func TestMissingFields(t *testing.T) { client, errs := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) require.Nil(t, errs) @@ -871,12 +956,16 @@ func TestMissingFields(t *testing.T) { } func TestGlobalRateLimits(t *testing.T) { - peer := cluster.PeerAt(0).GRPCAddress - client, errs := guber.DialV1Server(peer, nil) - require.NoError(t, errs) + const ( + name = "test_global" + key = "account:12345" + ) - sendHit := func(status guber.Status, remain int64, i int) string { - ctx, cancel := context.WithTimeout(context.Background(), clock.Second*5) + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, status guber.Status, hits int64, remain int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ @@ -885,59 +974,346 @@ func TestGlobalRateLimits(t *testing.T) { UniqueKey: "account:12345", Algorithm: guber.Algorithm_TOKEN_BUCKET, Behavior: guber.Behavior_GLOBAL, - Duration: guber.Second * 3, - Hits: 1, + Duration: guber.Minute * 3, + Hits: hits, Limit: 5, }, }, }) - require.NoError(t, err, i) - assert.Equal(t, "", resp.Responses[0].Error, i) - assert.Equal(t, status, resp.Responses[0].Status, i) - assert.Equal(t, remain, resp.Responses[0].Remaining, i) - assert.Equal(t, int64(5), resp.Responses[0].Limit, i) - - // ensure that we have a canonical host - assert.NotEmpty(t, resp.Responses[0].Metadata["owner"]) - - // name/key should ensure our connected peer is NOT the owner, - // the peer we are connected to should forward requests asynchronously to the owner. - assert.NotEqual(t, peer, resp.Responses[0].Metadata["owner"]) - - return resp.Responses[0].Metadata["owner"] + require.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].Error) + assert.Equal(t, remain, resp.Responses[0].Remaining) + assert.Equal(t, status, resp.Responses[0].Status) + assert.Equal(t, int64(5), resp.Responses[0].Limit) } - // Our first hit should create the request on the peer and queue for async forward - sendHit(guber.Status_UNDER_LIMIT, 4, 1) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1, 4) // Our second should be processed as if we own it since the async forward hasn't occurred yet - sendHit(guber.Status_UNDER_LIMIT, 3, 2) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 2, 2) testutil.UntilPass(t, 20, clock.Millisecond*200, func(t testutil.TestingT) { - // Inspect our metrics, ensure they collected the counts we expected during this test - d := cluster.DaemonAt(0) - metricsURL := fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress) - m := getMetricRequest(t, metricsURL, "gubernator_global_send_duration_count") + // Inspect peers metrics, ensure the peer sent the global rate limit to the owner + metricsURL := fmt.Sprintf("http://%s/metrics", peers[0].Config().HTTPListenAddress) + m, err := getMetricRequest(metricsURL, "gubernator_global_send_duration_count") + assert.NoError(t, err) assert.Equal(t, 1, int(m.Value)) + }) + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) - // Expect one peer (the owning peer) to indicate a broadcast. - var broadcastCount int - for i := 0; i < cluster.NumOfDaemons(); i++ { - d := cluster.DaemonAt(i) - metricsURL := fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress) - m := getMetricRequest(t, metricsURL, "gubernator_broadcast_duration_count") - broadcastCount += int(m.Value) - } + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) + + // Check different peers, they should have gotten the broadcast from the owner + sendHit(peers[1].MustClient(), guber.Status_UNDER_LIMIT, 0, 2) + sendHit(peers[2].MustClient(), guber.Status_UNDER_LIMIT, 0, 2) + + // Non owning peer should calculate the rate limit remaining before forwarding + // to the owner. + sendHit(peers[3].MustClient(), guber.Status_UNDER_LIMIT, 2, 0) + + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 2)) + + sendHit(peers[4].MustClient(), guber.Status_OVER_LIMIT, 1, 0) +} + +func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { + const ( + name = "test_global_token_limit" + key = "account:12345" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(expectedStatus guber.Status, hits int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := peers[0].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_TOKEN_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 5, + Hits: hits, + Limit: 2, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + + // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining + sendHit(guber.Status_UNDER_LIMIT, 1) + sendHit(guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) + // Since the remainder is 0, the peer should set OVER_LIMIT instead of waiting for the owner + // to respond with OVER_LIMIT. + sendHit(guber.Status_OVER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 2)) + // The status should still be OVER_LIMIT + sendHit(guber.Status_OVER_LIMIT, 0) +} + +func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { + const ( + name = "test_global_token_limit_leaky" + key = "account:12345" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 5, + Hits: hits, + Limit: 2, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + + // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peers + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) + // Ask a different peer if the status is over the limit + sendHit(peers[1].MustClient(), guber.Status_OVER_LIMIT, 1) +} + +func TestGlobalRequestMoreThanAvailable(t *testing.T) { + const ( + name = "test_global_more_than_available" + key = "account:123456" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64, remaining int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 1_000, + Hits: hits, + Limit: 100, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + + prev, err := getBroadcastCount(owner) + require.NoError(t, err) + + // Ensure GRPC has connections to each peer before we start, as we want + // the actual test requests to happen quite fast. + for _, p := range peers { + sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 0, 100) + } + + // Send a request for 50 hits from each non owning peer in the cluster. These requests + // will be queued and sent to the owner as accumulated hits. As a result of the async nature + // of `Behavior_GLOBAL` rate limit requests spread across peers like this will be allowed to + // over-consume their resource within the rate limit window until the owner is updated and + // a broadcast to all peers is received. + // + // The maximum number of resources that can be over-consumed can be calculated by multiplying + // the remainder by the number of peers in the cluster. For example: If you have a remainder of 100 + // and a cluster of 10 instances, then the maximum over-consumed resource is 1,000. If you need + // a more accurate remaining calculation, and wish to avoid over consuming a resource, then do + // not use `Behavior_GLOBAL`. + for _, p := range peers { + sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 50, 50) + } + + // Wait for the broadcast from the owner to the peers + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) + + // We should be over the limit + sendHit(peers[0].MustClient(), guber.Status_OVER_LIMIT, 1, 0) +} + +func TestGlobalNegativeHits(t *testing.T) { + const ( + name = "test_global_negative_hits" + key = "account:12345" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, status guber.Status, hits int64, remaining int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_TOKEN_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 100, + Hits: hits, + Limit: 2, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, status, resp.Responses[0].GetStatus()) + assert.Equal(t, remaining, resp.Responses[0].Remaining) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + prev, err := getBroadcastCount(owner) + require.NoError(t, err) + + // Send a negative hit on a rate limit with no hits + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, -1, 3) + + // Wait for the negative remaining to propagate + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) + + // Send another negative hit to a different peer + sendHit(peers[1].MustClient(), guber.Status_UNDER_LIMIT, -1, 4) + + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+2)) + + // Should have 4 in the remainder + sendHit(peers[2].MustClient(), guber.Status_UNDER_LIMIT, 4, 0) + + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+3)) + + sendHit(peers[3].MustClient(), guber.Status_UNDER_LIMIT, 0, 0) +} + +func TestGlobalResetRemaining(t *testing.T) { + const ( + name = "test_global_reset" + key = "account:123456" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64, remaining int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 1_000, + Hits: hits, + Limit: 100, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + assert.Equal(t, remaining, resp.Responses[0].Remaining) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + prev, err := getBroadcastCount(owner) + require.NoError(t, err) + + for _, p := range peers { + sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 50, 50) + } + + // Wait for the broadcast from the owner to the peers + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) + + // We should be over the limit and remaining should be zero + sendHit(peers[0].MustClient(), guber.Status_OVER_LIMIT, 1, 0) - assert.Equal(t, 1, broadcastCount) + // Now reset the remaining + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := peers[0].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL | guber.Behavior_RESET_REMAINING, + Duration: guber.Minute * 1_000, + Hits: 0, + Limit: 100, + }, + }, + }) + require.NoError(t, err) + assert.NotEqual(t, 100, resp.Responses[0].Remaining) + + // Wait for the reset to propagate. + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+2)) + + // Check a different peer to ensure remaining has been reset + resp, err = peers[1].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 1_000, + Hits: 0, + Limit: 100, + }, + }, }) + require.NoError(t, err) + assert.NotEqual(t, 100, resp.Responses[0].Remaining) + } -func getMetricRequest(t testutil.TestingT, url string, name string) *model.Sample { +func getMetricRequest(url string, name string) (*model.Sample, error) { resp, err := http.Get(url) - require.NoError(t, err) + if err != nil { + return nil, err + } defer resp.Body.Close() - return getMetric(t, resp.Body, name) + return getMetric(resp.Body, name) } func TestChangeLimit(t *testing.T) { @@ -1174,6 +1550,7 @@ func TestHealthCheck(t *testing.T) { } func TestLeakyBucketDivBug(t *testing.T) { + // Freeze time so we don't leak during the test defer clock.Freeze(clock.Now()).Unfreeze() client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) @@ -1321,7 +1698,7 @@ func TestGetPeerRateLimits(t *testing.T) { // TODO: Add a test for sending no rate limits RateLimitReqList.RateLimits = nil -func getMetric(t testutil.TestingT, in io.Reader, name string) *model.Sample { +func getMetric(in io.Reader, name string) (*model.Sample, error) { dec := expfmt.SampleDecoder{ Dec: expfmt.NewDecoder(in, expfmt.FmtText), Opts: &expfmt.DecodeOptions{ @@ -1336,14 +1713,58 @@ func getMetric(t testutil.TestingT, in io.Reader, name string) *model.Sample { if err == io.EOF { break } - assert.NoError(t, err) + if err != nil { + return nil, err + } all = append(all, smpls...) } for _, s := range all { if strings.Contains(s.Metric.String(), name) { - return s + return s, nil + } + } + return nil, nil +} + +// getBroadcastCount returns the current broadcast count for use with waitForBroadcast() +// TODO: Replace this with something else, we can call and reset via HTTP/GRPC calls in gubernator v3 +func getBroadcastCount(d *guber.Daemon) (int, error) { + m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), + "gubernator_broadcast_duration_count") + if err != nil { + return 0, err + } + + return int(m.Value), nil +} + +// waitForBroadcast waits until the broadcast count for the daemon passed +// changes to the expected value. Returns an error if the expected value is +// not found before the context is cancelled. +func waitForBroadcast(timeout clock.Duration, d *guber.Daemon, expect int) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for { + m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), + "gubernator_broadcast_duration_count") + if err != nil { + return err + } + + // It's possible a broadcast occurred twice if waiting for multiple peer to + // forward updates to the owner. + if int(m.Value) >= expect { + // Give the nodes some time to process the broadcasts + clock.Sleep(clock.Millisecond * 500) + return nil + } + + select { + case <-clock.After(time.Millisecond * 800): + case <-ctx.Done(): + return ctx.Err() } } - return nil } diff --git a/global.go b/global.go index cd113108..adbd8e44 100644 --- a/global.go +++ b/global.go @@ -21,14 +21,13 @@ import ( "github.com/mailgun/holster/v4/syncutil" "github.com/prometheus/client_golang/prometheus" - "google.golang.org/protobuf/proto" ) // globalManager manages async hit queue and updates peers in // the cluster periodically when a global rate limit we own updates. type globalManager struct { hitsQueue chan *RateLimitReq - updatesQueue chan *RateLimitReq + broadcastQueue chan *UpdatePeerGlobal wg syncutil.WaitGroup conf BehaviorConfig log FieldLogger @@ -41,11 +40,11 @@ type globalManager struct { func newGlobalManager(conf BehaviorConfig, instance *V1Instance) *globalManager { gm := globalManager{ - log: instance.log, - hitsQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), - updatesQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), - instance: instance, - conf: conf, + log: instance.log, + hitsQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), + broadcastQueue: make(chan *UpdatePeerGlobal, conf.GlobalBatchLimit), + instance: instance, + conf: conf, metricGlobalSendDuration: prometheus.NewSummary(prometheus.SummaryOpts{ Name: "gubernator_global_send_duration", Help: "The duration of GLOBAL async sends in seconds.", @@ -74,8 +73,12 @@ func (gm *globalManager) QueueHit(r *RateLimitReq) { gm.hitsQueue <- r } -func (gm *globalManager) QueueUpdate(r *RateLimitReq) { - gm.updatesQueue <- r +func (gm *globalManager) QueueUpdate(req *RateLimitReq, resp *RateLimitResp) { + gm.broadcastQueue <- &UpdatePeerGlobal{ + Key: req.HashKey(), + Algorithm: req.Algorithm, + Status: resp, + } } // runAsyncHits collects async hit requests in a forever loop, @@ -95,6 +98,11 @@ func (gm *globalManager) runAsyncHits() { key := r.HashKey() _, ok := hits[key] if ok { + // If any of our hits includes a request to RESET_REMAINING + // ensure the owning peer gets this behavior + if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { + SetBehavior(&hits[key].Behavior, Behavior_RESET_REMAINING, true) + } hits[key].Hits += r.Hits } else { hits[key] = r @@ -142,7 +150,6 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { gm.log.WithError(err).Errorf("while getting peer for hash key '%s'", r.HashKey()) continue } - p, ok := peerRequests[peer.Info().GRPCAddress] if ok { p.req.Requests = append(p.req.Requests, r) @@ -179,18 +186,18 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { // and in a periodic frequency determined by GlobalSyncWait. func (gm *globalManager) runBroadcasts() { var interval = NewInterval(gm.conf.GlobalSyncWait) - updates := make(map[string]*RateLimitReq) + updates := make(map[string]*UpdatePeerGlobal) gm.wg.Until(func(done chan struct{}) bool { select { - case r := <-gm.updatesQueue: - updates[r.HashKey()] = r + case updateReq := <-gm.broadcastQueue: + updates[updateReq.Key] = updateReq // Send the hits if we reached our batch limit if len(updates) >= gm.conf.GlobalBatchLimit { gm.metricBroadcastCounter.WithLabelValues("queue_full").Inc() gm.broadcastPeers(context.Background(), updates) - updates = make(map[string]*RateLimitReq) + updates = make(map[string]*UpdatePeerGlobal) return true } @@ -204,7 +211,7 @@ func (gm *globalManager) runBroadcasts() { if len(updates) != 0 { gm.metricBroadcastCounter.WithLabelValues("timer").Inc() gm.broadcastPeers(context.Background(), updates) - updates = make(map[string]*RateLimitReq) + updates = make(map[string]*UpdatePeerGlobal) } else { gm.metricGlobalQueueLength.Set(0) } @@ -216,30 +223,14 @@ func (gm *globalManager) runBroadcasts() { } // broadcastPeers broadcasts global rate limit statuses to all other peers -func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*RateLimitReq) { +func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*UpdatePeerGlobal) { defer prometheus.NewTimer(gm.metricBroadcastDuration).ObserveDuration() var req UpdatePeerGlobalsReq gm.metricGlobalQueueLength.Set(float64(len(updates))) for _, r := range updates { - // Copy the original since we are removing the GLOBAL behavior - rl := proto.Clone(r).(*RateLimitReq) - // We are only sending the status of the rate limit so, we - // clear the behavior flag, so we don't get queued for update again. - SetBehavior(&rl.Behavior, Behavior_GLOBAL, false) - rl.Hits = 0 - - status, err := gm.instance.getLocalRateLimit(ctx, rl) - if err != nil { - gm.log.WithError(err).Errorf("while getting local rate limit for: '%s'", rl.HashKey()) - continue - } - req.Globals = append(req.Globals, &UpdatePeerGlobal{ - Algorithm: rl.Algorithm, - Key: rl.HashKey(), - Status: status, - }) + req.Globals = append(req.Globals, r) } fan := syncutil.NewFanOut(gm.conf.GlobalPeerRequestsConcurrency) diff --git a/gubernator.go b/gubernator.go index 59c26eca..58f3f616 100644 --- a/gubernator.go +++ b/gubernator.go @@ -396,23 +396,9 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) tracing.EndScope(ctx, err) }() - item, ok, err := s.workerPool.GetCacheItem(ctx, req.HashKey()) - if err != nil { - countError(err, "Error in workerPool.GetCacheItem") - return nil, errors.Wrap(err, "during in workerPool.GetCacheItem") - } - if ok { - // Global rate limits are always stored as RateLimitResp regardless of algorithm - rl, ok := item.Value.(*RateLimitResp) - if ok { - return rl, nil - } - // We get here if the owning node hasn't asynchronously forwarded it's updates to us yet and - // our cache still holds the rate limit we created on the first hit. - } - cpy := proto.Clone(req).(*RateLimitReq) - cpy.Behavior = Behavior_NO_BATCHING + SetBehavior(&cpy.Behavior, Behavior_NO_BATCHING, true) + SetBehavior(&cpy.Behavior, Behavior_GLOBAL, false) // Process the rate limit like we own it resp, err = s.getLocalRateLimit(ctx, cpy) @@ -427,13 +413,29 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) // UpdatePeerGlobals updates the local cache with a list of global rate limits. This method should only // be called by a peer who is the owner of a global rate limit. func (s *V1Instance) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobalsReq) (*UpdatePeerGlobalsResp, error) { + now := MillisecondNow() for _, g := range r.Globals { item := &CacheItem{ ExpireAt: g.Status.ResetTime, Algorithm: g.Algorithm, - Value: g.Status, Key: g.Key, } + switch g.Algorithm { + case Algorithm_LEAKY_BUCKET: + item.Value = &LeakyBucketItem{ + Remaining: float64(g.Status.Remaining), + Limit: g.Status.Limit, + Burst: g.Status.Limit, + UpdatedAt: now, + } + case Algorithm_TOKEN_BUCKET: + item.Value = &TokenBucketItem{ + Status: g.Status.Status, + Limit: g.Status.Limit, + Remaining: g.Status.Remaining, + CreatedAt: now, + } + } err := s.workerPool.AddCacheItem(ctx, g.Key, item) if err != nil { return nil, errors.Wrap(err, "Error in workerPool.AddCacheItem") @@ -485,6 +487,15 @@ func (s *V1Instance) GetPeerRateLimits(ctx context.Context, r *GetPeerRateLimits // Extract the propagated context from the metadata in the request prop := propagation.TraceContext{} ctx := prop.Extract(ctx, &MetadataCarrier{Map: rin.req.Metadata}) + + // Forwarded global requests must have DRAIN_OVER_LIMIT set so token and leaky algorithms + // drain the remaining in the event a peer asks for more than is remaining. + // This is needed because with GLOBAL behavior peers will accumulate hits, which could + // result in requesting more hits than is remaining. + if HasBehavior(rin.req.Behavior, Behavior_GLOBAL) { + SetBehavior(&rin.req.Behavior, Behavior_DRAIN_OVER_LIMIT, true) + } + rl, err := s.getLocalRateLimit(ctx, rin.req) if err != nil { // Return the error for this request @@ -569,11 +580,9 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq) (_ } metricGetRateLimitCounter.WithLabelValues("local").Inc() - - // If global behavior and owning peer, broadcast update to all peers. - // Assuming that this peer does not own the ratelimit. + // If global behavior, then broadcast update to all peers. if HasBehavior(r.Behavior, Behavior_GLOBAL) { - s.global.QueueUpdate(r) + s.global.QueueUpdate(r, resp) } return resp, nil diff --git a/interval_test.go b/interval_test.go index 89642c3e..68c8b40d 100644 --- a/interval_test.go +++ b/interval_test.go @@ -18,9 +18,8 @@ package gubernator_test import ( "testing" - "time" - gubernator "github.com/mailgun/gubernator/v2" + "github.com/mailgun/gubernator/v2" "github.com/mailgun/holster/v4/clock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -28,18 +27,18 @@ import ( func TestInterval(t *testing.T) { t.Run("Happy path", func(t *testing.T) { - interval := gubernator.NewInterval(10 * time.Millisecond) + interval := gubernator.NewInterval(10 * clock.Millisecond) defer interval.Stop() interval.Next() assert.Empty(t, interval.C) - time.Sleep(10 * time.Millisecond) + clock.Sleep(10 * clock.Millisecond) // Wait for tick. select { case <-interval.C: - case <-time.After(100 * time.Millisecond): + case <-clock.After(100 * clock.Millisecond): require.Fail(t, "timeout") } })