diff --git a/.github/workflows/on-pull-request.yml b/.github/workflows/on-pull-request.yml index 76da76b0..694dd57a 100644 --- a/.github/workflows/on-pull-request.yml +++ b/.github/workflows/on-pull-request.yml @@ -41,7 +41,7 @@ jobs: run: go mod download - name: Test - run: go test -v -race -p=1 -count=1 + run: go test -v -race -p=1 -count=1 -tags holster_test_mode go-bench: runs-on: ubuntu-latest timeout-minutes: 30 diff --git a/Makefile b/Makefile index fde55c98..f93266a8 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ lint: $(GOLANGCI_LINT) .PHONY: test test: - (go test -v -race -p=1 -count=1 -coverprofile coverage.out ./...; ret=$$?; \ + (go test -v -race -p=1 -count=1 -tags holster_test_mode -coverprofile coverage.out ./...; ret=$$?; \ go tool cover -func coverage.out; \ go tool cover -html coverage.out -o coverage.html; \ exit $$ret) diff --git a/algorithms.go b/algorithms.go index c7e6315c..ce31735b 100644 --- a/algorithms.go +++ b/algorithms.go @@ -383,16 +383,17 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * // If requested hits takes the remainder if int64(b.Remaining) == r.Hits { - b.Remaining -= float64(r.Hits) - rl.Remaining = 0 + b.Remaining = 0 + rl.Remaining = int64(b.Remaining) rl.ResetTime = now + (rl.Limit-rl.Remaining)*int64(rate) return rl, nil } - // If requested is more than available, then return over the limit - // without updating the bucket. + // If requested is more than available, drain bucket in order to converge as everything is returning OVER_LIMIT. if r.Hits > int64(b.Remaining) { metricOverLimitCounter.Add(1) + b.Remaining = 0 + rl.Remaining = int64(b.Remaining) rl.Status = Status_OVER_LIMIT return rl, nil } diff --git a/cluster/cluster.go b/cluster/cluster.go index 493aa71c..bacdde30 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -68,6 +68,15 @@ func PeerAt(idx int) gubernator.PeerInfo { return peers[idx] } +// FindOwningPeer finds the peer which owns the rate limit with the provided name and unique key +func FindOwningPeer(name, key string) (gubernator.PeerInfo, error) { + p, err := daemons[0].V1Server.GetPeer(context.Background(), name+"_"+key) + if err != nil { + return gubernator.PeerInfo{}, err + } + return p.Info(), nil +} + // DaemonAt returns a specific daemon func DaemonAt(idx int) *gubernator.Daemon { return daemons[idx] diff --git a/functional_test.go b/functional_test.go index 6225fc3e..9dff61e2 100644 --- a/functional_test.go +++ b/functional_test.go @@ -25,6 +25,7 @@ import ( "os" "strings" "testing" + "time" guber "github.com/mailgun/gubernator/v2" "github.com/mailgun/gubernator/v2/cluster" @@ -34,6 +35,10 @@ import ( "github.com/prometheus/common/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/resolver" json "google.golang.org/protobuf/encoding/protojson" ) @@ -859,6 +864,88 @@ func TestGlobalRateLimits(t *testing.T) { }) } +func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { + const ( + name = "test_global_token_limit" + key = "account:12345" + ) + + // Make a connection to a peer in the cluster which does not own this rate limit + client, err := getClientToNonOwningPeer(name, key) + require.NoError(t, err) + + sendHit := func(expectedStatus guber.Status, hits int) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Hour*5) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_TOKEN_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 5, + Hits: 1, + Limit: 2, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + } + + // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining + sendHit(guber.Status_UNDER_LIMIT, 1) + sendHit(guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + time.Sleep(time.Second * 3) + // Since the remainder is 0, the peer should set OVER_LIMIT instead of waiting for the owner + // to respond with OVER_LIMIT. + sendHit(guber.Status_OVER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + time.Sleep(time.Second * 3) + // The status should still be OVER_LIMIT + sendHit(guber.Status_OVER_LIMIT, 0) +} + +func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { + const ( + name = "test_global_token_limit_leaky" + key = "account:12345" + ) + + // Make a connection to a peer in the cluster which does not own this rate limit + client, err := getClientToNonOwningPeer(name, key) + require.NoError(t, err) + + sendHit := func(expectedStatus guber.Status, hits int) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Hour*5) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 5, + Hits: 1, + Limit: 2, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + } + + sendHit(guber.Status_UNDER_LIMIT, 1) + sendHit(guber.Status_UNDER_LIMIT, 1) + time.Sleep(time.Second * 3) + sendHit(guber.Status_OVER_LIMIT, 1) +} + func getMetricRequest(t testutil.TestingT, url string, name string) *model.Sample { resp, err := http.Get(url) require.NoError(t, err) @@ -1273,3 +1360,76 @@ func getMetric(t testutil.TestingT, in io.Reader, name string) *model.Sample { } return nil } + +type staticBuilder struct{} + +var _ resolver.Builder = (*staticBuilder)(nil) + +func (sb *staticBuilder) Scheme() string { + return "static" +} + +func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { + var resolverAddrs []resolver.Address + for _, address := range strings.Split(target.Endpoint(), ",") { + resolverAddrs = append(resolverAddrs, resolver.Address{ + Addr: address, + ServerName: address, + }) + } + if err := cc.UpdateState(resolver.State{Addresses: resolverAddrs}); err != nil { + return nil, err + } + return &staticResolver{cc: cc}, nil +} + +// newStaticBuilder returns a builder which returns a staticResolver that tells GRPC +// to connect a specific peer in the cluster. +func newStaticBuilder() resolver.Builder { + return &staticBuilder{} +} + +type staticResolver struct { + cc resolver.ClientConn +} + +func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {} + +func (sr *staticResolver) Close() {} + +var _ resolver.Resolver = (*staticResolver)(nil) + +// findNonOwningPeer returns peer info for a peer in the cluster which does not +// own the rate limit for the name and key provided. +func findNonOwningPeer(name, key string) (guber.PeerInfo, error) { + owner, err := cluster.FindOwningPeer(name, key) + if err != nil { + return guber.PeerInfo{}, err + } + + for _, p := range cluster.GetPeers() { + if p.HashKey() != owner.HashKey() { + return p, nil + } + } + return guber.PeerInfo{}, fmt.Errorf("unable to find non-owning peer in '%d' node cluster", + len(cluster.GetPeers())) +} + +// getClientToNonOwningPeer returns a connection to a peer in the cluster which does not own +// the rate limit for the name and key provided. +func getClientToNonOwningPeer(name, key string) (guber.V1Client, error) { + p, err := findNonOwningPeer(name, key) + if err != nil { + return nil, err + } + conn, err := grpc.DialContext(context.Background(), + fmt.Sprintf("static:///%s", p.GRPCAddress), + grpc.WithResolvers(newStaticBuilder()), + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + return guber.NewV1Client(conn), nil + +} diff --git a/global.go b/global.go index 78431960..568c6f03 100644 --- a/global.go +++ b/global.go @@ -21,14 +21,13 @@ import ( "github.com/mailgun/holster/v4/syncutil" "github.com/prometheus/client_golang/prometheus" - "google.golang.org/protobuf/proto" ) // globalManager manages async hit queue and updates peers in // the cluster periodically when a global rate limit we own updates. type globalManager struct { asyncQueue chan *RateLimitReq - broadcastQueue chan *RateLimitReq + broadcastQueue chan *UpdatePeerGlobal wg syncutil.WaitGroup conf BehaviorConfig log FieldLogger @@ -43,7 +42,7 @@ func newGlobalManager(conf BehaviorConfig, instance *V1Instance) *globalManager gm := globalManager{ log: instance.log, asyncQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), - broadcastQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), + broadcastQueue: make(chan *UpdatePeerGlobal, conf.GlobalBatchLimit), instance: instance, conf: conf, metricGlobalSendDuration: prometheus.NewSummary(prometheus.SummaryOpts{ @@ -74,8 +73,12 @@ func (gm *globalManager) QueueHit(r *RateLimitReq) { gm.asyncQueue <- r } -func (gm *globalManager) QueueUpdate(r *RateLimitReq) { - gm.broadcastQueue <- r +func (gm *globalManager) QueueUpdate(req *RateLimitReq, resp *RateLimitResp) { + gm.broadcastQueue <- &UpdatePeerGlobal{ + Key: req.HashKey(), + Algorithm: req.Algorithm, + Status: resp, + } } // runAsyncHits collects async hit requests and queues them to @@ -173,18 +176,18 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { // runBroadcasts collects status changes for global rate limits and broadcasts the changes to each peer in the cluster. func (gm *globalManager) runBroadcasts() { var interval = NewInterval(gm.conf.GlobalSyncWait) - updates := make(map[string]*RateLimitReq) + updates := make(map[string]*UpdatePeerGlobal) gm.wg.Until(func(done chan struct{}) bool { select { - case r := <-gm.broadcastQueue: - updates[r.HashKey()] = r + case updateReq := <-gm.broadcastQueue: + updates[updateReq.Key] = updateReq // Send the hits if we reached our batch limit if len(updates) >= gm.conf.GlobalBatchLimit { gm.metricBroadcastCounter.WithLabelValues("queue_full").Inc() gm.broadcastPeers(context.Background(), updates) - updates = make(map[string]*RateLimitReq) + updates = make(map[string]*UpdatePeerGlobal) return true } @@ -198,7 +201,7 @@ func (gm *globalManager) runBroadcasts() { if len(updates) != 0 { gm.metricBroadcastCounter.WithLabelValues("timer").Inc() gm.broadcastPeers(context.Background(), updates) - updates = make(map[string]*RateLimitReq) + updates = make(map[string]*UpdatePeerGlobal) } else { gm.metricGlobalQueueLength.Set(0) } @@ -210,31 +213,14 @@ func (gm *globalManager) runBroadcasts() { } // broadcastPeers broadcasts global rate limit statuses to all other peers -func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*RateLimitReq) { +func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*UpdatePeerGlobal) { defer prometheus.NewTimer(gm.metricBroadcastDuration).ObserveDuration() var req UpdatePeerGlobalsReq gm.metricGlobalQueueLength.Set(float64(len(updates))) for _, r := range updates { - // Copy the original since we are removing the GLOBAL behavior - rl := proto.Clone(r).(*RateLimitReq) - // We are only sending the status of the rate limit so, we - // clear the behavior flag, so we don't get queued for update again. - SetBehavior(&rl.Behavior, Behavior_GLOBAL, false) - rl.Hits = 0 - - status, err := gm.instance.getLocalRateLimit(ctx, rl) - if err != nil { - gm.log.WithError(err).Errorf("while broadcasting update to peers for: '%s'", rl.HashKey()) - continue - } - // Build an UpdatePeerGlobalsReq - req.Globals = append(req.Globals, &UpdatePeerGlobal{ - Algorithm: rl.Algorithm, - Key: rl.HashKey(), - Status: status, - }) + req.Globals = append(req.Globals, r) } fan := syncutil.NewFanOut(gm.conf.GlobalPeerRequestsConcurrency) diff --git a/go.mod b/go.mod index 0f5275c6..7398f859 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.0 github.com/hashicorp/memberlist v0.5.0 github.com/mailgun/errors v0.1.5 - github.com/mailgun/holster/v4 v4.16.2-0.20231121154636-69040cb71a3b + github.com/mailgun/holster/v4 v4.16.3 github.com/miekg/dns v1.1.50 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 diff --git a/go.sum b/go.sum index 3ff2e93e..2972b6fb 100644 --- a/go.sum +++ b/go.sum @@ -291,8 +291,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mailgun/errors v0.1.5 h1:riRpZqfUKTdc8saXvoEg2tYkbRyZESU1KvQ3UxPbdus= github.com/mailgun/errors v0.1.5/go.mod h1:lw+Nh4r/aoUTz6uK915FdfZJo3yq60gPiflFHNpK4NQ= -github.com/mailgun/holster/v4 v4.16.2-0.20231121154636-69040cb71a3b h1:ohMhrwmmA4JbXNukFpriztFWEVLlMuL90Cssg2Vl2TU= -github.com/mailgun/holster/v4 v4.16.2-0.20231121154636-69040cb71a3b/go.mod h1:phAg61z7LZ1PBfedyt2GXkGSlHhuVKK9AcVJO+Cm0/U= +github.com/mailgun/holster/v4 v4.16.3 h1:YMTkDoaFV83ViSaFuAfiyIvzrHJD1UNw7RjNv6J3Kfg= +github.com/mailgun/holster/v4 v4.16.3/go.mod h1:phAg61z7LZ1PBfedyt2GXkGSlHhuVKK9AcVJO+Cm0/U= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= diff --git a/gubernator.go b/gubernator.go index 59c26eca..89e875fb 100644 --- a/gubernator.go +++ b/gubernator.go @@ -396,21 +396,23 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) tracing.EndScope(ctx, err) }() - item, ok, err := s.workerPool.GetCacheItem(ctx, req.HashKey()) - if err != nil { - countError(err, "Error in workerPool.GetCacheItem") - return nil, errors.Wrap(err, "during in workerPool.GetCacheItem") - } - if ok { - // Global rate limits are always stored as RateLimitResp regardless of algorithm - rl, ok := item.Value.(*RateLimitResp) - if ok { - return rl, nil + /* + item, ok, err := s.workerPool.GetCacheItem(ctx, req.HashKey()) + if err != nil { + countError(err, "Error in workerPool.GetCacheItem") + return nil, errors.Wrap(err, "during in workerPool.GetCacheItem") } - // We get here if the owning node hasn't asynchronously forwarded it's updates to us yet and - // our cache still holds the rate limit we created on the first hit. - } + if ok { + // Global rate limits are always stored as RateLimitResp regardless of algorithm + rl, ok := item.Value.(*RateLimitResp) + if ok { + return rl, nil + } + // We get here if the owning node hasn't asynchronously forwarded it's updates to us yet and + // our cache still holds the rate limit we created on the first hit. + } + */ cpy := proto.Clone(req).(*RateLimitReq) cpy.Behavior = Behavior_NO_BATCHING @@ -427,13 +429,29 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) // UpdatePeerGlobals updates the local cache with a list of global rate limits. This method should only // be called by a peer who is the owner of a global rate limit. func (s *V1Instance) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobalsReq) (*UpdatePeerGlobalsResp, error) { + now := MillisecondNow() for _, g := range r.Globals { item := &CacheItem{ - ExpireAt: g.Status.ResetTime, + ExpireAt: g.Status.ResetTime + 1000, // account for clock drift from owner where `ResetTime` might already be less than current time of the local machine. Algorithm: g.Algorithm, - Value: g.Status, Key: g.Key, } + switch g.Algorithm { + case Algorithm_LEAKY_BUCKET: + item.Value = &LeakyBucketItem{ + Remaining: float64(g.Status.Remaining), + Limit: g.Status.Limit, + Burst: g.Status.Limit, + UpdatedAt: now, + } + case Algorithm_TOKEN_BUCKET: + item.Value = &TokenBucketItem{ + Status: g.Status.Status, + Limit: g.Status.Limit, + Remaining: g.Status.Remaining, + CreatedAt: now, + } + } err := s.workerPool.AddCacheItem(ctx, g.Key, item) if err != nil { return nil, errors.Wrap(err, "Error in workerPool.AddCacheItem") @@ -569,11 +587,9 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq) (_ } metricGetRateLimitCounter.WithLabelValues("local").Inc() - - // If global behavior and owning peer, broadcast update to all peers. - // Assuming that this peer does not own the ratelimit. + // If global behavior, then broadcast update to all peers. if HasBehavior(r.Behavior, Behavior_GLOBAL) { - s.global.QueueUpdate(r) + s.global.QueueUpdate(r, resp) } return resp, nil