Skip to content

Commit

Permalink
Generic singleflight
Browse files Browse the repository at this point in the history
  • Loading branch information
janos committed Dec 18, 2021
1 parent f856ab8 commit 8feb798
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 37 deletions.
14 changes: 8 additions & 6 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.16
stable: 'false'
go-version: '1.18.0-beta1'

- name: Checkout
uses: actions/checkout@v1
Expand All @@ -31,11 +32,12 @@ jobs:
${{ runner.OS }}-build-
${{ runner.OS }}-
- name: Lint
uses: golangci/golangci-lint-action@v2
with:
version: v1.40.1
args: --timeout 10m
# todo: enable when go 1.18 is supported
# - name: Lint
# uses: golangci/golangci-lint-action@v2
# with:
# version: v1.43.0
# args: --timeout 10m

- name: Vet
if: matrix.os == 'ubuntu-latest'
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module resenje.org/singleflight

go 1.13
go 1.18
18 changes: 9 additions & 9 deletions singleflight.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (

// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type Group struct {
calls map[string]*call // lazily initialized
type Group[K comparable, V any] struct {
calls map[K]*call[V] // lazily initialized
mu sync.Mutex // protects calls
}

Expand All @@ -31,10 +31,10 @@ type Group struct {
// effect the execution and returned values of others.
//
// The return value shared indicates whether v was given to multiple callers.
func (g *Group) Do(ctx context.Context, key string, fn func(ctx context.Context) (interface{}, error)) (v interface{}, shared bool, err error) {
func (g *Group[K, V]) Do(ctx context.Context, key K, fn func(ctx context.Context) (V, error)) (v V, shared bool, err error) {
g.mu.Lock()
if g.calls == nil {
g.calls = make(map[string]*call)
g.calls = make(map[K]*call[V])
}

if c, ok := g.calls[key]; ok {
Expand All @@ -47,7 +47,7 @@ func (g *Group) Do(ctx context.Context, key string, fn func(ctx context.Context)

callCtx, cancel := context.WithCancel(context.Background())

c := &call{
c := &call[V]{
done: make(chan struct{}),
cancel: cancel,
counter: 1,
Expand All @@ -64,7 +64,7 @@ func (g *Group) Do(ctx context.Context, key string, fn func(ctx context.Context)
}

// wait for function passed to Do to finish or context to be done.
func (g *Group) wait(ctx context.Context, key string, c *call) (v interface{}, shared bool, err error) {
func (g *Group[K, V]) wait(ctx context.Context, key K, c *call[V]) (v V, shared bool, err error) {
select {
case <-c.done:
v = c.val
Expand All @@ -87,7 +87,7 @@ func (g *Group) wait(ctx context.Context, key string, c *call) (v interface{}, s
// Forget tells the singleflight to forget about a key. Future calls
// to Do for this key will call the function rather than waiting for
// an earlier call to complete.
func (g *Group) Forget(key string) {
func (g *Group[K, V]) Forget(key K) {
g.mu.Lock()
if c, ok := g.calls[key]; ok {
c.forgotten = true
Expand All @@ -97,9 +97,9 @@ func (g *Group) Forget(key string) {
}

// call stores information about as single function call passed to Do function.
type call struct {
type call[V any] struct {
// val and err hold the state about results of the function call.
val interface{}
val V
err error

// done channel signals that the function call is done.
Expand Down
42 changes: 21 additions & 21 deletions singleflight_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import (
)

func TestDo(t *testing.T) {
var g singleflight.Group
var g singleflight.Group[string, string]

want := "val"
got, shared, err := g.Do(context.Background(), "key", func(_ context.Context) (interface{}, error) {
got, shared, err := g.Do(context.Background(), "key", func(_ context.Context) (string, error) {
return want, nil
})
if err != nil {
Expand All @@ -40,21 +40,21 @@ func TestDo(t *testing.T) {
}

func TestDo_error(t *testing.T) {
var g singleflight.Group
var g singleflight.Group[string, string]
wantErr := errors.New("test error")
got, _, err := g.Do(context.Background(), "key", func(_ context.Context) (interface{}, error) {
return nil, wantErr
got, _, err := g.Do(context.Background(), "key", func(_ context.Context) (string, error) {
return "", wantErr
})
if err != wantErr {
t.Errorf("got error %v, want %v", err, wantErr)
}
if got != nil {
if got != "" {
t.Errorf("unexpected value %#v", got)
}
}

func TestDo_multipleCalls(t *testing.T) {
var g singleflight.Group
var g singleflight.Group[string, string]

want := "val"
var counter int32
Expand All @@ -68,7 +68,7 @@ func TestDo_multipleCalls(t *testing.T) {
for i := 0; i < n; i++ {
go func(i int) {
defer wg.Done()
got[i], shared[i], err[i] = g.Do(context.Background(), "key", func(_ context.Context) (interface{}, error) {
got[i], shared[i], err[i] = g.Do(context.Background(), "key", func(_ context.Context) (string, error) {
atomic.AddInt32(&counter, 1)
time.Sleep(100 * time.Millisecond)
return want, nil
Expand All @@ -95,11 +95,11 @@ func TestDo_multipleCalls(t *testing.T) {
}

func TestDo_callRemoval(t *testing.T) {
var g singleflight.Group
var g singleflight.Group[string, string]

wantPrefix := "val"
counter := 0
fn := func(_ context.Context) (interface{}, error) {
fn := func(_ context.Context) (string, error) {
counter++
return wantPrefix + strconv.Itoa(counter), nil
}
Expand Down Expand Up @@ -131,7 +131,7 @@ func TestDo_cancelContext(t *testing.T) {
done := make(chan struct{})
defer close(done)

var g singleflight.Group
var g singleflight.Group[string, string]

want := "val"
ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -140,7 +140,7 @@ func TestDo_cancelContext(t *testing.T) {
cancel()
}()
start := time.Now()
got, shared, err := g.Do(ctx, "key", func(_ context.Context) (interface{}, error) {
got, shared, err := g.Do(ctx, "key", func(_ context.Context) (string, error) {
select {
case <-time.After(time.Second):
case <-done:
Expand All @@ -156,7 +156,7 @@ func TestDo_cancelContext(t *testing.T) {
if shared {
t.Error("the value should not be shared")
}
if got != nil {
if got != "" {
t.Errorf("unexpected value %#v", got)
}
}
Expand All @@ -165,10 +165,10 @@ func TestDo_cancelContextSecond(t *testing.T) {
done := make(chan struct{})
defer close(done)

var g singleflight.Group
var g singleflight.Group[string, string]

want := "val"
fn := func(_ context.Context) (interface{}, error) {
fn := func(_ context.Context) (string, error) {
select {
case <-time.After(time.Second):
case <-done:
Expand Down Expand Up @@ -196,7 +196,7 @@ func TestDo_cancelContextSecond(t *testing.T) {
if !shared {
t.Error("the value should be shared")
}
if got != nil {
if got != "" {
t.Errorf("unexpected value %#v", got)
}
}
Expand All @@ -205,12 +205,12 @@ func TestForget(t *testing.T) {
done := make(chan struct{})
defer close(done)

var g singleflight.Group
var g singleflight.Group[string, string]

wantPrefix := "val"
var counter uint64
firstCall := make(chan struct{})
fn := func(_ context.Context) (interface{}, error) {
fn := func(_ context.Context) (string, error) {
c := atomic.AddUint64(&counter, 1)
if c == 1 {
close(firstCall)
Expand Down Expand Up @@ -252,7 +252,7 @@ func TestDo_multipleCallsCanceled(t *testing.T) {
done := make(chan struct{})
defer close(done)

var g singleflight.Group
var g singleflight.Group[string, string]

var counter int32

Expand All @@ -271,7 +271,7 @@ func TestDo_multipleCallsCanceled(t *testing.T) {
contexts[i] = ctx
cancelFuncs[i] = cancel
mu.Unlock()
_, _, _ = g.Do(ctx, "key", func(ctx context.Context) (interface{}, error) {
_, _, _ = g.Do(ctx, "key", func(ctx context.Context) (string, error) {
atomic.AddInt32(&counter, 1)
close(fnCalled)
var err error
Expand All @@ -288,7 +288,7 @@ func TestDo_multipleCallsCanceled(t *testing.T) {

fnErrChan <- err

return nil, nil
return "", nil
})
}(i)
}
Expand Down

0 comments on commit 8feb798

Please sign in to comment.