Skip to content

Commit 35d5739

Browse files
authored
expression: Fix optimizer panic in evaluate expr with null (#57403)
close #55886
1 parent ecca340 commit 35d5739

File tree

12 files changed

+173
-49
lines changed

12 files changed

+173
-49
lines changed

pkg/expression/aggregation/descriptor.go

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -228,15 +228,15 @@ func (a *AggFuncDesc) Split(ordinal []int) (partialAggDesc, finalAggDesc *AggFun
228228
// +------+-----------+---------+---------+------------+-------------+------------+---------+---------+------+----------+
229229
// | 1 | 1 | 95 | 95.0000 | 95 | 95 | 95 | 95 | 95 | NULL | NULL |
230230
// +------+-----------+---------+---------+------------+-------------+------------+---------+---------+------+----------+
231-
func (a *AggFuncDesc) EvalNullValueInOuterJoin(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) {
231+
func (a *AggFuncDesc) EvalNullValueInOuterJoin(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool, error) {
232232
switch a.Name {
233233
case ast.AggFuncCount:
234234
return a.evalNullValueInOuterJoin4Count(ctx, schema)
235235
case ast.AggFuncSum, ast.AggFuncMax, ast.AggFuncMin,
236236
ast.AggFuncFirstRow:
237237
return a.evalNullValueInOuterJoin4Sum(ctx, schema)
238238
case ast.AggFuncAvg, ast.AggFuncGroupConcat:
239-
return types.Datum{}, false
239+
return types.Datum{}, false, nil
240240
case ast.AggFuncBitAnd:
241241
return a.evalNullValueInOuterJoin4BitAnd(ctx, schema)
242242
case ast.AggFuncBitOr, ast.AggFuncBitXor:
@@ -275,42 +275,54 @@ func (a *AggFuncDesc) GetAggFunc(ctx expression.AggFuncBuildContext) Aggregation
275275
}
276276
}
277277

278-
func (a *AggFuncDesc) evalNullValueInOuterJoin4Count(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) {
278+
func (a *AggFuncDesc) evalNullValueInOuterJoin4Count(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool, error) {
279279
for _, arg := range a.Args {
280-
result := expression.EvaluateExprWithNull(ctx, schema, arg)
280+
result, err := expression.EvaluateExprWithNull(ctx, schema, arg)
281+
if err != nil {
282+
return types.Datum{}, false, err
283+
}
281284
con, ok := result.(*expression.Constant)
282285
if !ok || con.Value.IsNull() {
283-
return types.Datum{}, ok
286+
return types.Datum{}, ok, nil
284287
}
285288
}
286-
return types.NewDatum(1), true
289+
return types.NewDatum(1), true, nil
287290
}
288291

289-
func (a *AggFuncDesc) evalNullValueInOuterJoin4Sum(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) {
290-
result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0])
292+
func (a *AggFuncDesc) evalNullValueInOuterJoin4Sum(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool, error) {
293+
result, err := expression.EvaluateExprWithNull(ctx, schema, a.Args[0])
294+
if err != nil {
295+
return types.Datum{}, false, err
296+
}
291297
con, ok := result.(*expression.Constant)
292298
if !ok || con.Value.IsNull() {
293-
return types.Datum{}, ok
299+
return types.Datum{}, ok, nil
294300
}
295-
return con.Value, true
301+
return con.Value, true, nil
296302
}
297303

298-
func (a *AggFuncDesc) evalNullValueInOuterJoin4BitAnd(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) {
299-
result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0])
304+
func (a *AggFuncDesc) evalNullValueInOuterJoin4BitAnd(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool, error) {
305+
result, err := expression.EvaluateExprWithNull(ctx, schema, a.Args[0])
306+
if err != nil {
307+
return types.Datum{}, false, err
308+
}
300309
con, ok := result.(*expression.Constant)
301310
if !ok || con.Value.IsNull() {
302-
return types.NewDatum(uint64(math.MaxUint64)), true
311+
return types.NewDatum(uint64(math.MaxUint64)), true, nil
303312
}
304-
return con.Value, true
313+
return con.Value, true, nil
305314
}
306315

307-
func (a *AggFuncDesc) evalNullValueInOuterJoin4BitOr(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool) {
308-
result := expression.EvaluateExprWithNull(ctx, schema, a.Args[0])
316+
func (a *AggFuncDesc) evalNullValueInOuterJoin4BitOr(ctx expression.BuildContext, schema *expression.Schema) (types.Datum, bool, error) {
317+
result, err := expression.EvaluateExprWithNull(ctx, schema, a.Args[0])
318+
if err != nil {
319+
return types.Datum{}, false, err
320+
}
309321
con, ok := result.(*expression.Constant)
310322
if !ok || con.Value.IsNull() {
311-
return types.NewDatum(0), true
323+
return types.NewDatum(0), true, nil
312324
}
313-
return con.Value, true
325+
return con.Value, true, nil
314326
}
315327

316328
// UpdateNotNullFlag4RetType checks if we should remove the NotNull flag for the return type of the agg.

pkg/expression/expression.go

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -920,49 +920,57 @@ func SplitDNFItems(onExpr Expression) []Expression {
920920

921921
// EvaluateExprWithNull sets columns in schema as null and calculate the final result of the scalar function.
922922
// If the Expression is a non-constant value, it means the result is unknown.
923-
func EvaluateExprWithNull(ctx BuildContext, schema *Schema, expr Expression) Expression {
923+
func EvaluateExprWithNull(ctx BuildContext, schema *Schema, expr Expression) (Expression, error) {
924924
if MaybeOverOptimized4PlanCache(ctx, []Expression{expr}) {
925925
ctx.SetSkipPlanCache(fmt.Sprintf("%v affects null check", expr.StringWithCtx(ctx.GetEvalCtx(), errors.RedactLogDisable)))
926926
}
927927
if ctx.IsInNullRejectCheck() {
928-
expr, _ = evaluateExprWithNullInNullRejectCheck(ctx, schema, expr)
929-
return expr
928+
res, _, err := evaluateExprWithNullInNullRejectCheck(ctx, schema, expr)
929+
return res, err
930930
}
931931
return evaluateExprWithNull(ctx, schema, expr)
932932
}
933933

934-
func evaluateExprWithNull(ctx BuildContext, schema *Schema, expr Expression) Expression {
934+
func evaluateExprWithNull(ctx BuildContext, schema *Schema, expr Expression) (Expression, error) {
935935
switch x := expr.(type) {
936936
case *ScalarFunction:
937937
args := make([]Expression, len(x.GetArgs()))
938938
for i, arg := range x.GetArgs() {
939-
args[i] = evaluateExprWithNull(ctx, schema, arg)
939+
res, err := EvaluateExprWithNull(ctx, schema, arg)
940+
if err != nil {
941+
return nil, err
942+
}
943+
args[i] = res
940944
}
941-
return NewFunctionInternal(ctx, x.FuncName.L, x.RetType.Clone(), args...)
945+
return NewFunction(ctx, x.FuncName.L, x.RetType.Clone(), args...)
942946
case *Column:
943947
if !schema.Contains(x) {
944-
return x
948+
return x, nil
945949
}
946-
return &Constant{Value: types.Datum{}, RetType: types.NewFieldType(mysql.TypeNull)}
950+
return &Constant{Value: types.Datum{}, RetType: types.NewFieldType(mysql.TypeNull)}, nil
947951
case *Constant:
948952
if x.DeferredExpr != nil {
949-
return FoldConstant(ctx, x)
953+
return FoldConstant(ctx, x), nil
950954
}
951955
}
952-
return expr
956+
return expr, nil
953957
}
954958

955959
// evaluateExprWithNullInNullRejectCheck sets columns in schema as null and calculate the final result of the scalar function.
956960
// If the Expression is a non-constant value, it means the result is unknown.
957961
// The returned bool values indicates whether the value is influenced by the Null Constant transformed from schema column
958962
// when the value is Null Constant.
959-
func evaluateExprWithNullInNullRejectCheck(ctx BuildContext, schema *Schema, expr Expression) (Expression, bool) {
963+
func evaluateExprWithNullInNullRejectCheck(ctx BuildContext, schema *Schema, expr Expression) (Expression, bool, error) {
960964
switch x := expr.(type) {
961965
case *ScalarFunction:
962966
args := make([]Expression, len(x.GetArgs()))
963967
nullFromSets := make([]bool, len(x.GetArgs()))
964968
for i, arg := range x.GetArgs() {
965-
args[i], nullFromSets[i] = evaluateExprWithNullInNullRejectCheck(ctx, schema, arg)
969+
res, nullFromSet, err := evaluateExprWithNullInNullRejectCheck(ctx, schema, arg)
970+
if err != nil {
971+
return nil, false, err
972+
}
973+
args[i], nullFromSets[i] = res, nullFromSet
966974
}
967975
allArgsNullFromSet := true
968976
for i := range args {
@@ -999,22 +1007,25 @@ func evaluateExprWithNullInNullRejectCheck(ctx BuildContext, schema *Schema, exp
9991007
}
10001008
}
10011009

1002-
c := NewFunctionInternal(ctx, x.FuncName.L, x.RetType.Clone(), args...)
1010+
c, err := NewFunction(ctx, x.FuncName.L, x.RetType.Clone(), args...)
1011+
if err != nil {
1012+
return nil, false, err
1013+
}
10031014
cons, ok := c.(*Constant)
10041015
// If the return expr is Null Constant, and all the Null Constant arguments are affected by column schema,
10051016
// then we think the result Null Constant is also affected by the column schema
1006-
return c, ok && cons.Value.IsNull() && allArgsNullFromSet
1017+
return c, ok && cons.Value.IsNull() && allArgsNullFromSet, nil
10071018
case *Column:
10081019
if !schema.Contains(x) {
1009-
return x, false
1020+
return x, false, nil
10101021
}
1011-
return &Constant{Value: types.Datum{}, RetType: types.NewFieldType(mysql.TypeNull)}, true
1022+
return &Constant{Value: types.Datum{}, RetType: types.NewFieldType(mysql.TypeNull)}, true, nil
10121023
case *Constant:
10131024
if x.DeferredExpr != nil {
1014-
return FoldConstant(ctx, x), false
1025+
return FoldConstant(ctx, x), false, nil
10151026
}
10161027
}
1017-
return expr, false
1028+
return expr, false, nil
10181029
}
10191030

10201031
// TableInfo2SchemaAndNames converts the TableInfo to the schema and name slice.

pkg/expression/expression_test.go

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,40 @@ func TestEvaluateExprWithNull(t *testing.T) {
4949
outerIfNull, err := newFunctionForTest(ctx, ast.Ifnull, col0, innerIfNull)
5050
require.NoError(t, err)
5151

52-
res := EvaluateExprWithNull(ctx, schema, outerIfNull)
52+
res, err := EvaluateExprWithNull(ctx, schema, outerIfNull)
53+
require.Nil(t, err)
5354
require.Equal(t, "ifnull(Column#1, 1)", res.StringWithCtx(ctx, errors.RedactLogDisable))
5455
require.Equal(t, "ifnull(Column#1, ?)", res.StringWithCtx(ctx, errors.RedactLogEnable))
5556
require.Equal(t, "ifnull(Column#1, ‹1›)", res.StringWithCtx(ctx, errors.RedactLogMarker))
5657
schema.Columns = append(schema.Columns, col1)
5758
// ifnull(null, ifnull(null, 1))
58-
res = EvaluateExprWithNull(ctx, schema, outerIfNull)
59+
res, err = EvaluateExprWithNull(ctx, schema, outerIfNull)
60+
require.Nil(t, err)
5961
require.True(t, res.Equal(ctx, NewOne()))
6062
}
6163

64+
func TestEvaluateExprWithNullMeetError(t *testing.T) {
65+
ctx := createContext(t)
66+
tblInfo := newTestTableBuilder("").add("col0", mysql.TypeLonglong, 0).add("col1", mysql.TypeLonglong, 0).build()
67+
schema := tableInfoToSchemaForTest(tblInfo)
68+
col0 := schema.Columns[0]
69+
col1 := schema.Columns[1]
70+
schema.Columns = schema.Columns[:1]
71+
innerFunc, err := newFunctionForTest(ctx, ast.Ifnull, col1, NewOne())
72+
require.NoError(t, err)
73+
// rename the function name to make it invalid, so that the inner function will meet an error
74+
innerFunc.(*ScalarFunction).FuncName.L = "invalid"
75+
outerIfNull, err := newFunctionForTest(ctx, ast.Ifnull, col0, innerFunc)
76+
require.NoError(t, err)
77+
78+
// the inner function has an error
79+
_, err = EvaluateExprWithNull(ctx, schema, outerIfNull)
80+
require.NotNil(t, err)
81+
// check in NullRejectCheck ctx
82+
_, err = EvaluateExprWithNull(ctx.GetNullRejectCheckExprCtx(), schema, outerIfNull)
83+
require.NotNil(t, err)
84+
}
85+
6286
func TestEvaluateExprWithNullAndParameters(t *testing.T) {
6387
ctx := createContext(t)
6488
tblInfo := newTestTableBuilder("").add("col0", mysql.TypeLonglong, 0).build()
@@ -70,14 +94,16 @@ func TestEvaluateExprWithNullAndParameters(t *testing.T) {
7094
// cases for parameters
7195
ltWithoutParam, err := newFunctionForTest(ctx, ast.LT, col0, NewOne())
7296
require.NoError(t, err)
73-
res := EvaluateExprWithNull(ctx, schema, ltWithoutParam)
97+
res, err := EvaluateExprWithNull(ctx, schema, ltWithoutParam)
98+
require.Nil(t, err)
7499
require.True(t, res.Equal(ctx, NewNull())) // the expression is evaluated to null
75100
param := NewOne()
76101
param.ParamMarker = &ParamMarker{order: 0}
77102
ctx.GetSessionVars().PlanCacheParams.Append(types.NewIntDatum(10))
78103
ltWithParam, err := newFunctionForTest(ctx, ast.LT, col0, param)
79104
require.NoError(t, err)
80-
res = EvaluateExprWithNull(ctx, schema, ltWithParam)
105+
res, err = EvaluateExprWithNull(ctx, schema, ltWithParam)
106+
require.Nil(t, err)
81107
_, isConst := res.(*Constant)
82108
require.True(t, isConst) // this expression is evaluated and skip-plan cache flag is set.
83109
require.True(t, !ctx.GetSessionVars().StmtCtx.UseCache())

pkg/expression/integration_test/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ go_test(
88
"main_test.go",
99
],
1010
flaky = True,
11-
shard_count = 47,
11+
shard_count = 48,
1212
deps = [
1313
"//pkg/config",
1414
"//pkg/domain",

pkg/expression/integration_test/integration_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3949,3 +3949,15 @@ func TestIssue55885(t *testing.T) {
39493949

39503950
tk.MustQuery("SELECT subq_0.c3 as c1 FROM (select c_a90ol as c3, c_a90ol as c4, var_pop(cast(c__qy as double)) over (partition by c_a90ol, c_s order by c_z) as c5 from t_jg8o limit 65) as subq_0 LIMIT 37")
39513951
}
3952+
3953+
func TestIssue55886(t *testing.T) {
3954+
store := testkit.CreateMockStore(t)
3955+
tk := testkit.NewTestKit(t, store)
3956+
tk.MustExec("use test")
3957+
tk.MustExec("create table t1(c_foveoe text, c_jbb text, c_cz text not null);")
3958+
tk.MustExec("create table t2(c_g7eofzlxn int);")
3959+
tk.MustExec("set collation_connection='latin1_bin';")
3960+
tk.MustQuery("with cte_0 AS (select 1 as c1, case when ref_0.c_jbb then inet6_aton(ref_0.c_foveoe) else ref_4.c_cz end as c5 from t1 as ref_0 join " +
3961+
" (t1 as ref_4 right outer join t2 as ref_5 on ref_5.c_g7eofzlxn != 1)), cte_4 as (select 1 as c1 from t2) select ref_34.c1 as c5 from" +
3962+
" cte_0 as ref_34 where exists (select 1 from cte_4 as ref_35 where ref_34.c1 <= case when ref_34.c5 then cast(1 as char) else ref_34.c5 end);")
3963+
}

pkg/planner/core/casetest/rule/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ go_test(
1313
],
1414
data = glob(["testdata/**"]),
1515
flaky = True,
16-
shard_count = 7,
16+
shard_count = 8,
1717
deps = [
1818
"//pkg/domain",
1919
"//pkg/expression",

pkg/planner/core/casetest/rule/rule_outer2inner_test.go

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,36 @@ func TestOuter2Inner(t *testing.T) {
4545
Plan []string
4646
}
4747
suiteData := GetOuter2InnerSuiteData()
48-
suiteData.LoadTestCases(t, &input, &output)
48+
suiteData.LoadTestCasesByName("TestOuter2Inner", t, &input, &output)
49+
for i, sql := range input {
50+
plan := tk.MustQuery("explain format = 'brief' " + sql)
51+
testdata.OnRecord(func() {
52+
output[i].SQL = sql
53+
output[i].Plan = testdata.ConvertRowsToStrings(plan.Rows())
54+
})
55+
plan.Check(testkit.Rows(output[i].Plan...))
56+
}
57+
}
58+
59+
// can not add this test case to TestOuter2Inner because the collation_connection is different
60+
func TestOuter2InnerIssue55886(t *testing.T) {
61+
store := testkit.CreateMockStore(t)
62+
tk := testkit.NewTestKit(t, store)
63+
64+
tk.MustExec("use test")
65+
tk.MustExec("drop table if exists t1")
66+
tk.MustExec("drop table if exists t2")
67+
tk.MustExec("create table t1(c_foveoe text, c_jbb text, c_cz text not null)")
68+
tk.MustExec("create table t2(c_g7eofzlxn int)")
69+
tk.MustExec("set collation_connection = 'latin1_bin'")
70+
71+
var input Input
72+
var output []struct {
73+
SQL string
74+
Plan []string
75+
}
76+
suiteData := GetOuter2InnerSuiteData()
77+
suiteData.LoadTestCasesByName("TestOuter2InnerIssue55886", t, &input, &output)
4978
for i, sql := range input {
5079
plan := tk.MustQuery("explain format = 'brief' " + sql)
5180
testdata.OnRecord(func() {

pkg/planner/core/casetest/rule/testdata/outer2inner_in.json

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,13 @@
4646
"select * from t0 left outer join t11 on a0=a1 where '5' not in (t0.b0, t11.b1)",
4747
"select * from t0 left outer join t11 on a0=a1 where '1' in (t0.b0, t11.b1)",
4848
"select * from t0 left outer join t11 on a0=a1 where t0.b0 in ('5', t11.b1) -- some = in the in list is not null filtering",
49-
"select * from t0 left outer join t11 on a0=a1 where '5' in (t0.b0, t11.b1) -- some = in the in list is not null filtering",
50-
"select * from t1 left outer join t2 on a1=a2 where not (b2 is NOT NULL AND c2 = 5) -- NOT case "
49+
"select * from t0 left outer join t11 on a0=a1 where '5' in (t0.b0, t11.b1) -- some = in the in list is not null filtering"
50+
]
51+
},
52+
{
53+
"name": "TestOuter2InnerIssue55886",
54+
"cases": [
55+
"with cte_0 AS (select 1 as c1, case when ref_0.c_jbb then inet6_aton(ref_0.c_foveoe) else ref_4.c_cz end as c5 from t1 as ref_0 join (t1 as ref_4 right outer join t2 as ref_5 on ref_5.c_g7eofzlxn != 1)), cte_4 as (select 1 as c1 from t2) select ref_34.c1 as c5 from cte_0 as ref_34 where exists (select 1 from cte_4 as ref_35 where ref_34.c1 <= case when ref_34.c5 then cast(1 as char) else ref_34.c5 end)"
5156
]
5257
}
5358
]

pkg/planner/core/casetest/rule/testdata/outer2inner_out.json

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,5 +650,27 @@
650650
]
651651
}
652652
]
653+
},
654+
{
655+
"Name": "TestOuter2InnerIssue55886",
656+
"Cases": [
657+
{
658+
"SQL": "with cte_0 AS (select 1 as c1, case when ref_0.c_jbb then inet6_aton(ref_0.c_foveoe) else ref_4.c_cz end as c5 from t1 as ref_0 join (t1 as ref_4 right outer join t2 as ref_5 on ref_5.c_g7eofzlxn != 1)), cte_4 as (select 1 as c1 from t2) select ref_34.c1 as c5 from cte_0 as ref_34 where exists (select 1 from cte_4 as ref_35 where ref_34.c1 <= case when ref_34.c5 then cast(1 as char) else ref_34.c5 end)",
659+
"Plan": [
660+
"HashJoin 800000000000.00 root CARTESIAN semi join",
661+
"├─TableReader(Build) 10000.00 root data:TableFullScan",
662+
"│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
663+
"└─Projection(Probe) 1000000000000.00 root 1->Column#28",
664+
" └─HashJoin 1000000000000.00 root CARTESIAN inner join, other cond:le(1, cast(case(istrue_with_null(cast(case(istrue_with_null(cast(test.t1.c_jbb, double BINARY)), from_binary(inet6_aton(test.t1.c_foveoe)), test.t1.c_cz), double BINARY)), \"1\", case(istrue_with_null(cast(test.t1.c_jbb, double BINARY)), from_binary(inet6_aton(test.t1.c_foveoe)), test.t1.c_cz)), double BINARY))",
665+
" ├─TableReader(Build) 10000.00 root data:TableFullScan",
666+
" │ └─TableFullScan 10000.00 cop[tikv] table:ref_0 keep order:false, stats:pseudo",
667+
" └─HashJoin(Probe) 100000000.00 root CARTESIAN right outer join, right cond:ne(test.t2.c_g7eofzlxn, 1)",
668+
" ├─TableReader(Build) 10000.00 root data:TableFullScan",
669+
" │ └─TableFullScan 10000.00 cop[tikv] table:ref_5 keep order:false, stats:pseudo",
670+
" └─TableReader(Probe) 10000.00 root data:TableFullScan",
671+
" └─TableFullScan 10000.00 cop[tikv] table:ref_4 keep order:false, stats:pseudo"
672+
]
673+
}
674+
]
653675
}
654676
]

pkg/planner/core/operator/logicalop/logical_aggregation.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,10 @@ func (la *LogicalAggregation) CanPullUp() bool {
703703
}
704704
for _, f := range la.AggFuncs {
705705
for _, arg := range f.Args {
706-
expr := expression.EvaluateExprWithNull(la.SCtx().GetExprCtx(), la.Children()[0].Schema(), arg)
706+
expr, err := expression.EvaluateExprWithNull(la.SCtx().GetExprCtx(), la.Children()[0].Schema(), arg)
707+
if err != nil {
708+
return false
709+
}
707710
if con, ok := expr.(*expression.Constant); !ok || !con.Value.IsNull() {
708711
return false
709712
}

0 commit comments

Comments
 (0)