Skip to content

Commit 4fc3123

Browse files
authored
executor: fix goroutine leak when exceed quota in hash agg (#58078) (#58462)
close #58004
1 parent 11ca251 commit 4fc3123

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

pkg/executor/aggregate/agg_hash_executor.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ import (
3434
"github.com/pingcap/tidb/pkg/util/chunk"
3535
"github.com/pingcap/tidb/pkg/util/disk"
3636
"github.com/pingcap/tidb/pkg/util/hack"
37+
"github.com/pingcap/tidb/pkg/util/logutil"
3738
"github.com/pingcap/tidb/pkg/util/memory"
3839
"github.com/pingcap/tidb/pkg/util/set"
40+
"go.uber.org/zap"
3941
)
4042

4143
// HashAggInput indicates the input of hash agg exec.
@@ -154,6 +156,8 @@ type HashAggExec struct {
154156
spillHelper *parallelHashAggSpillHelper
155157
// isChildDrained indicates whether the all data from child has been taken out.
156158
isChildDrained bool
159+
160+
invalidMemoryUsageForTrackingTest bool
157161
}
158162

159163
// Close implements the Executor Close interface.
@@ -204,6 +208,10 @@ func (e *HashAggExec) Close() error {
204208
channel.Clear(e.finalOutputCh)
205209
e.executed.Store(false)
206210
if e.memTracker != nil {
211+
if e.memTracker.BytesConsumed() < 0 {
212+
logutil.BgLogger().Warn("Memory tracker's counter is invalid", zap.Int64("counter", e.memTracker.BytesConsumed()))
213+
e.invalidMemoryUsageForTrackingTest = true
214+
}
207215
e.memTracker.ReplaceBytesUsed(0)
208216
}
209217
e.parallelExecValid = false
@@ -289,6 +297,8 @@ func (e *HashAggExec) initForUnparallelExec() {
289297
}
290298

291299
func (e *HashAggExec) initPartialWorkers(partialConcurrency int, finalConcurrency int, ctx sessionctx.Context) {
300+
memUsage := int64(0)
301+
292302
for i := 0; i < partialConcurrency; i++ {
293303
partialResultsMap := make([]aggfuncs.AggPartialResultMapper, finalConcurrency)
294304
for i := 0; i < finalConcurrency; i++ {
@@ -316,6 +326,8 @@ func (e *HashAggExec) initPartialWorkers(partialConcurrency int, finalConcurrenc
316326
inflightChunkSync: e.inflightChunkSync,
317327
}
318328

329+
memUsage += e.partialWorkers[i].chk.MemoryUsage()
330+
319331
e.partialWorkers[i].partialResultNumInRow = e.partialWorkers[i].getPartialResultSliceLenConsiderByteAlign()
320332
for j := 0; j < finalConcurrency; j++ {
321333
e.partialWorkers[i].BInMaps[j] = 0
@@ -332,8 +344,11 @@ func (e *HashAggExec) initPartialWorkers(partialConcurrency int, finalConcurrenc
332344
chk: chunk.New(e.Children(0).RetFieldTypes(), 0, e.MaxChunkSize()),
333345
giveBackCh: e.partialWorkers[i].inputCh,
334346
}
347+
memUsage += input.chk.MemoryUsage()
335348
e.inputCh <- input
336349
}
350+
351+
e.memTracker.Consume(memUsage)
337352
}
338353

339354
func (e *HashAggExec) initFinalWorkers(finalConcurrency int) {
@@ -442,6 +457,7 @@ func (e *HashAggExec) fetchChildData(ctx context.Context, waitGroup *sync.WaitGr
442457
ok bool
443458
err error
444459
)
460+
445461
defer func() {
446462
if r := recover(); r != nil {
447463
recoveryHashAgg(e.finalOutputCh, r)
@@ -494,6 +510,7 @@ func (e *HashAggExec) fetchChildData(ctx context.Context, waitGroup *sync.WaitGr
494510
input.giveBackCh <- chk
495511

496512
if hasError := e.spillIfNeed(); hasError {
513+
e.memTracker.Consume(-mSize)
497514
return
498515
}
499516
}
@@ -857,3 +874,8 @@ func (e *HashAggExec) IsSpillTriggeredForTest() bool {
857874
}
858875
return false
859876
}
877+
878+
// IsInvalidMemoryUsageTrackingForTest is for test
879+
func (e *HashAggExec) IsInvalidMemoryUsageTrackingForTest() bool {
880+
return e.invalidMemoryUsageForTrackingTest
881+
}

pkg/executor/aggregate/agg_spill_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ func generateResult(t *testing.T, ctx *mock.Context, dataSource *testutil.MockDa
150150
resultRows = append(resultRows, chk.GetRow(i))
151151
}
152152
}
153+
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
153154
aggExec.Close()
154155

155156
require.False(t, aggExec.IsSpillTriggeredForTest())
@@ -315,6 +316,7 @@ func executeCorrecResultTest(t *testing.T, ctx *mock.Context, aggExec *aggregate
315316
resultRows = append(resultRows, chk.GetRow(i))
316317
}
317318
}
319+
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
318320
aggExec.Close()
319321

320322
require.True(t, aggExec.IsSpillTriggeredForTest())
@@ -351,6 +353,7 @@ func fallBackActionTest(t *testing.T) {
351353
}
352354
chk.Reset()
353355
}
356+
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
354357
aggExec.Close()
355358
require.Less(t, 0, newRootExceedAction.GetTriggeredNum())
356359
}
@@ -373,6 +376,7 @@ func randomFailTest(t *testing.T, ctx *mock.Context, aggExec *aggregate.HashAggE
373376
go func() {
374377
time.Sleep(time.Duration(rand.Int31n(300)) * time.Millisecond)
375378
once.Do(func() {
379+
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
376380
aggExec.Close()
377381
})
378382
goRoutineWaiter.Done()
@@ -382,6 +386,7 @@ func randomFailTest(t *testing.T, ctx *mock.Context, aggExec *aggregate.HashAggE
382386
err := aggExec.Next(tmpCtx, chk)
383387
if err != nil {
384388
once.Do(func() {
389+
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
385390
err = aggExec.Close()
386391
require.Equal(t, nil, err)
387392
})
@@ -393,6 +398,7 @@ func randomFailTest(t *testing.T, ctx *mock.Context, aggExec *aggregate.HashAggE
393398
chk.Reset()
394399
}
395400
once.Do(func() {
401+
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
396402
aggExec.Close()
397403
})
398404
}

0 commit comments

Comments
 (0)