From e2b8853c048c861c6fe5a3568cfe83f2b368680a Mon Sep 17 00:00:00 2001 From: Shawn Poulson Date: Mon, 11 Mar 2024 11:17:31 -0400 Subject: [PATCH] Refactor global behavior and functional tests for stability. - Simplify passing of request time across layers. - Better handling of metrics in tests. - Better detection of global broadcasts, global updates, and idle. - Drop redundant metric `guberator_global_broadcast_counter`. - Fix metric `gubernator_global_queue_length` for global broadcast. - Add metric `gubernator_global_send_queue_length` for global send. --- algorithms.go | 54 ++-- functional_test.go | 769 ++++++++++++++++++--------------------------- global.go | 81 +++-- gubernator.go | 25 +- peer_client.go | 2 - workers.go | 12 +- 6 files changed, 380 insertions(+), 563 deletions(-) diff --git a/algorithms.go b/algorithms.go index a9937c59..8d49bb35 100644 --- a/algorithms.go +++ b/algorithms.go @@ -35,7 +35,7 @@ import ( // with 100 emails and the request will succeed. You can override this default behavior with `DRAIN_OVER_LIMIT` // Implements token bucket algorithm for rate limiting. https://en.wikipedia.org/wiki/Token_bucket -func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, requestTime time.Time) (resp *RateLimitResp, err error) { +func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) { tokenBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("tokenBucket")) defer tokenBucketTimer.ObserveDuration() @@ -100,7 +100,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, request s.Remove(ctx, hashKey) } - return tokenBucketNewItem(ctx, s, c, r, requestTime) + return tokenBucketNewItem(ctx, s, c, r) } // Update the limit if it changed. @@ -133,12 +133,12 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, request } // If our new duration means we are currently expired. - now := EpochMillis(requestTime) - if expire <= now { + requestTime := *r.RequestTime + if expire <= requestTime { // Renew item. span.AddEvent("Limit has expired") - expire = now + r.Duration - t.CreatedAt = now + expire = requestTime + r.Duration + t.CreatedAt = requestTime t.Remaining = t.Limit } @@ -196,19 +196,19 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, request } // Item is not found in cache or store, create new. - return tokenBucketNewItem(ctx, s, c, r, requestTime) + return tokenBucketNewItem(ctx, s, c, r) } // Called by tokenBucket() when adding a new item in the store. -func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, requestTime time.Time) (resp *RateLimitResp, err error) { - now := EpochMillis(requestTime) - expire := now + r.Duration +func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) { + requestTime := *r.RequestTime + expire := requestTime + r.Duration t := &TokenBucketItem{ Limit: r.Limit, Duration: r.Duration, Remaining: r.Limit - r.Hits, - CreatedAt: now, + CreatedAt: requestTime, } // Add a new rate limit to the cache. @@ -252,7 +252,7 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, } // Implements leaky bucket algorithm for rate limiting https://en.wikipedia.org/wiki/Leaky_bucket -func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, requestTime time.Time) (resp *RateLimitResp, err error) { +func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) { leakyBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getRateLimit_leakyBucket")) defer leakyBucketTimer.ObserveDuration() @@ -260,7 +260,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, request r.Burst = r.Limit } - now := EpochMillis(requestTime) + requestTime := *r.RequestTime // Get rate limit from cache. hashKey := r.HashKey() @@ -309,7 +309,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, request s.Remove(ctx, hashKey) } - return leakyBucketNewItem(ctx, s, c, r, requestTime) + return leakyBucketNewItem(ctx, s, c, r) } if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { @@ -349,16 +349,16 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, request } if r.Hits != 0 { - c.UpdateExpiration(r.HashKey(), now+duration) + c.UpdateExpiration(r.HashKey(), requestTime+duration) } // Calculate how much leaked out of the bucket since the last time we leaked a hit - elapsed := now - b.UpdatedAt + elapsed := requestTime - b.UpdatedAt leak := float64(elapsed) / rate if int64(leak) > 0 { b.Remaining += leak - b.UpdatedAt = now + b.UpdatedAt = requestTime } if int64(b.Remaining) > b.Burst { @@ -369,7 +369,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, request Limit: b.Limit, Remaining: int64(b.Remaining), Status: Status_UNDER_LIMIT, - ResetTime: now + (b.Limit-int64(b.Remaining))*int64(rate), + ResetTime: requestTime + (b.Limit-int64(b.Remaining))*int64(rate), } // TODO: Feature missing: check for Duration change between item/request. @@ -391,7 +391,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, request if int64(b.Remaining) == r.Hits { b.Remaining = 0 rl.Remaining = int64(b.Remaining) - rl.ResetTime = now + (rl.Limit-rl.Remaining)*int64(rate) + rl.ResetTime = requestTime + (rl.Limit-rl.Remaining)*int64(rate) return rl, nil } @@ -417,16 +417,16 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, request b.Remaining -= float64(r.Hits) rl.Remaining = int64(b.Remaining) - rl.ResetTime = now + (rl.Limit-rl.Remaining)*int64(rate) + rl.ResetTime = requestTime + (rl.Limit-rl.Remaining)*int64(rate) return rl, nil } - return leakyBucketNewItem(ctx, s, c, r, requestTime) + return leakyBucketNewItem(ctx, s, c, r) } // Called by leakyBucket() when adding a new item in the store. -func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, requestTime time.Time) (resp *RateLimitResp, err error) { - now := EpochMillis(requestTime) +func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) { + requestTime := *r.RequestTime duration := r.Duration rate := float64(duration) / float64(r.Limit) if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { @@ -445,7 +445,7 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, Remaining: float64(r.Burst - r.Hits), Limit: r.Limit, Duration: duration, - UpdatedAt: now, + UpdatedAt: requestTime, Burst: r.Burst, } @@ -453,7 +453,7 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, Status: Status_UNDER_LIMIT, Limit: b.Limit, Remaining: r.Burst - r.Hits, - ResetTime: now + (b.Limit-(r.Burst-r.Hits))*int64(rate), + ResetTime: requestTime + (b.Limit-(r.Burst-r.Hits))*int64(rate), } // Client could be requesting that we start with the bucket OVER_LIMIT @@ -461,12 +461,12 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, metricOverLimitCounter.Add(1) rl.Status = Status_OVER_LIMIT rl.Remaining = 0 - rl.ResetTime = now + (rl.Limit-rl.Remaining)*int64(rate) + rl.ResetTime = requestTime + (rl.Limit-rl.Remaining)*int64(rate) b.Remaining = 0 } item := &CacheItem{ - ExpireAt: now + duration, + ExpireAt: requestTime + duration, Algorithm: r.Algorithm, Key: r.HashKey(), Value: &b, diff --git a/functional_test.go b/functional_test.go index 526f7209..e7b66ac1 100644 --- a/functional_test.go +++ b/functional_test.go @@ -34,6 +34,7 @@ import ( guber "github.com/mailgun/gubernator/v2" "github.com/mailgun/gubernator/v2/cluster" "github.com/mailgun/holster/v4/clock" + "github.com/mailgun/holster/v4/syncutil" "github.com/mailgun/holster/v4/testutil" "github.com/prometheus/common/expfmt" "github.com/prometheus/common/model" @@ -973,22 +974,22 @@ func TestMissingFields(t *testing.T) { } func TestGlobalRateLimits(t *testing.T) { - const ( - name = "test_global" - key = "account:12345" - ) - + name := t.Name() + key := randomKey() + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) peers, err := cluster.ListNonOwningDaemons(name, key) require.NoError(t, err) + var resetTime int64 - sendHit := func(client guber.V1Client, status guber.Status, hits, expectRemaining, expectResetTime int64) int64 { + sendHit := func(client guber.V1Client, status guber.Status, hits, remain int64) { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ { - Name: "test_global", - UniqueKey: "account:12345", + Name: name, + UniqueKey: key, Algorithm: guber.Algorithm_TOKEN_BUCKET, Behavior: guber.Behavior_GLOBAL, Duration: guber.Minute * 3, @@ -1000,19 +1001,27 @@ func TestGlobalRateLimits(t *testing.T) { require.NoError(t, err) item := resp.Responses[0] assert.Equal(t, "", item.Error) - assert.Equal(t, expectRemaining, item.Remaining) + assert.Equal(t, remain, item.Remaining) assert.Equal(t, status, item.Status) assert.Equal(t, int64(5), item.Limit) - if expectResetTime != 0 { - assert.Equal(t, expectResetTime, item.ResetTime) + + // ResetTime should not change during test. + if resetTime == 0 { + resetTime = item.ResetTime } - return item.ResetTime + assert.Equal(t, resetTime, item.ResetTime) + + // ensure that we have a canonical host + assert.NotEmpty(t, item.Metadata["owner"]) } + + require.NoError(t, waitForIdle(1*clock.Minute, cluster.GetDaemons()...)) + // Our first hit should create the request on the peer and queue for async forward - _ = sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1, 4, 0) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1, 4) // Our second should be processed as if we own it since the async forward hasn't occurred yet - _ = sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 2, 2, 0) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 2, 2) testutil.UntilPass(t, 20, clock.Millisecond*200, func(t testutil.TestingT) { // Inspect peers metrics, ensure the peer sent the global rate limit to the owner @@ -1021,44 +1030,36 @@ func TestGlobalRateLimits(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, int(m.Value)) }) - owner, err := cluster.FindOwningDaemon(name, key) - require.NoError(t, err) - // Get the ResetTime from owner. - expectResetTime := sendHit(owner.MustClient(), guber.Status_UNDER_LIMIT, 0, 2, 0) require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) // Check different peers, they should have gotten the broadcast from the owner - sendHit(peers[1].MustClient(), guber.Status_UNDER_LIMIT, 0, 2, expectResetTime) - sendHit(peers[2].MustClient(), guber.Status_UNDER_LIMIT, 0, 2, expectResetTime) + sendHit(peers[1].MustClient(), guber.Status_UNDER_LIMIT, 0, 2) + sendHit(peers[2].MustClient(), guber.Status_UNDER_LIMIT, 0, 2) // Non owning peer should calculate the rate limit remaining before forwarding // to the owner. - sendHit(peers[3].MustClient(), guber.Status_UNDER_LIMIT, 2, 0, expectResetTime) + sendHit(peers[3].MustClient(), guber.Status_UNDER_LIMIT, 2, 0) require.NoError(t, waitForBroadcast(clock.Second*3, owner, 2)) - sendHit(peers[4].MustClient(), guber.Status_OVER_LIMIT, 1, 0, expectResetTime) + sendHit(peers[4].MustClient(), guber.Status_OVER_LIMIT, 1, 0) } // Ensure global broadcast updates all peers when GetRateLimits is called on // either owner or non-owner peer. func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { ctx := context.Background() - const name = "test_global" - key := fmt.Sprintf("key:%016x", rand.Int()) + name := t.Name() + key := randomKey() // Determine owner and non-owner peers. - ownerPeerInfo, err := cluster.FindOwningPeer(name, key) + owner, err := cluster.FindOwningDaemon(name, key) require.NoError(t, err) - ownerDaemon, err := cluster.FindOwningDaemon(name, key) + // ownerAddr := owner.ownerPeerInfo.GRPCAddress + peers, err := cluster.ListNonOwningDaemons(name, key) require.NoError(t, err) - owner := ownerPeerInfo.GRPCAddress - nonOwner := cluster.PeerAt(0).GRPCAddress - if nonOwner == owner { - nonOwner = cluster.PeerAt(1).GRPCAddress - } - require.NotEqual(t, owner, nonOwner) + nonOwner := peers[0] // Connect to owner and non-owner peers in round robin. dialOpts := []grpc.DialOption{ @@ -1066,22 +1067,22 @@ func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`), } - address := fmt.Sprintf("static:///%s,%s", owner, nonOwner) + address := fmt.Sprintf("static:///%s,%s", owner.PeerInfo.GRPCAddress, nonOwner.PeerInfo.GRPCAddress) conn, err := grpc.DialContext(ctx, address, dialOpts...) require.NoError(t, err) client := guber.NewV1Client(conn) - sendHit := func(status guber.Status, i int) { - ctx, cancel := context.WithTimeout(ctx, 10*clock.Second) + sendHit := func(client guber.V1Client, status guber.Status, i int) { + ctx, cancel := context.WithTimeout(context.Background(), 10*clock.Second) defer cancel() resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ { Name: name, UniqueKey: key, - Algorithm: guber.Algorithm_LEAKY_BUCKET, + Algorithm: guber.Algorithm_TOKEN_BUCKET, Behavior: guber.Behavior_GLOBAL, - Duration: guber.Minute * 5, + Duration: 5 * guber.Minute, Hits: 1, Limit: 2, }, @@ -1089,319 +1090,73 @@ func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { }) require.NoError(t, err, i) item := resp.Responses[0] - assert.Equal(t, "", item.GetError(), fmt.Sprintf("mismatch error, iteration %d", i)) - assert.Equal(t, status, item.GetStatus(), fmt.Sprintf("mismatch status, iteration %d", i)) + assert.Equal(t, "", item.Error, fmt.Sprintf("unexpected error, iteration %d", i)) + assert.Equal(t, status, item.Status, fmt.Sprintf("mismatch status, iteration %d", i)) } + require.NoError(t, waitForIdle(1*clock.Minute, cluster.GetDaemons()...)) + // Send two hits that should be processed by the owner and non-owner and // deplete the limit consistently. - sendHit(guber.Status_UNDER_LIMIT, 1) - sendHit(guber.Status_UNDER_LIMIT, 2) - require.NoError(t, waitForBroadcast(clock.Second*3, ownerDaemon, 1)) + sendHit(client, guber.Status_UNDER_LIMIT, 1) + sendHit(client, guber.Status_UNDER_LIMIT, 2) + require.NoError(t, waitForBroadcast(3*clock.Second, owner, 1)) // All successive hits should return OVER_LIMIT. for i := 2; i <= 10; i++ { - sendHit(guber.Status_OVER_LIMIT, i) + sendHit(client, guber.Status_OVER_LIMIT, i) } } func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { - const ( - name = "test_global_token_limit" - key = "account:12345" - ) - - peers, err := cluster.ListNonOwningDaemons(name, key) - require.NoError(t, err) - - sendHit := func(expectedStatus guber.Status, hits int64) { - ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) - defer cancel() - resp, err := peers[0].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ - { - Name: name, - UniqueKey: key, - Algorithm: guber.Algorithm_TOKEN_BUCKET, - Behavior: guber.Behavior_GLOBAL, - Duration: guber.Minute * 5, - Hits: hits, - Limit: 2, - }, - }, - }) - assert.NoError(t, err) - assert.Equal(t, "", resp.Responses[0].GetError()) - assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) - } + name := t.Name() + key := randomKey() owner, err := cluster.FindOwningDaemon(name, key) require.NoError(t, err) - - // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining - sendHit(guber.Status_UNDER_LIMIT, 1) - sendHit(guber.Status_UNDER_LIMIT, 1) - // Wait for the broadcast from the owner to the peer - require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) - // Since the remainder is 0, the peer should set OVER_LIMIT instead of waiting for the owner - // to respond with OVER_LIMIT. - sendHit(guber.Status_OVER_LIMIT, 1) - // Wait for the broadcast from the owner to the peer - require.NoError(t, waitForBroadcast(clock.Second*3, owner, 2)) - // The status should still be OVER_LIMIT - sendHit(guber.Status_OVER_LIMIT, 0) -} - -func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { - const ( - name = "test_global_token_limit_leaky" - key = "account:12345" - ) - peers, err := cluster.ListNonOwningDaemons(name, key) require.NoError(t, err) - sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64) { - ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + sendHit := func(expectedStatus guber.Status, hits, expectedRemaining int64) { + ctx, cancel := context.WithTimeout(context.Background(), 10*clock.Second) defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ - { - Name: name, - UniqueKey: key, - Algorithm: guber.Algorithm_LEAKY_BUCKET, - Behavior: guber.Behavior_GLOBAL, - Duration: guber.Minute * 5, - Hits: hits, - Limit: 2, - }, - }, - }) - assert.NoError(t, err) - assert.Equal(t, "", resp.Responses[0].GetError()) - assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) - } - owner, err := cluster.FindOwningDaemon(name, key) - require.NoError(t, err) - - // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining - sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1) - sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1) - // Wait for the broadcast from the owner to the peers - require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) - // Ask a different peer if the status is over the limit - sendHit(peers[1].MustClient(), guber.Status_OVER_LIMIT, 1) -} - -func TestGlobalRequestMoreThanAvailable(t *testing.T) { - const ( - name = "test_global_more_than_available" - key = "account:123456" - ) - - peers, err := cluster.ListNonOwningDaemons(name, key) - require.NoError(t, err) - - sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64, remaining int64) { - ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) - defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ - { - Name: name, - UniqueKey: key, - Algorithm: guber.Algorithm_LEAKY_BUCKET, - Behavior: guber.Behavior_GLOBAL, - Duration: guber.Minute * 1_000, - Hits: hits, - Limit: 100, - }, - }, - }) - assert.NoError(t, err) - assert.Equal(t, "", resp.Responses[0].GetError()) - assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) - } - owner, err := cluster.FindOwningDaemon(name, key) - require.NoError(t, err) - - prev, err := getBroadcastCount(owner) - require.NoError(t, err) - - // Ensure GRPC has connections to each peer before we start, as we want - // the actual test requests to happen quite fast. - for _, p := range peers { - sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 0, 100) - } - - // Send a request for 50 hits from each non owning peer in the cluster. These requests - // will be queued and sent to the owner as accumulated hits. As a result of the async nature - // of `Behavior_GLOBAL` rate limit requests spread across peers like this will be allowed to - // over-consume their resource within the rate limit window until the owner is updated and - // a broadcast to all peers is received. - // - // The maximum number of resources that can be over-consumed can be calculated by multiplying - // the remainder by the number of peers in the cluster. For example: If you have a remainder of 100 - // and a cluster of 10 instances, then the maximum over-consumed resource is 1,000. If you need - // a more accurate remaining calculation, and wish to avoid over consuming a resource, then do - // not use `Behavior_GLOBAL`. - for _, p := range peers { - sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 50, 50) - } - - // Wait for the broadcast from the owner to the peers - require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) - - // We should be over the limit - sendHit(peers[0].MustClient(), guber.Status_OVER_LIMIT, 1, 0) -} - -func TestGlobalNegativeHits(t *testing.T) { - const ( - name = "test_global_negative_hits" - key = "account:12345" - ) - - peers, err := cluster.ListNonOwningDaemons(name, key) - require.NoError(t, err) - - sendHit := func(client guber.V1Client, status guber.Status, hits int64, remaining int64) { - ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) - defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + resp, err := peers[0].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ { Name: name, UniqueKey: key, Algorithm: guber.Algorithm_TOKEN_BUCKET, Behavior: guber.Behavior_GLOBAL, - Duration: guber.Minute * 100, + Duration: 5 * guber.Minute, Hits: hits, Limit: 2, }, }, }) assert.NoError(t, err) - assert.Equal(t, "", resp.Responses[0].GetError()) - assert.Equal(t, status, resp.Responses[0].GetStatus()) - assert.Equal(t, remaining, resp.Responses[0].Remaining) - } - owner, err := cluster.FindOwningDaemon(name, key) - require.NoError(t, err) - prev, err := getBroadcastCount(owner) - require.NoError(t, err) - - // Send a negative hit on a rate limit with no hits - sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, -1, 3) - - // Wait for the negative remaining to propagate - require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) - - // Send another negative hit to a different peer - sendHit(peers[1].MustClient(), guber.Status_UNDER_LIMIT, -1, 4) - - require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+2)) - - // Should have 4 in the remainder - sendHit(peers[2].MustClient(), guber.Status_UNDER_LIMIT, 4, 0) - - require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+3)) - - sendHit(peers[3].MustClient(), guber.Status_UNDER_LIMIT, 0, 0) -} - -func TestGlobalResetRemaining(t *testing.T) { - const ( - name = "test_global_reset" - key = "account:123456" - ) - - peers, err := cluster.ListNonOwningDaemons(name, key) - require.NoError(t, err) - - sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64, remaining int64) { - ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) - defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ - { - Name: name, - UniqueKey: key, - Algorithm: guber.Algorithm_LEAKY_BUCKET, - Behavior: guber.Behavior_GLOBAL, - Duration: guber.Minute * 1_000, - Hits: hits, - Limit: 100, - }, - }, - }) - assert.NoError(t, err) - assert.Equal(t, "", resp.Responses[0].GetError()) - assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) - assert.Equal(t, remaining, resp.Responses[0].Remaining) - } - owner, err := cluster.FindOwningDaemon(name, key) - require.NoError(t, err) - prev, err := getBroadcastCount(owner) - require.NoError(t, err) - - for _, p := range peers { - sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 50, 50) + item := resp.Responses[0] + assert.Equal(t, "", item.Error, "unexpected error") + assert.Equal(t, expectedStatus, item.Status, "mismatch status") + assert.Equal(t, expectedRemaining, item.Remaining, "mismatch remaining") } - // Wait for the broadcast from the owner to the peers - require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) - - // We should be over the limit and remaining should be zero - sendHit(peers[0].MustClient(), guber.Status_OVER_LIMIT, 1, 0) + require.NoError(t, waitForIdle(1*clock.Minute, cluster.GetDaemons()...)) - // Now reset the remaining - ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) - defer cancel() - resp, err := peers[0].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ - { - Name: name, - UniqueKey: key, - Algorithm: guber.Algorithm_LEAKY_BUCKET, - Behavior: guber.Behavior_GLOBAL | guber.Behavior_RESET_REMAINING, - Duration: guber.Minute * 1_000, - Hits: 0, - Limit: 100, - }, - }, - }) - require.NoError(t, err) - assert.NotEqual(t, 100, resp.Responses[0].Remaining) + // Send two hits that should be processed by the owner and the broadcast to + // peer, depleting the remaining. + sendHit(guber.Status_UNDER_LIMIT, 1, 1) + sendHit(guber.Status_UNDER_LIMIT, 1, 0) - // Wait for the reset to propagate. - require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+2)) + // Wait for the broadcast from the owner to the peer + require.NoError(t, waitForBroadcast(3*clock.Second, owner, 1)) - // Check a different peer to ensure remaining has been reset - resp, err = peers[1].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ - { - Name: name, - UniqueKey: key, - Algorithm: guber.Algorithm_LEAKY_BUCKET, - Behavior: guber.Behavior_GLOBAL, - Duration: guber.Minute * 1_000, - Hits: 0, - Limit: 100, - }, - }, - }) - require.NoError(t, err) - assert.NotEqual(t, 100, resp.Responses[0].Remaining) + // Since the remainder is 0, the peer should return OVER_LIMIT on next hit. + sendHit(guber.Status_OVER_LIMIT, 1, 0) -} + // Wait for the broadcast from the owner to the peer. + require.NoError(t, waitForBroadcast(3*clock.Second, owner, 2)) -func getMetricRequest(url string, name string) (*model.Sample, error) { - resp, err := http.Get(url) - if err != nil { - return nil, err - } - defer resp.Body.Close() - return getMetric(resp.Body, name) + // The status should still be OVER_LIMIT. + sendHit(guber.Status_OVER_LIMIT, 0, 0) } func TestChangeLimit(t *testing.T) { @@ -1622,10 +1377,11 @@ func TestHealthCheck(t *testing.T) { testutil.UntilPass(t, 20, clock.Millisecond*300, func(t testutil.TestingT) { // Check the health again to get back the connection error - healthResp, err := client.HealthCheck(context.Background(), &guber.HealthCheckReq{}) - if !assert.NoError(t, err) { + healthResp, err = client.HealthCheck(context.Background(), &guber.HealthCheckReq{}) + if assert.Nil(t, err) { return } + assert.Equal(t, "unhealthy", healthResp.GetStatus()) assert.Contains(t, healthResp.GetMessage(), "connect: connection refused") }) @@ -1634,25 +1390,9 @@ func TestHealthCheck(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*15) defer cancel() require.NoError(t, cluster.Restart(ctx)) - - // wait for every peer instance to come back online - numPeers := int32(len(cluster.GetPeers())) - for _, peer := range cluster.GetPeers() { - peerClient, err := guber.DialV1Server(peer.GRPCAddress, nil) - require.NoError(t, err) - testutil.UntilPass(t, 10, 300*clock.Millisecond, func(t testutil.TestingT) { - healthResp, err := peerClient.HealthCheck(context.Background(), &guber.HealthCheckReq{}) - if !assert.NoError(t, err) { - return - } - assert.Equal(t, "healthy", healthResp.Status) - assert.Equal(t, numPeers, healthResp.PeerCount) - }) - } } func TestLeakyBucketDivBug(t *testing.T) { - // Freeze time so we don't leak during the test defer clock.Freeze(clock.Now()).Unfreeze() client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) @@ -1801,142 +1541,6 @@ func TestGetPeerRateLimits(t *testing.T) { // TODO: Add a test for sending no rate limits RateLimitReqList.RateLimits = nil -func getMetric(in io.Reader, name string) (*model.Sample, error) { - dec := expfmt.SampleDecoder{ - Dec: expfmt.NewDecoder(in, expfmt.FmtText), - Opts: &expfmt.DecodeOptions{ - Timestamp: model.Now(), - }, - } - - var all model.Vector - for { - var smpls model.Vector - err := dec.Decode(&smpls) - if err == io.EOF { - break - } - if err != nil { - return nil, err - } - all = append(all, smpls...) - } - - for _, s := range all { - if strings.Contains(s.Metric.String(), name) { - return s, nil - } - } - return nil, nil -} - -// getBroadcastCount returns the current broadcast count for use with waitForBroadcast() -// TODO: Replace this with something else, we can call and reset via HTTP/GRPC calls in gubernator v3 -func getBroadcastCount(d *guber.Daemon) (int, error) { - m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), - "gubernator_broadcast_duration_count") - if err != nil { - return 0, err - } - - return int(m.Value), nil -} - -// waitForBroadcast waits until the broadcast count for the daemon changes to -// the expected value. Returns an error if the expected value is not found -// before the context is cancelled. -func waitForBroadcast(timeout clock.Duration, d *guber.Daemon, expect int) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - for { - m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), - "gubernator_broadcast_duration_count") - if err != nil { - return err - } - - // It's possible a broadcast occurred twice if waiting for multiple peer to - // forward updates to the owner. - if int(m.Value) >= expect { - return nil - } - - select { - case <-clock.After(time.Millisecond * 100): - case <-ctx.Done(): - return ctx.Err() - } - } -} - -// waitForUpdate waits until the global update count for the daemon changes to -// the expected value. Returns an error if the expected value is not found -// before the context is cancelled. -func waitForUpdate(timeout clock.Duration, d *guber.Daemon, expect int) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - for { - m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), - "gubernator_global_send_duration_count") - if err != nil { - return err - } - - // It's possible a broadcast occurred twice if waiting for multiple peer to - // forward updates to the owner. - if int(m.Value) >= expect { - return nil - } - - select { - case <-clock.After(time.Millisecond * 100): - case <-ctx.Done(): - return ctx.Err() - } - } -} - -func getMetricValue(t *testing.T, d *guber.Daemon, name string) float64 { - m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), - name) - require.NoError(t, err) - if m == nil { - return 0 - } - return float64(m.Value) -} - -// Get metric counter values on each peer. -func getPeerCounters(t *testing.T, peers []*guber.Daemon, name string) map[string]int { - counters := make(map[string]int) - for _, peer := range peers { - counters[peer.InstanceID] = int(getMetricValue(t, peer, name)) - } - return counters -} - -func sendHit(t *testing.T, d *guber.Daemon, req *guber.RateLimitReq, expectStatus guber.Status, expectRemaining int64) { - if req.Hits != 0 { - t.Logf("Sending %d hits to peer %s", req.Hits, d.InstanceID) - } - client := d.MustClient() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{req}, - }) - require.NoError(t, err) - item := resp.Responses[0] - assert.Equal(t, "", item.Error) - if expectRemaining >= 0 { - assert.Equal(t, expectRemaining, item.Remaining) - } - assert.Equal(t, expectStatus, item.Status) - assert.Equal(t, req.Limit, item.Limit) -} - func TestGlobalBehavior(t *testing.T) { const limit = 1000 broadcastTimeout := 400 * time.Millisecond @@ -1972,6 +1576,8 @@ func TestGlobalBehavior(t *testing.T) { require.NoError(t, err) t.Logf("Owner peer: %s", owner.InstanceID) + require.NoError(t, waitForIdle(1*time.Minute, cluster.GetDaemons()...)) + broadcastCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_broadcast_duration_count") updateCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_global_send_duration_count") upgCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/UpdatePeerGlobals\"}") @@ -2088,6 +1694,8 @@ func TestGlobalBehavior(t *testing.T) { require.NoError(t, err) t.Logf("Owner peer: %s", owner.InstanceID) + require.NoError(t, waitForIdle(1*clock.Minute, cluster.GetDaemons()...)) + broadcastCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_broadcast_duration_count") updateCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_global_send_duration_count") upgCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/UpdatePeerGlobals\"}") @@ -2189,7 +1797,6 @@ func TestGlobalBehavior(t *testing.T) { } }) - // Distribute hits across all non-owner peers. t.Run("Distributed hits", func(t *testing.T) { testCases := []struct { Name string @@ -2216,6 +1823,8 @@ func TestGlobalBehavior(t *testing.T) { } t.Logf("Owner peer: %s", owner.InstanceID) + require.NoError(t, waitForIdle(1*clock.Minute, cluster.GetDaemons()...)) + broadcastCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_broadcast_duration_count") updateCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_global_send_duration_count") upgCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/UpdatePeerGlobals\"}") @@ -2338,3 +1947,225 @@ func TestGlobalBehavior(t *testing.T) { } }) } + +// Request metrics and parse into map. +// Optionally pass names to filter metrics by name. +func getMetrics(HTTPAddr string, names ...string) (map[string]*model.Sample, error) { + url := fmt.Sprintf("http://%s/metrics", HTTPAddr) + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP error requesting metrics: %s", resp.Status) + } + decoder := expfmt.SampleDecoder{ + Dec: expfmt.NewDecoder(resp.Body, expfmt.FmtText), + Opts: &expfmt.DecodeOptions{ + Timestamp: model.Now(), + }, + } + nameSet := make(map[string]struct{}) + for _, name := range names { + nameSet[name] = struct{}{} + } + metrics := make(map[string]*model.Sample) + + for { + var smpls model.Vector + err := decoder.Decode(&smpls) + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + for _, smpl := range smpls { + name := smpl.Metric.String() + if _, ok := nameSet[name]; ok || len(nameSet) == 0 { + metrics[name] = smpl + } + } + } + + return metrics, nil +} + +func getMetricRequest(url string, name string) (*model.Sample, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return getMetric(resp.Body, name) +} + +func getMetric(in io.Reader, name string) (*model.Sample, error) { + dec := expfmt.SampleDecoder{ + Dec: expfmt.NewDecoder(in, expfmt.FmtText), + Opts: &expfmt.DecodeOptions{ + Timestamp: model.Now(), + }, + } + + var all model.Vector + for { + var smpls model.Vector + err := dec.Decode(&smpls) + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + all = append(all, smpls...) + } + + for _, s := range all { + if strings.Contains(s.Metric.String(), name) { + return s, nil + } + } + return nil, nil +} + +// waitForBroadcast waits until the broadcast count for the daemon changes to +// at least the expected value and the broadcast queue is empty. +// Returns an error if timeout waiting for conditions to be met. +func waitForBroadcast(timeout clock.Duration, d *guber.Daemon, expect int) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for { + metrics, err := getMetrics(d.Config().HTTPListenAddress, + "gubernator_broadcast_duration_count", "gubernator_global_queue_length") + if err != nil { + return err + } + gbdc := metrics["gubernator_broadcast_duration_count"] + ggql := metrics["gubernator_global_queue_length"] + + // It's possible a broadcast occurred twice if waiting for multiple + // peers to forward updates to non-owners. + if int(gbdc.Value) >= expect && ggql.Value == 0 { + return nil + } + + select { + case <-clock.After(100 * clock.Millisecond): + case <-ctx.Done(): + return ctx.Err() + } + } +} + +// waitForUpdate waits until the global hits update count for the daemon +// changes to at least the expected value and the global update queue is empty. +// Returns an error if timeout waiting for conditions to be met. +func waitForUpdate(timeout clock.Duration, d *guber.Daemon, expect int) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for { + metrics, err := getMetrics(d.Config().HTTPListenAddress, + "gubernator_global_send_duration_count", "gubernator_global_send_queue_length") + if err != nil { + return err + } + gsdc := metrics["gubernator_global_send_duration_count"] + gsql := metrics["gubernator_global_send_queue_length"] + + // It's possible a hit occurred twice if waiting for multiple peers to + // forward updates to the owner. + if int(gsdc.Value) >= expect && gsql.Value == 0 { + return nil + } + + select { + case <-clock.After(100 * clock.Millisecond): + case <-ctx.Done(): + return ctx.Err() + } + } +} + +// waitForIdle waits until both global broadcast and global hits queues are +// empty. +func waitForIdle(timeout clock.Duration, daemons ...*guber.Daemon) error { + var wg syncutil.WaitGroup + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + for _, d := range daemons { + wg.Run(func(raw any) error { + d := raw.(*guber.Daemon) + for { + metrics, err := getMetrics(d.Config().HTTPListenAddress, + "gubernator_global_queue_length", "gubernator_global_send_queue_length") + if err != nil { + return err + } + ggql := metrics["gubernator_global_queue_length"] + gsql := metrics["gubernator_global_send_queue_length"] + + if ggql.Value == 0 && gsql.Value == 0 { + return nil + } + + select { + case <-clock.After(100 * clock.Millisecond): + case <-ctx.Done(): + return ctx.Err() + } + } + }, d) + } + errs := wg.Wait() + if len(errs) > 0 { + return errs[0] + } + return nil +} + +func getMetricValue(t *testing.T, d *guber.Daemon, name string) float64 { + m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), + name) + require.NoError(t, err) + if m == nil { + return 0 + } + return float64(m.Value) +} + +// Get metric counter values on each peer. +func getPeerCounters(t *testing.T, peers []*guber.Daemon, name string) map[string]int { + counters := make(map[string]int) + for _, peer := range peers { + counters[peer.InstanceID] = int(getMetricValue(t, peer, name)) + } + return counters +} + +func sendHit(t *testing.T, d *guber.Daemon, req *guber.RateLimitReq, expectStatus guber.Status, expectRemaining int64) { + if req.Hits != 0 { + t.Logf("Sending %d hits to peer %s", req.Hits, d.InstanceID) + } + client := d.MustClient() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{req}, + }) + require.NoError(t, err) + item := resp.Responses[0] + assert.Equal(t, "", item.Error) + if expectRemaining >= 0 { + assert.Equal(t, expectRemaining, item.Remaining) + } + assert.Equal(t, expectStatus, item.Status) + assert.Equal(t, req.Limit, item.Limit) +} + +func randomKey() string { + return fmt.Sprintf("%016x", rand.Int()) +} diff --git a/global.go b/global.go index 5af33301..47703f6e 100644 --- a/global.go +++ b/global.go @@ -18,7 +18,6 @@ package gubernator import ( "context" - "time" "github.com/mailgun/holster/v4/syncutil" "github.com/pkg/errors" @@ -29,28 +28,23 @@ import ( // globalManager manages async hit queue and updates peers in // the cluster periodically when a global rate limit we own updates. type globalManager struct { - hitsQueue chan *RateLimitReq - broadcastQueue chan broadcastItem - wg syncutil.WaitGroup - conf BehaviorConfig - log FieldLogger - instance *V1Instance // TODO circular import? V1Instance also holds a reference to globalManager - metricGlobalSendDuration prometheus.Summary - metricBroadcastDuration prometheus.Summary - metricBroadcastCounter *prometheus.CounterVec - metricGlobalQueueLength prometheus.Gauge -} - -type broadcastItem struct { - Request *RateLimitReq - RequestTime time.Time + hitsQueue chan *RateLimitReq + broadcastQueue chan *RateLimitReq + wg syncutil.WaitGroup + conf BehaviorConfig + log FieldLogger + instance *V1Instance // TODO circular import? V1Instance also holds a reference to globalManager + metricGlobalSendDuration prometheus.Summary + metricGlobalSendQueueLength prometheus.Gauge + metricBroadcastDuration prometheus.Summary + metricGlobalQueueLength prometheus.Gauge } func newGlobalManager(conf BehaviorConfig, instance *V1Instance) *globalManager { gm := globalManager{ log: instance.log, hitsQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), - broadcastQueue: make(chan broadcastItem, conf.GlobalBatchLimit), + broadcastQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), instance: instance, conf: conf, metricGlobalSendDuration: prometheus.NewSummary(prometheus.SummaryOpts{ @@ -58,15 +52,15 @@ func newGlobalManager(conf BehaviorConfig, instance *V1Instance) *globalManager Help: "The duration of GLOBAL async sends in seconds.", Objectives: map[float64]float64{0.5: 0.05, 0.99: 0.001}, }), + metricGlobalSendQueueLength: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "gubernator_global_send_queue_length", + Help: "The count of requests queued up for global broadcast. This is only used for GetRateLimit requests using global behavior.", + }), metricBroadcastDuration: prometheus.NewSummary(prometheus.SummaryOpts{ Name: "gubernator_broadcast_duration", Help: "The duration of GLOBAL broadcasts to peers in seconds.", Objectives: map[float64]float64{0.5: 0.05, 0.99: 0.001}, }), - metricBroadcastCounter: prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "gubernator_broadcast_counter", - Help: "The count of broadcasts.", - }, []string{"condition"}), metricGlobalQueueLength: prometheus.NewGauge(prometheus.GaugeOpts{ Name: "gubernator_global_queue_length", Help: "The count of requests queued up for global broadcast. This is only used for GetRateLimit requests using global behavior.", @@ -83,12 +77,9 @@ func (gm *globalManager) QueueHit(r *RateLimitReq) { } } -func (gm *globalManager) QueueUpdate(req *RateLimitReq, requestTime time.Time) { +func (gm *globalManager) QueueUpdate(req *RateLimitReq) { if req.Hits != 0 { - gm.broadcastQueue <- broadcastItem{ - Request: req, - RequestTime: requestTime, - } + gm.broadcastQueue <- req } } @@ -118,11 +109,13 @@ func (gm *globalManager) runAsyncHits() { } else { hits[key] = r } + gm.metricGlobalSendQueueLength.Set(float64(len(hits))) // Send the hits if we reached our batch limit if len(hits) == gm.conf.GlobalBatchLimit { gm.sendHits(hits) hits = make(map[string]*RateLimitReq) + gm.metricGlobalSendQueueLength.Set(0) return true } @@ -136,6 +129,7 @@ func (gm *globalManager) runAsyncHits() { if len(hits) != 0 { gm.sendHits(hits) hits = make(map[string]*RateLimitReq) + gm.metricGlobalSendQueueLength.Set(0) } case <-done: interval.Stop() @@ -198,18 +192,19 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { // and in a periodic frequency determined by GlobalSyncWait. func (gm *globalManager) runBroadcasts() { var interval = NewInterval(gm.conf.GlobalSyncWait) - updates := make(map[string]broadcastItem) + updates := make(map[string]*RateLimitReq) gm.wg.Until(func(done chan struct{}) bool { select { case update := <-gm.broadcastQueue: - updates[update.Request.HashKey()] = update + updates[update.HashKey()] = update + gm.metricGlobalQueueLength.Set(float64(len(updates))) // Send the hits if we reached our batch limit if len(updates) >= gm.conf.GlobalBatchLimit { - gm.metricBroadcastCounter.WithLabelValues("queue_full").Inc() gm.broadcastPeers(context.Background(), updates) - updates = make(map[string]broadcastItem) + updates = make(map[string]*RateLimitReq) + gm.metricGlobalQueueLength.Set(0) return true } @@ -220,13 +215,13 @@ func (gm *globalManager) runBroadcasts() { } case <-interval.C: - if len(updates) != 0 { - gm.metricBroadcastCounter.WithLabelValues("timer").Inc() - gm.broadcastPeers(context.Background(), updates) - updates = make(map[string]broadcastItem) - } else { - gm.metricGlobalQueueLength.Set(0) + if len(updates) == 0 { + break } + gm.broadcastPeers(context.Background(), updates) + updates = make(map[string]*RateLimitReq) + gm.metricGlobalQueueLength.Set(0) + case <-done: interval.Stop() return false @@ -236,7 +231,7 @@ func (gm *globalManager) runBroadcasts() { } // broadcastPeers broadcasts global rate limit statuses to all other peers -func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]broadcastItem) { +func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*RateLimitReq) { defer prometheus.NewTimer(gm.metricBroadcastDuration).ObserveDuration() var req UpdatePeerGlobalsReq @@ -244,19 +239,19 @@ func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string] for _, update := range updates { // Get current rate limit state. - grlReq := proto.Clone(update.Request).(*RateLimitReq) + grlReq := proto.Clone(update).(*RateLimitReq) grlReq.Hits = 0 - status, err := gm.instance.workerPool.GetRateLimit(ctx, grlReq, update.RequestTime) + status, err := gm.instance.workerPool.GetRateLimit(ctx, grlReq) if err != nil { gm.log.WithError(err).Error("while retrieving rate limit status") continue } updateReq := &UpdatePeerGlobal{ - Key: update.Request.HashKey(), - Algorithm: update.Request.Algorithm, - Duration: update.Request.Duration, + Key: update.HashKey(), + Algorithm: update.Algorithm, + Duration: update.Duration, Status: status, - RequestTime: EpochMillis(update.RequestTime), + RequestTime: *update.RequestTime, } req.Globals = append(req.Globals, updateReq) } diff --git a/gubernator.go b/gubernator.go index fda9f92a..87a0a04d 100644 --- a/gubernator.go +++ b/gubernator.go @@ -21,7 +21,6 @@ import ( "fmt" "strings" "sync" - "time" "github.com/mailgun/errors" "github.com/mailgun/holster/v4/clock" @@ -188,6 +187,7 @@ func (s *V1Instance) GetRateLimits(ctx context.Context, r *GetRateLimitsReq) (*G "Requests.RateLimits list too large; max size is '%d'", maxBatchSize) } + requestTime := EpochMillis(clock.Now()) resp := GetRateLimitsResp{ Responses: make([]*RateLimitResp, len(r.Requests)), } @@ -200,17 +200,19 @@ func (s *V1Instance) GetRateLimits(ctx context.Context, r *GetRateLimitsReq) (*G var peer *PeerClient var err error - if len(req.UniqueKey) == 0 { + if req.UniqueKey == "" { metricCheckErrorCounter.WithLabelValues("Invalid request").Inc() resp.Responses[i] = &RateLimitResp{Error: "field 'unique_key' cannot be empty"} continue } - - if len(req.Name) == 0 { + if req.Name == "" { metricCheckErrorCounter.WithLabelValues("Invalid request").Inc() resp.Responses[i] = &RateLimitResp{Error: "field 'namespace' cannot be empty"} continue } + if req.RequestTime == nil || *req.RequestTime == 0 { + req.RequestTime = &requestTime + } if ctx.Err() != nil { err = errors.Wrap(ctx.Err(), "Error while iterating request items") @@ -578,21 +580,14 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq) (_ defer func() { tracing.EndScope(ctx, err) }() defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getLocalRateLimit")).ObserveDuration() - var requestTime time.Time - if r.RequestTime != nil { - requestTime = time.UnixMilli(*r.RequestTime) - } - if requestTime.IsZero() { - requestTime = clock.Now() - } - resp, err := s.workerPool.GetRateLimit(ctx, r, requestTime) + resp, err := s.workerPool.GetRateLimit(ctx, r) if err != nil { return nil, errors.Wrap(err, "during workerPool.GetRateLimit") } // If global behavior, then broadcast update to all peers. if HasBehavior(r.Behavior, Behavior_GLOBAL) { - s.global.QueueUpdate(r, requestTime) + s.global.QueueUpdate(r) } metricGetRateLimitCounter.WithLabelValues("local").Inc() @@ -736,10 +731,10 @@ func (s *V1Instance) Describe(ch chan<- *prometheus.Desc) { metricGetRateLimitCounter.Describe(ch) metricOverLimitCounter.Describe(ch) metricWorkerQueue.Describe(ch) - s.global.metricBroadcastCounter.Describe(ch) s.global.metricBroadcastDuration.Describe(ch) s.global.metricGlobalQueueLength.Describe(ch) s.global.metricGlobalSendDuration.Describe(ch) + s.global.metricGlobalSendQueueLength.Describe(ch) } // Collect fetches metrics from the server for use by prometheus @@ -754,10 +749,10 @@ func (s *V1Instance) Collect(ch chan<- prometheus.Metric) { metricGetRateLimitCounter.Collect(ch) metricOverLimitCounter.Collect(ch) metricWorkerQueue.Collect(ch) - s.global.metricBroadcastCounter.Collect(ch) s.global.metricBroadcastDuration.Collect(ch) s.global.metricGlobalQueueLength.Collect(ch) s.global.metricGlobalSendDuration.Collect(ch) + s.global.metricGlobalSendQueueLength.Collect(ch) } // HasBehavior returns true if the provided behavior is set diff --git a/peer_client.go b/peer_client.go index 2f3c0905..794ebea7 100644 --- a/peer_client.go +++ b/peer_client.go @@ -22,7 +22,6 @@ import ( "fmt" "sync" "sync/atomic" - "time" "github.com/mailgun/holster/v4/clock" "github.com/mailgun/holster/v4/collections" @@ -70,7 +69,6 @@ type request struct { request *RateLimitReq resp chan *response ctx context.Context - requestTime time.Time } type PeerConfig struct { diff --git a/workers.go b/workers.go index 04557f76..76fa1e31 100644 --- a/workers.go +++ b/workers.go @@ -42,7 +42,6 @@ import ( "strconv" "sync" "sync/atomic" - "time" "github.com/OneOfOne/xxhash" "github.com/mailgun/holster/v4/errors" @@ -200,7 +199,7 @@ func (p *WorkerPool) dispatch(worker *Worker) { } resp := new(response) - resp.rl, resp.err = worker.handleGetRateLimit(req.ctx, req.request, req.requestTime, worker.cache) + resp.rl, resp.err = worker.handleGetRateLimit(req.ctx, req.request, worker.cache) select { case req.resp <- resp: // Success. @@ -259,7 +258,7 @@ func (p *WorkerPool) dispatch(worker *Worker) { } // GetRateLimit sends a GetRateLimit request to worker pool. -func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq, requestTime time.Time) (*RateLimitResp, error) { +func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq) (*RateLimitResp, error) { // Delegate request to assigned channel based on request key. worker := p.getWorker(rlRequest.HashKey()) queueGauge := metricWorkerQueue.WithLabelValues("GetRateLimit", worker.name) @@ -269,7 +268,6 @@ func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq, ctx: ctx, resp: make(chan *response, 1), request: rlRequest, - requestTime: requestTime, } // Send request. @@ -291,14 +289,14 @@ func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq, } // Handle request received by worker. -func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, requestTime time.Time, cache Cache) (*RateLimitResp, error) { +func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, cache Cache) (*RateLimitResp, error) { defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Worker.handleGetRateLimit")).ObserveDuration() var rlResponse *RateLimitResp var err error switch req.Algorithm { case Algorithm_TOKEN_BUCKET: - rlResponse, err = tokenBucket(ctx, worker.conf.Store, cache, req, requestTime) + rlResponse, err = tokenBucket(ctx, worker.conf.Store, cache, req) if err != nil { msg := "Error in tokenBucket" countError(err, msg) @@ -307,7 +305,7 @@ func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, } case Algorithm_LEAKY_BUCKET: - rlResponse, err = leakyBucket(ctx, worker.conf.Store, cache, req, requestTime) + rlResponse, err = leakyBucket(ctx, worker.conf.Store, cache, req) if err != nil { msg := "Error in leakyBucket" countError(err, msg)