diff --git a/functional_test.go b/functional_test.go index f604e26d..2ffe6447 100644 --- a/functional_test.go +++ b/functional_test.go @@ -864,7 +864,7 @@ func TestGlobalRateLimits(t *testing.T) { }) } -func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { +func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { owner := cluster.PeerAt(2).GRPCAddress peer := cluster.PeerAt(0).GRPCAddress assert.NotEqual(t, owner, peer) @@ -872,24 +872,23 @@ func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { dialOpts := []grpc.DialOption{ grpc.WithResolvers(newStaticBuilder()), grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`), } - address := fmt.Sprintf("static:///%s,%s", owner, peer) + address := fmt.Sprintf("static:///%s", peer) conn, err := grpc.DialContext(context.Background(), address, dialOpts...) require.NoError(t, err) client := guber.NewV1Client(conn) - sendHit := func(status guber.Status, assertion func(resp *guber.RateLimitResp), i int) string { + sendHit := func(status guber.Status, i int) string { ctx, cancel := context.WithTimeout(context.Background(), clock.Hour*5) defer cancel() resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ { - Name: "test_global", + Name: "test_global_token_limit", UniqueKey: "account:12345", - Algorithm: guber.Algorithm_LEAKY_BUCKET, + Algorithm: guber.Algorithm_TOKEN_BUCKET, Behavior: guber.Behavior_GLOBAL, Duration: guber.Minute * 5, Hits: 1, @@ -902,22 +901,74 @@ func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { assert.Equal(t, "", gotResp.GetError(), i) assert.Equal(t, status, gotResp.GetStatus(), i) - if assertion != nil { - assertion(gotResp) - } - return gotResp.GetMetadata()["owner"] } - // Send two hits that should be processed by the owner and the peer and deplete the limit - sendHit(guber.Status_UNDER_LIMIT, nil, 1) - sendHit(guber.Status_UNDER_LIMIT, nil, 2) - // sleep to ensure the async forward has occurred and state should be shared - time.Sleep(time.Second * 5) + // Send two hits that should be processed by the owner and the peer and deplete the remaining + sendHit(guber.Status_UNDER_LIMIT, 1) + sendHit(guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + time.Sleep(time.Second * 3) + // Since the remainder is 0, the peer should set OVER_LIMIT instead of waiting for the owner + // to respond with OVER_LIMIT. + sendHit(guber.Status_OVER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + time.Sleep(time.Second * 3) + // The status should still be OVER_LIMIT + sendHit(guber.Status_OVER_LIMIT, 0) +} - for i := 0; i < 10; i++ { - sendHit(guber.Status_OVER_LIMIT, nil, i+2) +func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { + owner := cluster.PeerAt(2).GRPCAddress + peer := cluster.PeerAt(0).GRPCAddress + assert.NotEqual(t, owner, peer) + + dialOpts := []grpc.DialOption{ + grpc.WithResolvers(newStaticBuilder()), + grpc.WithTransportCredentials(insecure.NewCredentials()), } + + address := fmt.Sprintf("static:///%s", peer) + conn, err := grpc.DialContext(context.Background(), address, dialOpts...) + require.NoError(t, err) + + client := guber.NewV1Client(conn) + + sendHit := func(status guber.Status, i int) string { + ctx, cancel := context.WithTimeout(context.Background(), clock.Hour*5) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: "test_global_leaky_limit", + UniqueKey: "account:12345", + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 5, + Hits: 1, + Limit: 2, + }, + }, + }) + require.NoError(t, err, i) + gotResp := resp.Responses[0] + assert.Equal(t, "", gotResp.GetError(), i) + assert.Equal(t, status, gotResp.GetStatus(), i) + + return gotResp.GetMetadata()["owner"] + } + + // Send two hits that should be processed by the owner and the peer and deplete the remaining + sendHit(guber.Status_UNDER_LIMIT, 1) + sendHit(guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + time.Sleep(time.Second * 3) + // Since the peer must wait for the owner to say it's over the limit, this will return under the limit. + sendHit(guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + time.Sleep(time.Second * 3) + // The status should now be OVER_LIMIT + sendHit(guber.Status_OVER_LIMIT, 0) } func getMetricRequest(t testutil.TestingT, url string, name string) *model.Sample { diff --git a/global.go b/global.go index f5dfada9..568c6f03 100644 --- a/global.go +++ b/global.go @@ -21,14 +21,13 @@ import ( "github.com/mailgun/holster/v4/syncutil" "github.com/prometheus/client_golang/prometheus" - "google.golang.org/protobuf/proto" ) // globalManager manages async hit queue and updates peers in // the cluster periodically when a global rate limit we own updates. type globalManager struct { asyncQueue chan *RateLimitReq - broadcastQueue chan *RateLimitReq + broadcastQueue chan *UpdatePeerGlobal wg syncutil.WaitGroup conf BehaviorConfig log FieldLogger @@ -43,7 +42,7 @@ func newGlobalManager(conf BehaviorConfig, instance *V1Instance) *globalManager gm := globalManager{ log: instance.log, asyncQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), - broadcastQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), + broadcastQueue: make(chan *UpdatePeerGlobal, conf.GlobalBatchLimit), instance: instance, conf: conf, metricGlobalSendDuration: prometheus.NewSummary(prometheus.SummaryOpts{ @@ -74,8 +73,12 @@ func (gm *globalManager) QueueHit(r *RateLimitReq) { gm.asyncQueue <- r } -func (gm *globalManager) QueueUpdate(r *RateLimitReq) { - gm.broadcastQueue <- r +func (gm *globalManager) QueueUpdate(req *RateLimitReq, resp *RateLimitResp) { + gm.broadcastQueue <- &UpdatePeerGlobal{ + Key: req.HashKey(), + Algorithm: req.Algorithm, + Status: resp, + } } // runAsyncHits collects async hit requests and queues them to @@ -173,18 +176,18 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { // runBroadcasts collects status changes for global rate limits and broadcasts the changes to each peer in the cluster. func (gm *globalManager) runBroadcasts() { var interval = NewInterval(gm.conf.GlobalSyncWait) - updates := make(map[string]*RateLimitReq) + updates := make(map[string]*UpdatePeerGlobal) gm.wg.Until(func(done chan struct{}) bool { select { - case r := <-gm.broadcastQueue: - updates[r.HashKey()] = r + case updateReq := <-gm.broadcastQueue: + updates[updateReq.Key] = updateReq // Send the hits if we reached our batch limit if len(updates) >= gm.conf.GlobalBatchLimit { gm.metricBroadcastCounter.WithLabelValues("queue_full").Inc() gm.broadcastPeers(context.Background(), updates) - updates = make(map[string]*RateLimitReq) + updates = make(map[string]*UpdatePeerGlobal) return true } @@ -198,7 +201,7 @@ func (gm *globalManager) runBroadcasts() { if len(updates) != 0 { gm.metricBroadcastCounter.WithLabelValues("timer").Inc() gm.broadcastPeers(context.Background(), updates) - updates = make(map[string]*RateLimitReq) + updates = make(map[string]*UpdatePeerGlobal) } else { gm.metricGlobalQueueLength.Set(0) } @@ -210,35 +213,14 @@ func (gm *globalManager) runBroadcasts() { } // broadcastPeers broadcasts global rate limit statuses to all other peers -func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*RateLimitReq) { +func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*UpdatePeerGlobal) { defer prometheus.NewTimer(gm.metricBroadcastDuration).ObserveDuration() var req UpdatePeerGlobalsReq gm.metricGlobalQueueLength.Set(float64(len(updates))) for _, r := range updates { - // Copy the original since we are removing the GLOBAL behavior - rl := proto.Clone(r).(*RateLimitReq) - // We are only sending the status of the rate limit so, we - // clear the behavior flag, so we don't get queued for update again. - SetBehavior(&rl.Behavior, Behavior_GLOBAL, false) - rl.Hits = 0 - - misleadingStatus, err := gm.instance.getLocalRateLimit(ctx, rl) - if err != nil { - gm.log.WithError(err).Errorf("while broadcasting update to peers for: '%s'", rl.HashKey()) - continue - } - status := misleadingStatus - if misleadingStatus.Remaining == 0 { - status.Status = Status_OVER_LIMIT - } - // Build an UpdatePeerGlobalsReq - req.Globals = append(req.Globals, &UpdatePeerGlobal{ - Algorithm: rl.Algorithm, - Key: rl.HashKey(), - Status: status, - }) + req.Globals = append(req.Globals, r) } fan := syncutil.NewFanOut(gm.conf.GlobalPeerRequestsConcurrency) diff --git a/go.mod b/go.mod index 0f5275c6..7398f859 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.0 github.com/hashicorp/memberlist v0.5.0 github.com/mailgun/errors v0.1.5 - github.com/mailgun/holster/v4 v4.16.2-0.20231121154636-69040cb71a3b + github.com/mailgun/holster/v4 v4.16.3 github.com/miekg/dns v1.1.50 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 diff --git a/go.sum b/go.sum index 3ff2e93e..2972b6fb 100644 --- a/go.sum +++ b/go.sum @@ -291,8 +291,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mailgun/errors v0.1.5 h1:riRpZqfUKTdc8saXvoEg2tYkbRyZESU1KvQ3UxPbdus= github.com/mailgun/errors v0.1.5/go.mod h1:lw+Nh4r/aoUTz6uK915FdfZJo3yq60gPiflFHNpK4NQ= -github.com/mailgun/holster/v4 v4.16.2-0.20231121154636-69040cb71a3b h1:ohMhrwmmA4JbXNukFpriztFWEVLlMuL90Cssg2Vl2TU= -github.com/mailgun/holster/v4 v4.16.2-0.20231121154636-69040cb71a3b/go.mod h1:phAg61z7LZ1PBfedyt2GXkGSlHhuVKK9AcVJO+Cm0/U= +github.com/mailgun/holster/v4 v4.16.3 h1:YMTkDoaFV83ViSaFuAfiyIvzrHJD1UNw7RjNv6J3Kfg= +github.com/mailgun/holster/v4 v4.16.3/go.mod h1:phAg61z7LZ1PBfedyt2GXkGSlHhuVKK9AcVJO+Cm0/U= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= diff --git a/gubernator.go b/gubernator.go index 59c26eca..0d094a7c 100644 --- a/gubernator.go +++ b/gubernator.go @@ -405,9 +405,23 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) // Global rate limits are always stored as RateLimitResp regardless of algorithm rl, ok := item.Value.(*RateLimitResp) if ok { + // In the case we are not the owner, global behavior dictates that we respond with + // what ever the owner has broadcast to use as the response. However, in the case + // of TOKEN_BUCKET it makes little sense to wait for the owner to respond with OVER_LIMIT + // if we already know the remainder is 0. So we check for a remainder of 0 here and set + // OVER_LIMIT only if there are actual hits and this is not a RESET_REMAINING request and + // it's a TOKEN_BUCKET. + // + // We cannot preform this for LEAKY_BUCKET as we don't know how much time or what other requests + // might have influenced the leak rate at the owning peer. + // (Maybe we should preform the leak calculation here?????) + if rl.Remaining == 0 && req.Hits > 0 && !HasBehavior(req.Behavior, Behavior_RESET_REMAINING) && + req.Algorithm == Algorithm_TOKEN_BUCKET { + rl.Status = Status_OVER_LIMIT + } return rl, nil } - // We get here if the owning node hasn't asynchronously forwarded it's updates to us yet and + // We get here if the owning node hasn't asynchronously forwarded its updates to us yet and // our cache still holds the rate limit we created on the first hit. } @@ -569,11 +583,9 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq) (_ } metricGetRateLimitCounter.WithLabelValues("local").Inc() - - // If global behavior and owning peer, broadcast update to all peers. - // Assuming that this peer does not own the ratelimit. + // If global behavior, then broadcast update to all peers. if HasBehavior(r.Behavior, Behavior_GLOBAL) { - s.global.QueueUpdate(r) + s.global.QueueUpdate(r, resp) } return resp, nil