Skip to content

Commit ed9a909

Browse files
authored
planner: add hash64 and equals for logical aggregation (#56750)
ref #51664
1 parent b759f33 commit ed9a909

File tree

9 files changed

+279
-3
lines changed

9 files changed

+279
-3
lines changed

pkg/expression/aggregation/BUILD.bazel

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ go_library(
4040
"//pkg/parser/charset",
4141
"//pkg/parser/mysql",
4242
"//pkg/parser/terror",
43+
"//pkg/planner/cascades/base",
4344
"//pkg/planner/util",
4445
"//pkg/sessionctx/stmtctx",
4546
"//pkg/types",
@@ -68,12 +69,14 @@ go_test(
6869
],
6970
embed = [":aggregation"],
7071
flaky = True,
71-
shard_count = 14,
72+
shard_count = 15,
7273
deps = [
7374
"//pkg/expression",
7475
"//pkg/kv",
7576
"//pkg/parser/ast",
7677
"//pkg/parser/mysql",
78+
"//pkg/planner/cascades/base",
79+
"//pkg/planner/util",
7780
"//pkg/sessionctx/variable",
7881
"//pkg/testkit/testsetup",
7982
"//pkg/types",

pkg/expression/aggregation/aggregation_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import (
2121
"github.com/pingcap/tidb/pkg/expression"
2222
"github.com/pingcap/tidb/pkg/parser/ast"
2323
"github.com/pingcap/tidb/pkg/parser/mysql"
24+
"github.com/pingcap/tidb/pkg/planner/cascades/base"
25+
"github.com/pingcap/tidb/pkg/planner/util"
2426
"github.com/pingcap/tidb/pkg/sessionctx/variable"
2527
"github.com/pingcap/tidb/pkg/types"
2628
"github.com/pingcap/tidb/pkg/util/chunk"
@@ -600,3 +602,55 @@ func TestMaxMin(t *testing.T) {
600602
partialResult := minFunc.GetPartialResult(minEvalCtx)
601603
require.Equal(t, int64(1), partialResult[0].GetInt64())
602604
}
605+
606+
func TestAggFuncDesc(t *testing.T) {
607+
s := createAggFuncSuite()
608+
col := &expression.Column{
609+
Index: 0,
610+
RetType: types.NewFieldType(mysql.TypeLonglong),
611+
}
612+
desc1, err := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, false)
613+
require.NoError(t, err)
614+
desc2, err := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, false)
615+
require.NoError(t, err)
616+
hasher1 := base.NewHashEqualer()
617+
hasher2 := base.NewHashEqualer()
618+
desc1.Hash64(hasher1)
619+
desc2.Hash64(hasher2)
620+
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
621+
622+
desc2.HasDistinct = true
623+
hasher2.Reset()
624+
desc2.Hash64(hasher2)
625+
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
626+
627+
desc2.HasDistinct = false
628+
desc2.Mode = FinalMode
629+
hasher2.Reset()
630+
desc2.Hash64(hasher2)
631+
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
632+
633+
desc2.Mode = CompleteMode
634+
desc2.Name = "whatever"
635+
hasher2.Reset()
636+
desc2.Hash64(hasher2)
637+
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
638+
639+
desc2.Name = ast.AggFuncSum
640+
desc2.Args = []expression.Expression{}
641+
hasher2.Reset()
642+
desc2.Hash64(hasher2)
643+
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
644+
645+
desc2.Args = []expression.Expression{col}
646+
desc2.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
647+
hasher2.Reset()
648+
desc2.Hash64(hasher2)
649+
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
650+
651+
desc2.RetTp = types.NewFieldType(mysql.TypeLonglong)
652+
desc2.OrderByItems = []*util.ByItems{{Expr: col, Desc: true}}
653+
hasher2.Reset()
654+
desc2.Hash64(hasher2)
655+
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
656+
}

pkg/expression/aggregation/base_func.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/pingcap/tidb/pkg/parser/ast"
2626
"github.com/pingcap/tidb/pkg/parser/charset"
2727
"github.com/pingcap/tidb/pkg/parser/mysql"
28+
"github.com/pingcap/tidb/pkg/planner/cascades/base"
2829
"github.com/pingcap/tidb/pkg/types"
2930
"github.com/pingcap/tidb/pkg/util/chunk"
3031
"github.com/pingcap/tidb/pkg/util/mathutil"
@@ -47,6 +48,42 @@ func newBaseFuncDesc(ctx expression.BuildContext, name string, args []expression
4748
return b, err
4849
}
4950

51+
// Hash64 implements the base.Hasher interface.
52+
func (a *baseFuncDesc) Hash64(h base.Hasher) {
53+
h.HashString(a.Name)
54+
h.HashInt(len(a.Args))
55+
for _, arg := range a.Args {
56+
arg.Hash64(h)
57+
}
58+
if a.RetTp != nil {
59+
h.HashByte(base.NotNilFlag)
60+
a.RetTp.Hash64(h)
61+
} else {
62+
h.HashByte(base.NilFlag)
63+
}
64+
}
65+
66+
// Equals implements the base.Equals interface.
67+
func (a *baseFuncDesc) Equals(other any) bool {
68+
if other == nil {
69+
return false
70+
}
71+
a2, ok := other.(*baseFuncDesc)
72+
if !ok {
73+
return false
74+
}
75+
ok = a.Name == a2.Name && len(a.Args) == len(a2.Args) && ((a.RetTp == nil && a2.RetTp == nil) || (a.RetTp != nil && a2.RetTp != nil && a.RetTp.Equals(a2.RetTp)))
76+
if !ok {
77+
return false
78+
}
79+
for i, arg := range a.Args {
80+
if !arg.Equals(a2.Args[i]) {
81+
return false
82+
}
83+
}
84+
return true
85+
}
86+
5087
func (a *baseFuncDesc) equal(ctx expression.EvalContext, other *baseFuncDesc) bool {
5188
if a.Name != other.Name || len(a.Args) != len(other.Args) {
5289
return false

pkg/expression/aggregation/descriptor.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/pingcap/tidb/pkg/expression"
2323
"github.com/pingcap/tidb/pkg/parser/ast"
2424
"github.com/pingcap/tidb/pkg/parser/mysql"
25+
"github.com/pingcap/tidb/pkg/planner/cascades/base"
2526
"github.com/pingcap/tidb/pkg/planner/util"
2627
"github.com/pingcap/tidb/pkg/types"
2728
"github.com/pingcap/tidb/pkg/util/collate"
@@ -59,6 +60,38 @@ func NewAggFuncDescForWindowFunc(ctx expression.BuildContext, desc *WindowFuncDe
5960
return &AggFuncDesc{baseFuncDesc: baseFuncDesc{desc.Name, desc.Args, desc.RetTp}, HasDistinct: hasDistinct}, nil
6061
}
6162

63+
// Hash64 returns the hash64 for the aggregation function signature.
64+
func (a *AggFuncDesc) Hash64(h base.Hasher) {
65+
a.baseFuncDesc.Hash64(h)
66+
h.HashInt(int(a.Mode))
67+
h.HashBool(a.HasDistinct)
68+
h.HashInt(len(a.OrderByItems))
69+
for _, item := range a.OrderByItems {
70+
item.Hash64(h)
71+
}
72+
// groupingID will be deprecated soon.
73+
}
74+
75+
// Equals checks whether two aggregation function signatures are equal.
76+
func (a *AggFuncDesc) Equals(other any) bool {
77+
if other == nil {
78+
return false
79+
}
80+
otherAgg, ok := other.(*AggFuncDesc)
81+
if !ok {
82+
return false
83+
}
84+
if a.Mode != otherAgg.Mode || a.HasDistinct != otherAgg.HasDistinct || len(a.OrderByItems) != len(otherAgg.OrderByItems) {
85+
return false
86+
}
87+
for i := range a.OrderByItems {
88+
if !a.OrderByItems[i].Equals(otherAgg.OrderByItems[i]) {
89+
return false
90+
}
91+
}
92+
return a.baseFuncDesc.Equals(otherAgg.baseFuncDesc)
93+
}
94+
6295
// StringWithCtx returns the string representation within given ctx.
6396
func (a *AggFuncDesc) StringWithCtx(ctx expression.ParamValues, redact string) string {
6497
buffer := bytes.NewBufferString(a.Name)

pkg/planner/core/operator/logicalop/logical_aggregation.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/pingcap/tidb/pkg/kv"
2424
"github.com/pingcap/tidb/pkg/parser/ast"
2525
"github.com/pingcap/tidb/pkg/planner/cardinality"
26+
base2 "github.com/pingcap/tidb/pkg/planner/cascades/base"
2627
"github.com/pingcap/tidb/pkg/planner/core/base"
2728
ruleutil "github.com/pingcap/tidb/pkg/planner/core/rule/util"
2829
fd "github.com/pingcap/tidb/pkg/planner/funcdep"
@@ -61,6 +62,62 @@ func (la LogicalAggregation) Init(ctx base.PlanContext, offset int) *LogicalAggr
6162
return &la
6263
}
6364

65+
// *************************** start implementation of HashEquals interface ****************************
66+
67+
// Hash64 implements the base.Hash64.<0th> interface.
68+
func (la *LogicalAggregation) Hash64(h base2.Hasher) {
69+
h.HashInt(len(la.AggFuncs))
70+
for _, one := range la.AggFuncs {
71+
one.Hash64(h)
72+
}
73+
h.HashInt(len(la.GroupByItems))
74+
for _, one := range la.GroupByItems {
75+
one.Hash64(h)
76+
}
77+
h.HashInt(len(la.PossibleProperties))
78+
for _, one := range la.PossibleProperties {
79+
h.HashInt(len(one))
80+
for _, col := range one {
81+
col.Hash64(h)
82+
}
83+
}
84+
}
85+
86+
// Equals implements the base.HashEquals.<1st> interface.
87+
func (la *LogicalAggregation) Equals(other any) bool {
88+
if other == nil {
89+
return false
90+
}
91+
la2, ok := other.(*LogicalAggregation)
92+
if !ok {
93+
return false
94+
}
95+
if len(la.AggFuncs) != len(la2.AggFuncs) || len(la.GroupByItems) != len(la2.GroupByItems) || len(la.PossibleProperties) != len(la2.PossibleProperties) {
96+
return false
97+
}
98+
for i := range la.AggFuncs {
99+
if !la.AggFuncs[i].Equals(la2.AggFuncs[i]) {
100+
return false
101+
}
102+
}
103+
for i := range la.GroupByItems {
104+
if !la.GroupByItems[i].Equals(la2.GroupByItems[i]) {
105+
return false
106+
}
107+
}
108+
for i := range la.PossibleProperties {
109+
if len(la.PossibleProperties[i]) != len(la2.PossibleProperties[i]) {
110+
return false
111+
}
112+
for j := range la.PossibleProperties[i] {
113+
if !la.PossibleProperties[i][j].Equals(la2.PossibleProperties[i][j]) {
114+
return false
115+
}
116+
}
117+
}
118+
return true
119+
}
120+
64121
// *************************** start implementation of Plan interface ***************************
65122

66123
// ExplainInfo implements base.Plan.<4th> interface.

pkg/planner/core/operator/logicalop/logicalop_test/BUILD.bazel

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,21 @@ load("@io_bazel_rules_go//go:def.bzl", "go_test")
33
go_test(
44
name = "logicalop_test_test",
55
timeout = "short",
6-
srcs = ["logical_mem_table_predicate_extractor_test.go"],
6+
srcs = [
7+
"hash64_equals_test.go",
8+
"logical_mem_table_predicate_extractor_test.go",
9+
],
710
flaky = True,
8-
shard_count = 13,
11+
shard_count = 14,
912
deps = [
1013
"//pkg/domain",
1114
"//pkg/expression",
15+
"//pkg/expression/aggregation",
1216
"//pkg/parser",
1317
"//pkg/parser/ast",
18+
"//pkg/parser/mysql",
1419
"//pkg/planner",
20+
"//pkg/planner/cascades/base",
1521
"//pkg/planner/core",
1622
"//pkg/planner/core/base",
1723
"//pkg/planner/core/operator/logicalop",
@@ -22,6 +28,7 @@ go_test(
2228
"//pkg/testkit",
2329
"//pkg/types",
2430
"//pkg/util/hint",
31+
"//pkg/util/mock",
2532
"//pkg/util/set",
2633
"@com_github_stretchr_testify//require",
2734
],
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright 2024 PingCAP, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package logicalop
16+
17+
import (
18+
"testing"
19+
20+
"github.com/pingcap/tidb/pkg/expression"
21+
"github.com/pingcap/tidb/pkg/expression/aggregation"
22+
"github.com/pingcap/tidb/pkg/parser/ast"
23+
"github.com/pingcap/tidb/pkg/parser/mysql"
24+
"github.com/pingcap/tidb/pkg/planner/cascades/base"
25+
"github.com/pingcap/tidb/pkg/planner/core/operator/logicalop"
26+
"github.com/pingcap/tidb/pkg/types"
27+
"github.com/pingcap/tidb/pkg/util/mock"
28+
"github.com/stretchr/testify/require"
29+
)
30+
31+
func TestLogicalAggregationHash64Equals(t *testing.T) {
32+
col := &expression.Column{
33+
Index: 0,
34+
RetType: types.NewFieldType(mysql.TypeLonglong),
35+
}
36+
ctx := mock.NewContext()
37+
desc, err := aggregation.NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true)
38+
require.Nil(t, err)
39+
la1 := &logicalop.LogicalAggregation{
40+
AggFuncs: []*aggregation.AggFuncDesc{desc},
41+
GroupByItems: []expression.Expression{col},
42+
PossibleProperties: [][]*expression.Column{{col}},
43+
}
44+
la2 := &logicalop.LogicalAggregation{
45+
AggFuncs: []*aggregation.AggFuncDesc{desc},
46+
GroupByItems: []expression.Expression{col},
47+
PossibleProperties: [][]*expression.Column{{col}},
48+
}
49+
hasher1 := base.NewHashEqualer()
50+
hasher2 := base.NewHashEqualer()
51+
la1.Hash64(hasher1)
52+
la2.Hash64(hasher2)
53+
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
54+
55+
la2.GroupByItems = []expression.Expression{}
56+
hasher2.Reset()
57+
la2.Hash64(hasher2)
58+
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
59+
60+
la2.GroupByItems = []expression.Expression{col}
61+
la2.PossibleProperties = [][]*expression.Column{{}}
62+
hasher2.Reset()
63+
la2.Hash64(hasher2)
64+
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
65+
}

pkg/planner/util/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ go_library(
2121
"//pkg/parser/ast",
2222
"//pkg/parser/model",
2323
"//pkg/parser/mysql",
24+
"//pkg/planner/cascades/base",
2425
"//pkg/planner/core/base",
2526
"//pkg/planner/funcdep",
2627
"//pkg/planner/planctx",

pkg/planner/util/byitem.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
perrors "github.com/pingcap/errors"
2222
"github.com/pingcap/tidb/pkg/expression"
23+
"github.com/pingcap/tidb/pkg/planner/cascades/base"
2324
"github.com/pingcap/tidb/pkg/util/size"
2425
)
2526

@@ -29,6 +30,24 @@ type ByItems struct {
2930
Desc bool
3031
}
3132

33+
// Hash64 implements the base.Hasher interface.
34+
func (by *ByItems) Hash64(h base.Hasher) {
35+
by.Expr.Hash64(h)
36+
h.HashBool(by.Desc)
37+
}
38+
39+
// Equals implements the base.Equaler interface.
40+
func (by *ByItems) Equals(other any) bool {
41+
if other == nil {
42+
return false
43+
}
44+
otherBy, ok := other.(*ByItems)
45+
if !ok {
46+
return false
47+
}
48+
return by.Desc == otherBy.Desc && by.Expr.Equals(otherBy.Expr)
49+
}
50+
3251
// StringWithCtx implements expression.StringerWithCtx interface.
3352
func (by *ByItems) StringWithCtx(ctx expression.ParamValues, redact string) string {
3453
if by.Desc {

0 commit comments

Comments
 (0)