Skip to content

Commit 61d7f3c

Browse files
authored
executor: fix goroutine leak when exceed quota in hash agg (#58078) (#61803)
close #58004
1 parent ce3b858 commit 61d7f3c

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

pkg/executor/aggregate/agg_hash_executor.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ type HashAggExec struct {
154154
spillHelper *parallelHashAggSpillHelper
155155
// isChildDrained indicates whether the all data from child has been taken out.
156156
isChildDrained bool
157+
158+
invalidMemoryUsageForTrackingTest bool
157159
}
158160

159161
// Close implements the Executor Close interface.
@@ -204,7 +206,11 @@ func (e *HashAggExec) Close() error {
204206
channel.Clear(e.finalOutputCh)
205207
e.executed.Store(false)
206208
if e.memTracker != nil {
207-
e.memTracker.ReplaceBytesUsed(0)
209+
if e.memTracker.BytesConsumed() < 0 {
210+
e.invalidMemoryUsageForTrackingTest = true
211+
} else {
212+
e.memTracker.ReplaceBytesUsed(0)
213+
}
208214
}
209215
e.parallelExecValid = false
210216
if e.parallelAggSpillAction != nil {
@@ -285,6 +291,8 @@ func (e *HashAggExec) initForUnparallelExec() {
285291
}
286292

287293
func (e *HashAggExec) initPartialWorkers(partialConcurrency int, finalConcurrency int, ctx sessionctx.Context) {
294+
memUsage := int64(0)
295+
288296
for i := 0; i < partialConcurrency; i++ {
289297
partialResultsMap := make([]aggfuncs.AggPartialResultMapper, finalConcurrency)
290298
for i := 0; i < finalConcurrency; i++ {
@@ -311,6 +319,8 @@ func (e *HashAggExec) initPartialWorkers(partialConcurrency int, finalConcurrenc
311319
inflightChunkSync: e.inflightChunkSync,
312320
}
313321

322+
memUsage += e.partialWorkers[i].chk.MemoryUsage()
323+
314324
e.partialWorkers[i].partialResultNumInRow = e.partialWorkers[i].getPartialResultSliceLenConsiderByteAlign()
315325
for j := 0; j < finalConcurrency; j++ {
316326
e.partialWorkers[i].BInMaps[j] = 0
@@ -328,9 +338,11 @@ func (e *HashAggExec) initPartialWorkers(partialConcurrency int, finalConcurrenc
328338
chk: exec.NewFirstChunk(e.Children(0)),
329339
giveBackCh: e.partialWorkers[i].inputCh,
330340
}
331-
e.memTracker.Consume(input.chk.MemoryUsage())
341+
memUsage += input.chk.MemoryUsage()
332342
e.inputCh <- input
333343
}
344+
345+
e.memTracker.Consume(memUsage)
334346
}
335347

336348
func (e *HashAggExec) initFinalWorkers(finalConcurrency int) {
@@ -442,6 +454,7 @@ func (e *HashAggExec) fetchChildData(ctx context.Context, waitGroup *sync.WaitGr
442454
ok bool
443455
err error
444456
)
457+
445458
defer func() {
446459
if r := recover(); r != nil {
447460
recoveryHashAgg(e.finalOutputCh, r)
@@ -494,6 +507,7 @@ func (e *HashAggExec) fetchChildData(ctx context.Context, waitGroup *sync.WaitGr
494507
input.giveBackCh <- chk
495508

496509
if hasError := e.spillIfNeed(); hasError {
510+
e.memTracker.Consume(-mSize)
497511
return
498512
}
499513
}
@@ -857,3 +871,8 @@ func (e *HashAggExec) IsSpillTriggeredForTest() bool {
857871
}
858872
return false
859873
}
874+
875+
// IsInvalidMemoryUsageTrackingForTest is for test
876+
func (e *HashAggExec) IsInvalidMemoryUsageTrackingForTest() bool {
877+
return e.invalidMemoryUsageForTrackingTest
878+
}

pkg/executor/aggregate/agg_spill_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ func generateResult(t *testing.T, ctx *mock.Context, dataSource *testutil.MockDa
150150
resultRows = append(resultRows, chk.GetRow(i))
151151
}
152152
}
153+
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
153154
aggExec.Close()
154155

155156
require.False(t, aggExec.IsSpillTriggeredForTest())
@@ -315,6 +316,7 @@ func executeCorrecResultTest(t *testing.T, ctx *mock.Context, aggExec *aggregate
315316
resultRows = append(resultRows, chk.GetRow(i))
316317
}
317318
}
319+
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
318320
aggExec.Close()
319321

320322
require.True(t, aggExec.IsSpillTriggeredForTest())
@@ -351,6 +353,7 @@ func fallBackActionTest(t *testing.T) {
351353
}
352354
chk.Reset()
353355
}
356+
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
354357
aggExec.Close()
355358
require.Less(t, 0, newRootExceedAction.GetTriggeredNum())
356359
}
@@ -373,6 +376,7 @@ func randomFailTest(t *testing.T, ctx *mock.Context, aggExec *aggregate.HashAggE
373376
go func() {
374377
time.Sleep(time.Duration(rand.Int31n(300)) * time.Millisecond)
375378
once.Do(func() {
379+
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
376380
aggExec.Close()
377381
})
378382
goRoutineWaiter.Done()
@@ -382,6 +386,7 @@ func randomFailTest(t *testing.T, ctx *mock.Context, aggExec *aggregate.HashAggE
382386
err := aggExec.Next(tmpCtx, chk)
383387
if err != nil {
384388
once.Do(func() {
389+
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
385390
err = aggExec.Close()
386391
require.Equal(t, nil, err)
387392
})
@@ -393,6 +398,7 @@ func randomFailTest(t *testing.T, ctx *mock.Context, aggExec *aggregate.HashAggE
393398
chk.Reset()
394399
}
395400
once.Do(func() {
401+
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
396402
aggExec.Close()
397403
})
398404
}

0 commit comments

Comments
 (0)