@@ -11,28 +11,28 @@ import (
11
11
// time, Wait can be used to block until all goroutines have finished or the
12
12
// given context is done.
13
13
type WaitGroupContext struct {
14
- ctx context.Context
15
- done chan struct {}
16
- counter atomic.Int32
17
- state atomic.Int32
14
+ ctx context.Context
15
+ sem chan struct {}
16
+ state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
18
17
}
19
18
20
19
// NewWaitGroupContext returns a new WaitGroupContext with Context ctx.
21
20
func NewWaitGroupContext (ctx context.Context ) * WaitGroupContext {
22
21
return & WaitGroupContext {
23
- ctx : ctx ,
24
- done : make (chan struct {}),
22
+ ctx : ctx ,
23
+ sem : make (chan struct {}),
25
24
}
26
25
}
27
26
28
27
// Add adds delta, which may be negative, to the WaitGroupContext counter.
29
28
// If the counter becomes zero, all goroutines blocked on Wait are released.
30
29
// If the counter goes negative, Add panics.
31
30
func (wgc * WaitGroupContext ) Add (delta int ) {
32
- counter := wgc .counter .Add (int32 (delta ))
33
- if counter == 0 && wgc .state .CompareAndSwap (0 , 1 ) {
34
- wgc .release ()
35
- } else if counter < 0 && wgc .state .Load () == 0 {
31
+ state := wgc .state .Add (uint64 (delta ) << 32 )
32
+ counter := int32 (state >> 32 )
33
+ if counter == 0 {
34
+ wgc .notifyAll ()
35
+ } else if counter < 0 {
36
36
panic ("async: negative WaitGroupContext counter" )
37
37
}
38
38
}
@@ -44,12 +44,36 @@ func (wgc *WaitGroupContext) Done() {
44
44
45
45
// Wait blocks until the wait group counter is zero or ctx is done.
46
46
func (wgc * WaitGroupContext ) Wait () {
47
- select {
48
- case <- wgc .ctx .Done ():
49
- case <- wgc .done :
47
+ for {
48
+ state := wgc .state .Load ()
49
+ counter := int32 (state >> 32 )
50
+ if counter == 0 {
51
+ return
52
+ }
53
+ if wgc .state .CompareAndSwap (state , state + 1 ) {
54
+ select {
55
+ case <- wgc .sem :
56
+ if wgc .state .Load () != 0 {
57
+ panic ("async: WaitGroupContext is reused before " +
58
+ "previous Wait has returned" )
59
+ }
60
+ case <- wgc .ctx .Done ():
61
+ }
62
+ return
63
+ }
50
64
}
51
65
}
52
66
53
- func (wgc * WaitGroupContext ) release () {
54
- close (wgc .done )
67
+ // notifyAll releases all goroutines blocked in Wait and resets
68
+ // the wait group state.
69
+ func (wgc * WaitGroupContext ) notifyAll () {
70
+ state := wgc .state .Load ()
71
+ waiting := uint32 (state )
72
+ wgc .state .Store (0 )
73
+ for ; waiting != 0 ; waiting -- {
74
+ select {
75
+ case wgc .sem <- struct {}{}:
76
+ case <- wgc .ctx .Done ():
77
+ }
78
+ }
55
79
}
0 commit comments