From aaf4aff741c8cc0eae45526c8176bcf89aa15995 Mon Sep 17 00:00:00 2001 From: eran shmuely Date: Tue, 3 Sep 2024 15:03:04 +0300 Subject: [PATCH] fix duplicate key deletion when forget called --- singleflight.go | 11 ++++++- singleflight_test.go | 70 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/singleflight.go b/singleflight.go index d8c074d..b55c23d 100644 --- a/singleflight.go +++ b/singleflight.go @@ -113,7 +113,9 @@ func (g *Group[K, V]) wait(ctx context.Context, key K, c *call[V]) (v V, shared c.counter-- if c.counter == 0 { c.cancel() - delete(g.calls, key) + if !c.forgotten { + delete(g.calls, key) + } } shared = c.shared g.mu.Unlock() @@ -130,6 +132,9 @@ func (g *Group[K, V]) wait(ctx context.Context, key K, c *call[V]) (v V, shared // an earlier call to complete. func (g *Group[K, V]) Forget(key K) { g.mu.Lock() + if c, ok := g.calls[key]; ok { + c.forgotten = true + } delete(g.calls, key) g.mu.Unlock() } @@ -155,4 +160,8 @@ type call[V any] struct { // shared indicates if results val and err are passed to multiple callers. shared bool + + // forgotten indicates whether Forget was called with this call's key + // while the call was still in flight. + forgotten bool } diff --git a/singleflight_test.go b/singleflight_test.go index d371e3c..a4ec1d3 100644 --- a/singleflight_test.go +++ b/singleflight_test.go @@ -370,6 +370,76 @@ func TestForget(t *testing.T) { } } +// Test that singleflight behaves correctly after Forget called. +// See https://github.com/golang/go/issues/31420 +func TestForgetMisbehaving(t *testing.T) { + var g singleflight.Group[string, int] + + var firstStarted, firstFinished sync.WaitGroup + + firstStarted.Add(1) + firstFinished.Add(1) + + firstCh := make(chan struct{}) + go func() { + g.Do(context.Background(), "key", func(ctx context.Context) (i int, e error) { + firstStarted.Done() + <-firstCh + firstFinished.Done() + return + }) + }() + + firstStarted.Wait() + g.Forget("key") // from this point no two function using same key should be executed concurrently + + var secondStarted int32 + var secondFinished int32 + var thirdStarted int32 + + secondCh := make(chan struct{}) + secondRunning := make(chan struct{}) + go func() { + g.Do(context.Background(), "key", func(ctx context.Context) (i int, e error) { + atomic.AddInt32(&secondStarted, 1) + // Notify that we started + secondCh <- struct{}{} + // Wait other get above signal + <-secondRunning + <-secondCh + atomic.AddInt32(&secondFinished, 1) + return 2, nil + }) + }() + + close(firstCh) + firstFinished.Wait() // wait for first execution (which should not affect execution after Forget) + + <-secondCh + // Notify second that we got the signal that it started + secondRunning <- struct{}{} + if atomic.LoadInt32(&secondStarted) != 1 { + t.Fatal("Second execution should be executed due to usage of forget") + } + + if atomic.LoadInt32(&secondFinished) == 1 { + t.Fatal("Second execution should be still active") + } + + close(secondCh) + result, _, _ := g.Do(context.Background(), "key", func(ctx context.Context) (i int, e error) { + atomic.AddInt32(&thirdStarted, 1) + return 3, nil + }) + + if atomic.LoadInt32(&thirdStarted) != 0 { + t.Error("Third call should not be started because was started during second execution") + } + if result != 2 { + t.Errorf("We should receive result produced by second call, expected: 2, got %d", result) + } +} + func TestDo_multipleCallsCanceled(t *testing.T) { const n = 5