From 0bc8ad136c6e59df6f2fae7ea09e3abaa30de890 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Wed, 18 Jun 2025 17:37:31 +0800 Subject: [PATCH] fix and add test --- .../aggfuncs/spill_deserialize_helper.go | 11 +++++++++-- pkg/executor/aggfuncs/spill_helper_test.go | 17 +++++++++++------ pkg/executor/aggfuncs/spill_serialize_helper.go | 8 ++++++-- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/pkg/executor/aggfuncs/spill_deserialize_helper.go b/pkg/executor/aggfuncs/spill_deserialize_helper.go index add6aeaa532ae..b2568ec663df5 100644 --- a/pkg/executor/aggfuncs/spill_deserialize_helper.go +++ b/pkg/executor/aggfuncs/spill_deserialize_helper.go @@ -15,6 +15,8 @@ package aggfuncs import ( + "bytes" + "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/hack" util "github.com/pingcap/tidb/pkg/util/serialization" @@ -216,8 +218,13 @@ func (s *deserializeHelper) deserializePartialResult4SumFloat64(dst *partialResu func (s *deserializeHelper) deserializeBasePartialResult4GroupConcat(dst *basePartialResult4GroupConcat) bool { if s.readRowIndex < s.totalRowCnt { s.pab.Reset(s.column, s.readRowIndex) - dst.valsBuf = util.DeserializeBytesBuffer(s.pab) - dst.buffer = util.DeserializeBytesBuffer(s.pab) + dst.valsBuf = &bytes.Buffer{} + hasBuffer := util.DeserializeBool(s.pab) + if hasBuffer { + dst.buffer = util.DeserializeBytesBuffer(s.pab) + } else { + dst.buffer = nil + } s.readRowIndex++ return true } diff --git a/pkg/executor/aggfuncs/spill_helper_test.go b/pkg/executor/aggfuncs/spill_helper_test.go index fbc1b6a5954b3..012e9cfa46b91 100644 --- a/pkg/executor/aggfuncs/spill_helper_test.go +++ b/pkg/executor/aggfuncs/spill_helper_test.go @@ -25,8 +25,8 @@ import ( "github.com/stretchr/testify/require" ) -var testLongStr1 string = getLongString("平p凯k星x辰c") -var testLongStr2 string = getLongString("123aa啊啊aa") +var testLongStr1 string = getLongString("平352p凯额6辰c") +var testLongStr2 string = getLongString("123a啊f24f去rsgvsfg") func getChunk() *chunk.Chunk { fieldTypes := make([]*types.FieldType, 1) @@ -746,13 +746,15 @@ func TestPartialResult4SumFloat64(t *testing.T) { func TestBasePartialResult4GroupConcat(t *testing.T) { var serializeHelper = NewSerializeHelper() + serializeHelper.buf = make([]byte, 0) bufSizeChecker := newBufferSizeChecker() // Initialize test data expectData := []basePartialResult4GroupConcat{ + {valsBuf: bytes.NewBufferString("123"), buffer: nil}, {valsBuf: bytes.NewBufferString(""), buffer: bytes.NewBufferString("")}, - {valsBuf: bytes.NewBufferString("xzxx"), buffer: bytes.NewBufferString(testLongStr2)}, - {valsBuf: bytes.NewBufferString(testLongStr1), buffer: bytes.NewBufferString(testLongStr2)}, + {valsBuf: bytes.NewBufferString(""), buffer: bytes.NewBufferString(testLongStr1)}, + {valsBuf: bytes.NewBufferString(""), buffer: bytes.NewBufferString(testLongStr2)}, } serializedPartialResults := make([]PartialResult, len(expectData)) testDataNum := len(serializedPartialResults) @@ -787,8 +789,11 @@ func TestBasePartialResult4GroupConcat(t *testing.T) { // Check some results require.Equal(t, testDataNum, index) for i := range testDataNum { - require.Equal(t, (*basePartialResult4GroupConcat)(serializedPartialResults[i]).valsBuf.String(), deserializedPartialResults[i].valsBuf.String()) - require.Equal(t, (*basePartialResult4GroupConcat)(serializedPartialResults[i]).buffer.String(), deserializedPartialResults[i].buffer.String()) + if (*basePartialResult4GroupConcat)(serializedPartialResults[i]).buffer != nil { + require.Equal(t, (*basePartialResult4GroupConcat)(serializedPartialResults[i]).buffer.String(), deserializedPartialResults[i].buffer.String()) + } else { + require.Equal(t, (*bytes.Buffer)(nil), deserializedPartialResults[i].buffer) + } } } diff --git a/pkg/executor/aggfuncs/spill_serialize_helper.go b/pkg/executor/aggfuncs/spill_serialize_helper.go index 7a5aef465b80c..21e97668d997f 100644 --- a/pkg/executor/aggfuncs/spill_serialize_helper.go +++ b/pkg/executor/aggfuncs/spill_serialize_helper.go @@ -131,8 +131,12 @@ func (s *SerializeHelper) serializePartialResult4SumFloat64(value partialResult4 func (s *SerializeHelper) serializeBasePartialResult4GroupConcat(value basePartialResult4GroupConcat) []byte { s.buf = s.buf[:0] - s.buf = util.SerializeBytesBuffer(value.valsBuf, s.buf) - s.buf = util.SerializeBytesBuffer(value.buffer, s.buf) + if value.buffer != nil { + s.buf = util.SerializeBool(true, s.buf) + s.buf = util.SerializeBytesBuffer(value.buffer, s.buf) + } else { + s.buf = util.SerializeBool(false, s.buf) + } return s.buf }