Skip to content

Commit 986952c

Browse files
authored
fix: make WaitGroupContext reusable (#28)
1 parent c45fbe8 commit 986952c

File tree

2 files changed

+86
-16
lines changed

2 files changed

+86
-16
lines changed

wait_group_context.go

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,28 @@ import (
1111
// time, Wait can be used to block until all goroutines have finished or the
1212
// given context is done.
1313
type WaitGroupContext struct {
14-
ctx context.Context
15-
done chan struct{}
16-
counter atomic.Int32
17-
state atomic.Int32
14+
ctx context.Context
15+
sem chan struct{}
16+
state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
1817
}
1918

2019
// NewWaitGroupContext returns a new WaitGroupContext with Context ctx.
2120
func NewWaitGroupContext(ctx context.Context) *WaitGroupContext {
2221
return &WaitGroupContext{
23-
ctx: ctx,
24-
done: make(chan struct{}),
22+
ctx: ctx,
23+
sem: make(chan struct{}),
2524
}
2625
}
2726

2827
// Add adds delta, which may be negative, to the WaitGroupContext counter.
2928
// If the counter becomes zero, all goroutines blocked on Wait are released.
3029
// If the counter goes negative, Add panics.
3130
func (wgc *WaitGroupContext) Add(delta int) {
32-
counter := wgc.counter.Add(int32(delta))
33-
if counter == 0 && wgc.state.CompareAndSwap(0, 1) {
34-
wgc.release()
35-
} else if counter < 0 && wgc.state.Load() == 0 {
31+
state := wgc.state.Add(uint64(delta) << 32)
32+
counter := int32(state >> 32)
33+
if counter == 0 {
34+
wgc.notifyAll()
35+
} else if counter < 0 {
3636
panic("async: negative WaitGroupContext counter")
3737
}
3838
}
@@ -44,12 +44,36 @@ func (wgc *WaitGroupContext) Done() {
4444

4545
// Wait blocks until the wait group counter is zero or ctx is done.
4646
func (wgc *WaitGroupContext) Wait() {
47-
select {
48-
case <-wgc.ctx.Done():
49-
case <-wgc.done:
47+
for {
48+
state := wgc.state.Load()
49+
counter := int32(state >> 32)
50+
if counter == 0 {
51+
return
52+
}
53+
if wgc.state.CompareAndSwap(state, state+1) {
54+
select {
55+
case <-wgc.sem:
56+
if wgc.state.Load() != 0 {
57+
panic("async: WaitGroupContext is reused before " +
58+
"previous Wait has returned")
59+
}
60+
case <-wgc.ctx.Done():
61+
}
62+
return
63+
}
5064
}
5165
}
5266

53-
func (wgc *WaitGroupContext) release() {
54-
close(wgc.done)
67+
// notifyAll releases all goroutines blocked in Wait and resets
68+
// the wait group state.
69+
func (wgc *WaitGroupContext) notifyAll() {
70+
state := wgc.state.Load()
71+
waiting := uint32(state)
72+
wgc.state.Store(0)
73+
for ; waiting != 0; waiting-- {
74+
select {
75+
case wgc.sem <- struct{}{}:
76+
case <-wgc.ctx.Done():
77+
}
78+
}
5579
}

wait_group_context_test.go

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,56 @@ func TestWaitGroupContextCanceled(t *testing.T) {
6767
assert.Equal(t, int(result.Load()), 111)
6868
}
6969

70-
func TestWaitGroupContextPanic(t *testing.T) {
70+
func TestWaitGroupContextPanicNegativeCounter(t *testing.T) {
7171
negativeCounter := func() {
7272
wgc := NewWaitGroupContext(context.Background())
7373
wgc.Add(-2)
7474
}
7575
assert.Panic(t, negativeCounter)
7676
}
77+
78+
func TestWaitGroupContextPanicReused(t *testing.T) {
79+
reusedBeforeWaitReturned := func() {
80+
var result atomic.Int32
81+
wgc := NewWaitGroupContext(context.Background())
82+
83+
n := 10
84+
for i := 0; i < n; i++ {
85+
wgc.Add(1)
86+
go func() {
87+
defer wgc.Add(1)
88+
defer wgc.Done()
89+
result.Add(1)
90+
}()
91+
wgc.Wait()
92+
}
93+
}
94+
assert.Panic(t, reusedBeforeWaitReturned)
95+
}
96+
97+
func TestWaitGroupContextReused(t *testing.T) {
98+
var result atomic.Int32
99+
wgc := NewWaitGroupContext(context.Background())
100+
101+
n := 1000
102+
for i := 0; i < n; i++ {
103+
assert.Equal(t, int(result.Load()), i*3)
104+
wgc.Add(2)
105+
go func() {
106+
defer wgc.Done()
107+
result.Add(1)
108+
}()
109+
go func() {
110+
defer wgc.Done()
111+
result.Add(1)
112+
}()
113+
go func() {
114+
wgc.Wait()
115+
result.Add(1)
116+
}()
117+
wgc.Wait()
118+
time.Sleep(time.Millisecond)
119+
}
120+
121+
assert.Equal(t, int(result.Load()), n*3)
122+
}

0 commit comments

Comments
 (0)