Skip to content

Commit 4ab1765

Browse files
authored
planner: introduce hashEquals interface for expression.Expression (#55793)
ref #51664
1 parent b5ec2e3 commit 4ab1765

File tree

10 files changed

+250
-7
lines changed

10 files changed

+250
-7
lines changed

pkg/expression/column.go

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import (
3636

3737
var (
3838
_ base.HashEquals = &Column{}
39+
_ base.HashEquals = &CorrelatedColumn{}
3940
)
4041

4142
// CorrelatedColumn stands for a column in a correlated sub query.
@@ -246,6 +247,31 @@ func (col *CorrelatedColumn) RemapColumn(m map[int64]*Column) (Expression, error
246247
}, nil
247248
}
248249

250+
// Hash64 implements HashEquals.<0th> interface.
251+
func (col *CorrelatedColumn) Hash64(h base.Hasher) {
252+
// correlatedColumn flag here is used to distinguish correlatedColumn and Column.
253+
h.HashByte(correlatedColumn)
254+
col.Column.Hash64(h)
255+
// since col.Datum is filled in the runtime, we can't use it to calculate hash now, correlatedColumn flag + column is enough.
256+
}
257+
258+
// Equals implements HashEquals.<1st> interface.
259+
func (col *CorrelatedColumn) Equals(other any) bool {
260+
if other == nil {
261+
return false
262+
}
263+
var col2 *CorrelatedColumn
264+
switch x := other.(type) {
265+
case CorrelatedColumn:
266+
col2 = &x
267+
case *CorrelatedColumn:
268+
col2 = x
269+
default:
270+
return false
271+
}
272+
return col.Column.Equals(&col2.Column)
273+
}
274+
249275
// Column represents a column.
250276
type Column struct {
251277
RetType *types.FieldType `plan-cache-clone:"shallow"`
@@ -458,11 +484,11 @@ func (col *Column) Hash64(h base.Hasher) {
458484
h.HashInt64(col.ID)
459485
h.HashInt64(col.UniqueID)
460486
h.HashInt(col.Index)
461-
if col.VirtualExpr != nil {
487+
if col.VirtualExpr == nil {
462488
h.HashByte(base.NilFlag)
463489
} else {
464490
h.HashByte(base.NotNilFlag)
465-
//col.VirtualExpr.Hash64(h)
491+
col.VirtualExpr.Hash64(h)
466492
}
467493
h.HashString(col.OrigName)
468494
h.HashBool(col.IsHidden)
@@ -488,12 +514,12 @@ func (col *Column) Equals(other any) bool {
488514
}
489515
// when step into here, we could ensure that col1.RetType and col2.RetType are same type.
490516
// and we should ensure col1.RetType and col2.RetType is not nil ourselves.
491-
ftEqual := col.RetType == nil && col2.RetType == nil || col.RetType != nil && col2.RetType != nil && col.RetType.Equal(col2.RetType)
492-
return ftEqual &&
517+
ok := col.RetType == nil && col2.RetType == nil || col.RetType != nil && col2.RetType != nil && col.RetType.Equal(col2.RetType)
518+
ok = ok && (col.VirtualExpr == nil && col2.VirtualExpr == nil || col.VirtualExpr != nil && col2.VirtualExpr != nil && col.VirtualExpr.Equals(col2.VirtualExpr))
519+
return ok &&
493520
col.ID == col2.ID &&
494521
col.UniqueID == col2.UniqueID &&
495522
col.Index == col2.Index &&
496-
//col.VirtualExpr.Equals(col2.VirtualExpr) &&
497523
col.OrigName == col2.OrigName &&
498524
col.IsHidden == col2.IsHidden &&
499525
col.IsPrefix == col2.IsPrefix &&

pkg/expression/column_test.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,7 @@ func TestColumnHashEquals(t *testing.T) {
415415
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
416416
require.False(t, col1.Equals(col2))
417417

418-
// diff VirtualExpr
419-
// TODO: add HashEquals for VirtualExpr
418+
// diff VirtualExpr see TestColumnHashEuqals4VirtualExpr
420419

421420
// diff OrigName
422421
col2.Index = col1.Index
@@ -468,3 +467,29 @@ func TestColumnHashEquals(t *testing.T) {
468467
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
469468
require.False(t, col1.Equals(col2))
470469
}
470+
471+
func TestColumnHashEuqals4VirtualExpr(t *testing.T) {
472+
col1 := &Column{UniqueID: 1, VirtualExpr: NewZero()}
473+
col2 := &Column{UniqueID: 1, VirtualExpr: nil}
474+
hasher1 := base.NewHashEqualer()
475+
hasher2 := base.NewHashEqualer()
476+
col1.Hash64(hasher1)
477+
col2.Hash64(hasher2)
478+
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
479+
require.False(t, col1.Equals(col2))
480+
481+
col2.VirtualExpr = NewZero()
482+
hasher2.Reset()
483+
col2.Hash64(hasher2)
484+
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
485+
require.True(t, col1.Equals(col2))
486+
487+
col1.VirtualExpr = nil
488+
col2.VirtualExpr = nil
489+
hasher1.Reset()
490+
hasher2.Reset()
491+
col1.Hash64(hasher1)
492+
col2.Hash64(hasher2)
493+
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
494+
require.True(t, col1.Equals(col2))
495+
}

pkg/expression/constant.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
perrors "github.com/pingcap/errors"
2222
"github.com/pingcap/tidb/pkg/parser/mysql"
23+
"github.com/pingcap/tidb/pkg/planner/cascades/base"
2324
"github.com/pingcap/tidb/pkg/types"
2425
"github.com/pingcap/tidb/pkg/util/chunk"
2526
"github.com/pingcap/tidb/pkg/util/codec"
@@ -29,6 +30,8 @@ import (
2930
"go.uber.org/zap"
3031
)
3132

33+
var _ base.HashEquals = &Constant{}
34+
3235
// NewOne stands for a number 1.
3336
func NewOne() *Constant {
3437
retT := types.NewFieldType(mysql.TypeTiny)
@@ -502,6 +505,50 @@ func (c *Constant) CanonicalHashCode() []byte {
502505
return c.getHashCode(true)
503506
}
504507

508+
// Hash64 implements HashEquals.<0th> interface.
509+
func (c *Constant) Hash64(h base.Hasher) {
510+
if c.RetType == nil {
511+
h.HashByte(base.NilFlag)
512+
} else {
513+
h.HashByte(base.NotNilFlag)
514+
c.RetType.Hash64(h)
515+
}
516+
c.collationInfo.Hash64(h)
517+
if c.DeferredExpr != nil {
518+
c.DeferredExpr.Hash64(h)
519+
return
520+
}
521+
if c.ParamMarker != nil {
522+
h.HashByte(parameterFlag)
523+
h.HashInt64(int64(c.ParamMarker.order))
524+
return
525+
}
526+
intest.Assert(c.DeferredExpr == nil && c.ParamMarker == nil)
527+
h.HashByte(constantFlag)
528+
c.Value.Hash64(h)
529+
}
530+
531+
// Equals implements HashEquals.<1st> interface.
532+
func (c *Constant) Equals(other any) bool {
533+
if other == nil {
534+
return false
535+
}
536+
var c2 *Constant
537+
switch x := other.(type) {
538+
case *Constant:
539+
c2 = x
540+
case Constant:
541+
c2 = &x
542+
default:
543+
return false
544+
}
545+
ok := c.RetType == nil && c2.RetType == nil || c.RetType != nil && c2.RetType != nil && c.RetType.Equals(c2.RetType)
546+
ok = ok && c.collationInfo.Equals(c2.collationInfo)
547+
ok = ok && (c.DeferredExpr == nil && c2.DeferredExpr == nil || c.DeferredExpr != nil && c2.DeferredExpr != nil && c.DeferredExpr.Equals(c2.DeferredExpr))
548+
ok = ok && (c.ParamMarker == nil && c2.ParamMarker == nil || c.ParamMarker != nil && c2.ParamMarker != nil && c.ParamMarker.order == c2.ParamMarker.order)
549+
return ok && c.Value.Equals(c2.Value)
550+
}
551+
505552
func (c *Constant) getHashCode(canonical bool) []byte {
506553
if len(c.hashcode) > 0 {
507554
return c.hashcode

pkg/expression/constant_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
exprctx "github.com/pingcap/tidb/pkg/expression/context"
2626
"github.com/pingcap/tidb/pkg/parser/ast"
2727
"github.com/pingcap/tidb/pkg/parser/mysql"
28+
"github.com/pingcap/tidb/pkg/planner/cascades/base"
2829
"github.com/pingcap/tidb/pkg/types"
2930
"github.com/pingcap/tidb/pkg/util/chunk"
3031
"github.com/pingcap/tidb/pkg/util/mock"
@@ -545,3 +546,30 @@ func TestSpecificConstant(t *testing.T) {
545546
require.Equal(t, null.RetType.GetFlen(), 1)
546547
require.Equal(t, null.RetType.GetDecimal(), 0)
547548
}
549+
550+
func TestConstantHashEquals(t *testing.T) {
551+
// Test for Hash64 interface
552+
cst1 := &Constant{Value: types.NewIntDatum(2333), RetType: newIntFieldType()}
553+
cst2 := &Constant{Value: types.NewIntDatum(2333), RetType: newIntFieldType()}
554+
hasher1 := base.NewHashEqualer()
555+
hasher2 := base.NewHashEqualer()
556+
cst1.Hash64(hasher1)
557+
cst2.Hash64(hasher2)
558+
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
559+
require.True(t, cst1.Equals(cst2))
560+
561+
// test cst2 datum changes.
562+
cst2.Value = types.NewIntDatum(2334)
563+
hasher2.Reset()
564+
cst2.Hash64(hasher2)
565+
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
566+
require.False(t, cst1.Equals(cst2))
567+
568+
// test cst2 type changes.
569+
cst2.Value = types.NewIntDatum(2333)
570+
cst2.RetType = newStringFieldType()
571+
hasher2.Reset()
572+
cst2.Hash64(hasher2)
573+
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
574+
require.False(t, cst1.Equals(cst2))
575+
}

pkg/expression/expression.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"github.com/pingcap/tidb/pkg/parser/mysql"
2828
"github.com/pingcap/tidb/pkg/parser/opcode"
2929
"github.com/pingcap/tidb/pkg/parser/terror"
30+
"github.com/pingcap/tidb/pkg/planner/cascades/base"
3031
"github.com/pingcap/tidb/pkg/types"
3132
"github.com/pingcap/tidb/pkg/util/chunk"
3233
"github.com/pingcap/tidb/pkg/util/generatedexpr"
@@ -42,6 +43,7 @@ const (
4243
scalarFunctionFlag byte = 3
4344
parameterFlag byte = 4
4445
ScalarSubQFlag byte = 5
46+
correlatedColumn byte = 6
4547
)
4648

4749
// EvalSimpleAst evaluates a simple ast expression directly.
@@ -170,6 +172,7 @@ const (
170172
type Expression interface {
171173
VecExpr
172174
CollationInfo
175+
base.HashEquals
173176

174177
Traverse(TraverseAction) Expression
175178

pkg/expression/scalar_function.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/pingcap/tidb/pkg/parser/model"
2727
"github.com/pingcap/tidb/pkg/parser/mysql"
2828
"github.com/pingcap/tidb/pkg/parser/terror"
29+
"github.com/pingcap/tidb/pkg/planner/cascades/base"
2930
"github.com/pingcap/tidb/pkg/sessionctx/variable"
3031
"github.com/pingcap/tidb/pkg/types"
3132
"github.com/pingcap/tidb/pkg/util/chunk"
@@ -35,6 +36,8 @@ import (
3536
"github.com/pingcap/tidb/pkg/util/intest"
3637
)
3738

39+
var _ base.HashEquals = &ScalarFunction{}
40+
3841
// ScalarFunction is the function that returns a value.
3942
type ScalarFunction struct {
4043
FuncName model.CIStr
@@ -673,6 +676,51 @@ func simpleCanonicalizedHashCode(sf *ScalarFunction) {
673676
}
674677
}
675678

679+
// Hash64 implements HashEquals.<0th> interface.
680+
func (sf *ScalarFunction) Hash64(h base.Hasher) {
681+
h.HashByte(scalarFunctionFlag)
682+
h.HashString(sf.FuncName.L)
683+
if sf.RetType == nil {
684+
h.HashByte(base.NilFlag)
685+
} else {
686+
h.HashByte(base.NotNilFlag)
687+
sf.RetType.Hash64(h)
688+
}
689+
// hash the arg length to avoid hash collision.
690+
h.HashInt(len(sf.GetArgs()))
691+
for _, arg := range sf.GetArgs() {
692+
arg.Hash64(h)
693+
}
694+
}
695+
696+
// Equals implements HashEquals.<1th> interface.
697+
func (sf *ScalarFunction) Equals(other any) bool {
698+
if other == nil {
699+
return false
700+
}
701+
var sf2 *ScalarFunction
702+
switch x := other.(type) {
703+
case *ScalarFunction:
704+
sf2 = x
705+
case ScalarFunction:
706+
sf2 = &x
707+
default:
708+
return false
709+
}
710+
ok := sf.FuncName.L == sf2.FuncName.L
711+
ok = ok && (sf.RetType == nil && sf2.RetType == nil || sf.RetType != nil && sf2.RetType != nil && sf.RetType.Equals(sf2.RetType))
712+
if len(sf.GetArgs()) != len(sf2.GetArgs()) {
713+
return false
714+
}
715+
for i, arg := range sf.GetArgs() {
716+
ok = ok && arg.Equals(sf2.GetArgs()[i])
717+
if !ok {
718+
return false
719+
}
720+
}
721+
return ok
722+
}
723+
676724
// ReHashCode is used after we change the argument in place.
677725
func ReHashCode(sf *ScalarFunction) {
678726
sf.hashcode = sf.hashcode[:0]

pkg/expression/scalar_function_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919

2020
"github.com/pingcap/tidb/pkg/parser/ast"
2121
"github.com/pingcap/tidb/pkg/parser/mysql"
22+
"github.com/pingcap/tidb/pkg/planner/cascades/base"
2223
"github.com/pingcap/tidb/pkg/types"
2324
"github.com/pingcap/tidb/pkg/util/chunk"
2425
"github.com/pingcap/tidb/pkg/util/mock"
@@ -147,3 +148,40 @@ func TestScalarFuncs2Exprs(t *testing.T) {
147148
require.True(t, exprs[i].Equal(ctx, funcs[i]))
148149
}
149150
}
151+
152+
func TestScalarFunctionHash64Equals(t *testing.T) {
153+
a := &Column{
154+
UniqueID: 1,
155+
RetType: types.NewFieldType(mysql.TypeDouble),
156+
}
157+
sf0, _ := newFunctionWithMockCtx(ast.LT, a, NewZero()).(*ScalarFunction)
158+
sf1, _ := newFunctionWithMockCtx(ast.LT, a, NewZero()).(*ScalarFunction)
159+
hasher1 := base.NewHashEqualer()
160+
hasher2 := base.NewHashEqualer()
161+
sf0.Hash64(hasher1)
162+
sf1.Hash64(hasher2)
163+
require.Equal(t, hasher1.Sum64(), hasher2.Sum64())
164+
require.True(t, sf0.Equals(sf1))
165+
166+
// change the func name
167+
sf2, _ := newFunctionWithMockCtx(ast.GT, a, NewZero()).(*ScalarFunction)
168+
hasher2.Reset()
169+
sf2.Hash64(hasher2)
170+
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
171+
require.False(t, sf0.Equals(sf2))
172+
173+
// change the args
174+
sf3, _ := newFunctionWithMockCtx(ast.LT, a, NewOne()).(*ScalarFunction)
175+
hasher2.Reset()
176+
sf3.Hash64(hasher2)
177+
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
178+
require.False(t, sf0.Equals(sf3))
179+
180+
// change the ret type
181+
sf4, _ := newFunctionWithMockCtx(ast.LT, a, NewZero()).(*ScalarFunction)
182+
sf4.RetType = types.NewFieldType(mysql.TypeLong)
183+
hasher2.Reset()
184+
sf4.Hash64(hasher2)
185+
require.NotEqual(t, hasher1.Sum64(), hasher2.Sum64())
186+
require.False(t, sf0.Equals(sf4))
187+
}

pkg/expression/util_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/pingcap/tidb/pkg/parser/ast"
2323
"github.com/pingcap/tidb/pkg/parser/model"
2424
"github.com/pingcap/tidb/pkg/parser/mysql"
25+
"github.com/pingcap/tidb/pkg/planner/cascades/base"
2526
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
2627
"github.com/pingcap/tidb/pkg/types"
2728
"github.com/pingcap/tidb/pkg/util/chunk"
@@ -661,3 +662,5 @@ func (m *MockExpr) MemoryUsage() (sum int64) {
661662
func (m *MockExpr) Traverse(action TraverseAction) Expression {
662663
return action.Transform(m)
663664
}
665+
func (m *MockExpr) Hash64(_ base.Hasher) {}
666+
func (m *MockExpr) Equals(_ any) bool { return false }

pkg/planner/core/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ go_library(
126126
"//pkg/parser/terror",
127127
"//pkg/parser/types",
128128
"//pkg/planner/cardinality",
129+
"//pkg/planner/cascades/base",
129130
"//pkg/planner/context",
130131
"//pkg/planner/core/base",
131132
"//pkg/planner/core/constraint",

0 commit comments

Comments
 (0)