From fd28875c349c0de63dbb5c35fe0d6fb7437b7247 Mon Sep 17 00:00:00 2001 From: Shawn Poulson Date: Thu, 7 Mar 2024 10:01:24 -0500 Subject: [PATCH] Add tests from PR219. --- functional_test.go | 427 ++++++++++++++++++++++++++++++--------------- 1 file changed, 290 insertions(+), 137 deletions(-) diff --git a/functional_test.go b/functional_test.go index 2a299387..4381bc82 100644 --- a/functional_test.go +++ b/functional_test.go @@ -31,8 +31,6 @@ import ( "testing" "time" - "golang.org/x/exp/maps" - guber "github.com/mailgun/gubernator/v2" "github.com/mailgun/gubernator/v2/cluster" "github.com/mailgun/holster/v4/clock" @@ -41,6 +39,9 @@ import ( "github.com/prometheus/common/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" json "google.golang.org/protobuf/encoding/protojson" ) @@ -71,7 +72,6 @@ func TestMain(m *testing.M) { fmt.Println(err) os.Exit(1) } - defer cluster.Stop() // Populate peer clients. Avoids data races when goroutines conflict trying // to instantiate client singletons. @@ -79,7 +79,11 @@ func TestMain(m *testing.M) { _ = peer.MustClient() } - os.Exit(m.Run()) + code := m.Run() + cluster.Stop() + + // os.Exit doesn't run deferred functions + os.Exit(code) } func TestOverTheLimit(t *testing.T) { @@ -956,6 +960,151 @@ func TestGlobalRateLimits(t *testing.T) { }) } +// Ensure global broadcast updates all peers when GetRateLimits is called on +// either owner or non-owner peer. +func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { + ctx := context.Background() + name := t.Name() + key := randomKey() + + // Determine owner and non-owner peers. + ownerPeerInfo, err := cluster.FindOwningPeer(name, key) + require.NoError(t, err) + ownerDaemon, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + owner := ownerPeerInfo.GRPCAddress + nonOwner := cluster.PeerAt(0).GRPCAddress + if nonOwner == owner { + nonOwner = cluster.PeerAt(1).GRPCAddress + } + require.NotEqual(t, owner, nonOwner) + + // Connect to owner and non-owner peers in round robin. + dialOpts := []grpc.DialOption{ + grpc.WithResolvers(guber.NewStaticBuilder()), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`), + } + address := fmt.Sprintf("static:///%s,%s", owner, nonOwner) + conn, err := grpc.DialContext(ctx, address, dialOpts...) + require.NoError(t, err) + client := guber.NewV1Client(conn) + + sendHit := func(status guber.Status, i int) { + ctx, cancel := context.WithTimeout(ctx, 10*clock.Second) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 5, + Hits: 1, + Limit: 2, + }, + }, + }) + require.NoError(t, err, i) + item := resp.Responses[0] + assert.Equal(t, "", item.GetError(), fmt.Sprintf("mismatch error, iteration %d", i)) + assert.Equal(t, status, item.GetStatus(), fmt.Sprintf("mismatch status, iteration %d", i)) + } + + // Send two hits that should be processed by the owner and non-owner and + // deplete the limit consistently. + sendHit(guber.Status_UNDER_LIMIT, 1) + sendHit(guber.Status_UNDER_LIMIT, 2) + require.NoError(t, waitForBroadcast(clock.Second*3, ownerDaemon, 1)) + + // All successive hits should return OVER_LIMIT. + for i := 2; i <= 10; i++ { + sendHit(guber.Status_OVER_LIMIT, i) + } +} + +func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { + name := t.Name() + key := randomKey() + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(expectedStatus guber.Status, hits int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := peers[0].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_TOKEN_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 5, + Hits: hits, + Limit: 2, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + + // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining + sendHit(guber.Status_UNDER_LIMIT, 1) + sendHit(guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) + // Since the remainder is 0, the peer should set OVER_LIMIT instead of waiting for the owner + // to respond with OVER_LIMIT. + sendHit(guber.Status_OVER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 2)) + // The status should still be OVER_LIMIT + sendHit(guber.Status_OVER_LIMIT, 0) +} + +func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { + name := t.Name() + key := randomKey() + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 5, + Hits: hits, + Limit: 2, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + + // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peers + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) + // Ask a different peer if the status is over the limit + sendHit(peers[1].MustClient(), guber.Status_OVER_LIMIT, 1) +} + func getMetricRequest(url string, name string) (*model.Sample, error) { resp, err := http.Get(url) if err != nil { @@ -1346,136 +1495,6 @@ func TestGetPeerRateLimits(t *testing.T) { // TODO: Add a test for sending no rate limits RateLimitReqList.RateLimits = nil -func getMetric(in io.Reader, name string) (*model.Sample, error) { - dec := expfmt.SampleDecoder{ - Dec: expfmt.NewDecoder(in, expfmt.FmtText), - Opts: &expfmt.DecodeOptions{ - Timestamp: model.Now(), - }, - } - - var all model.Vector - for { - var smpls model.Vector - err := dec.Decode(&smpls) - if err == io.EOF { - break - } - if err != nil { - return nil, err - } - all = append(all, smpls...) - } - - for _, s := range all { - if strings.Contains(s.Metric.String(), name) { - return s, nil - } - } - return nil, nil -} - -// waitForBroadcast waits until the broadcast count for the daemon changes to -// the expected value. Returns an error if the expected value is not found -// before the context is cancelled. -func waitForBroadcast(timeout clock.Duration, d *guber.Daemon, expect int) error { - // fmt.Printf("waitForBroadcast() peer: %s\n", d.InstanceID) - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - for { - m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), - "gubernator_broadcast_duration_count") - if err != nil { - return err - } - // fmt.Printf("gubernator_broadcast_duration_count: %f\n", m.Value) - - // It's possible a broadcast occurred twice if waiting for multiple peer to - // forward updates to the owner. - if int(m.Value) >= expect { - // Give the nodes some time to process the broadcasts - // clock.Sleep(clock.Millisecond * 500) - return nil - } - - select { - case <-clock.After(time.Millisecond * 100): - case <-ctx.Done(): - return ctx.Err() - } - } -} - -// waitForUpdate waits until the global hits update count for the daemon -// changes to the expected value. Returns an error if the expected value is not -// found before the context is cancelled. -func waitForUpdate(timeout clock.Duration, d *guber.Daemon, expect int) error { - // fmt.Printf("waitForUpdate() peer: %s\n", d.InstanceID) - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - for { - m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), - "gubernator_global_send_duration_count") - if err != nil { - return err - } - // fmt.Printf("gubernator_global_send_duration_count: %f\n", m.Value) - - // It's possible a broadcast occurred twice if waiting for multiple peer to - // forward updates to the owner. - if int(m.Value) >= expect { - return nil - } - - select { - case <-clock.After(time.Millisecond * 100): - case <-ctx.Done(): - return ctx.Err() - } - } -} - -func getMetricValue(t *testing.T, d *guber.Daemon, name string) float64 { - m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), - name) - require.NoError(t, err) - if m == nil { - return 0 - } - return float64(m.Value) -} - -// Get metric counter values on each peer. -func getPeerCounters(t *testing.T, peers []*guber.Daemon, name string) map[string]int { - counters := make(map[string]int) - for _, peer := range peers { - counters[peer.InstanceID] = int(getMetricValue(t, peer, name)) - } - return counters -} - -func sendHit(t *testing.T, d *guber.Daemon, req *guber.RateLimitReq, expectStatus guber.Status, expectRemaining int64) { - if req.Hits != 0 { - t.Logf("Sending %d hits to peer %s", req.Hits, d.InstanceID) - } - client := d.MustClient() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{req}, - }) - require.NoError(t, err) - item := resp.Responses[0] - assert.Equal(t, "", item.Error) - if expectRemaining >= 0 { - assert.Equal(t, expectRemaining, item.Remaining) - } - assert.Equal(t, expectStatus, item.Status) - assert.Equal(t, req.Limit, item.Limit) -} - func TestGlobalBehavior(t *testing.T) { const limit = 1000 broadcastTimeout := 400 * time.Millisecond @@ -1597,7 +1616,7 @@ func TestGlobalBehavior(t *testing.T) { assert.Equal(t, expected, gprlCounters2[peer.InstanceID]) } - // Verify all peers report same remaining value. + // Verify all peers report consistent remaining value value. for _, peer := range cluster.GetDaemons() { if peer.PeerInfo.DataCenter != cluster.DataCenterNone { continue @@ -1717,7 +1736,7 @@ func TestGlobalBehavior(t *testing.T) { assert.Equal(t, expected, gprlCounters2[peer.InstanceID], "gprlCounter %s", peer.InstanceID) } - // Verify all peers report same remaining value. + // Verify all peers report consistent remaining value value. for _, peer := range cluster.GetDaemons() { if peer.PeerInfo.DataCenter != cluster.DataCenterNone { continue @@ -1865,7 +1884,7 @@ func TestGlobalBehavior(t *testing.T) { assert.Equal(t, expected, gprlCounters2[peer.InstanceID], "gprlCounter %s", peer.InstanceID) } - // Verify all peers report same remaining value. + // Verify all peers report consistent remaining value value. for _, peer := range cluster.GetDaemons() { if peer.PeerInfo.DataCenter != cluster.DataCenterNone { continue @@ -1876,3 +1895,137 @@ func TestGlobalBehavior(t *testing.T) { } }) } + +func getMetric(in io.Reader, name string) (*model.Sample, error) { + dec := expfmt.SampleDecoder{ + Dec: expfmt.NewDecoder(in, expfmt.FmtText), + Opts: &expfmt.DecodeOptions{ + Timestamp: model.Now(), + }, + } + + var all model.Vector + for { + var smpls model.Vector + err := dec.Decode(&smpls) + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + all = append(all, smpls...) + } + + for _, s := range all { + if strings.Contains(s.Metric.String(), name) { + return s, nil + } + } + return nil, nil +} + +// waitForBroadcast waits until the broadcast count for the daemon changes to +// the expected value. Returns an error if the expected value is not found +// before the context is cancelled. +func waitForBroadcast(timeout clock.Duration, d *guber.Daemon, expect int) error { + // fmt.Printf("waitForBroadcast() peer: %s\n", d.InstanceID) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for { + m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), + "gubernator_broadcast_duration_count") + if err != nil { + return err + } + // fmt.Printf("gubernator_broadcast_duration_count: %f\n", m.Value) + + // It's possible a broadcast occurred twice if waiting for multiple peer to + // forward updates to the owner. + if int(m.Value) >= expect { + // Give the nodes some time to process the broadcasts + // clock.Sleep(clock.Millisecond * 500) + return nil + } + + select { + case <-clock.After(time.Millisecond * 100): + case <-ctx.Done(): + return ctx.Err() + } + } +} + +// waitForUpdate waits until the global hits update count for the daemon +// changes to the expected value. Returns an error if the expected value is not +// found before the context is cancelled. +func waitForUpdate(timeout clock.Duration, d *guber.Daemon, expect int) error { + // fmt.Printf("waitForUpdate() peer: %s\n", d.InstanceID) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for { + m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), + "gubernator_global_send_duration_count") + if err != nil { + return err + } + // fmt.Printf("gubernator_global_send_duration_count: %f\n", m.Value) + + // It's possible a broadcast occurred twice if waiting for multiple peer to + // forward updates to the owner. + if int(m.Value) >= expect { + return nil + } + + select { + case <-clock.After(time.Millisecond * 100): + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func getMetricValue(t *testing.T, d *guber.Daemon, name string) float64 { + m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), + name) + require.NoError(t, err) + if m == nil { + return 0 + } + return float64(m.Value) +} + +// Get metric counter values on each peer. +func getPeerCounters(t *testing.T, peers []*guber.Daemon, name string) map[string]int { + counters := make(map[string]int) + for _, peer := range peers { + counters[peer.InstanceID] = int(getMetricValue(t, peer, name)) + } + return counters +} + +func sendHit(t *testing.T, d *guber.Daemon, req *guber.RateLimitReq, expectStatus guber.Status, expectRemaining int64) { + if req.Hits != 0 { + t.Logf("Sending %d hits to peer %s", req.Hits, d.InstanceID) + } + client := d.MustClient() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{req}, + }) + require.NoError(t, err) + item := resp.Responses[0] + assert.Equal(t, "", item.Error) + if expectRemaining >= 0 { + assert.Equal(t, expectRemaining, item.Remaining) + } + assert.Equal(t, expectStatus, item.Status) + assert.Equal(t, req.Limit, item.Limit) +} + +func randomKey() string { + return fmt.Sprintf("%016x", rand.Int()) +}