diff --git a/pkg/planner/core/casetest/partition/partition_pruner_test.go b/pkg/planner/core/casetest/partition/partition_pruner_test.go index 9c14428f137fe..1d6814b7694e2 100644 --- a/pkg/planner/core/casetest/partition/partition_pruner_test.go +++ b/pkg/planner/core/casetest/partition/partition_pruner_test.go @@ -49,6 +49,7 @@ func TestHashPartitionPruner(t *testing.T) { tk.MustExec("create table t8(a int, b int) partition by hash(a) partitions 6;") tk.MustExec("create table t9(a bit(1) default null, b int(11) default null) partition by hash(a) partitions 3;") //issue #22619 tk.MustExec("create table t10(a bigint unsigned) partition BY hash (a);") + tk.MustExec("create table t11(a int, b int) partition by hash(a + a + a + b) partitions 5") var input []string var output []struct { diff --git a/pkg/planner/core/casetest/partition/testdata/partition_pruner_in.json b/pkg/planner/core/casetest/partition/testdata/partition_pruner_in.json index aa35b16978330..cf007fe394100 100644 --- a/pkg/planner/core/casetest/partition/testdata/partition_pruner_in.json +++ b/pkg/planner/core/casetest/partition/testdata/partition_pruner_in.json @@ -30,7 +30,10 @@ "explain format = 'brief' select * from t8 where (a <= 10 and a >= 8) or (a <= 13 and a >= 11) or (a <= 16 and a >= 14)", "explain format = 'brief' select * from t8 where a < 12 and a > 9", "explain format = 'brief' select * from t9", - "explain format = 'brief' select * from t10 where a between 0 AND 15218001646226433652" + "explain format = 'brief' select * from t10 where a between 0 AND 15218001646226433652", + "explain format = 'brief' select * from t11 where a is null", + "explain format = 'brief' select * from t11 where a is null and b = 2", + "explain format = 'brief' select * from t11 where a = 1 and b = 2" ] }, { diff --git a/pkg/planner/core/casetest/partition/testdata/partition_pruner_out.json b/pkg/planner/core/casetest/partition/testdata/partition_pruner_out.json index 11d9e2650f85a..e766878892d6c 100644 --- a/pkg/planner/core/casetest/partition/testdata/partition_pruner_out.json +++ b/pkg/planner/core/casetest/partition/testdata/partition_pruner_out.json @@ -215,6 +215,30 @@ "└─Selection 250.00 cop[tikv] ge(test_partition.t10.a, 0), le(test_partition.t10.a, 15218001646226433652)", " └─TableFullScan 10000.00 cop[tikv] table:t10 keep order:false, stats:pseudo" ] + }, + { + "SQL": "explain format = 'brief' select * from t11 where a is null", + "Result": [ + "TableReader 10.00 root partition:all data:Selection", + "└─Selection 10.00 cop[tikv] isnull(test_partition.t11.a)", + " └─TableFullScan 10000.00 cop[tikv] table:t11 keep order:false, stats:pseudo" + ] + }, + { + "SQL": "explain format = 'brief' select * from t11 where a is null and b = 2", + "Result": [ + "TableReader 0.01 root partition:p0 data:Selection", + "└─Selection 0.01 cop[tikv] eq(test_partition.t11.b, 2), isnull(test_partition.t11.a)", + " └─TableFullScan 10000.00 cop[tikv] table:t11 keep order:false, stats:pseudo" + ] + }, + { + "SQL": "explain format = 'brief' select * from t11 where a = 1 and b = 2", + "Result": [ + "TableReader 0.01 root partition:p0 data:Selection", + "└─Selection 0.01 cop[tikv] eq(test_partition.t11.a, 1), eq(test_partition.t11.b, 2)", + " └─TableFullScan 10000.00 cop[tikv] table:t11 keep order:false, stats:pseudo" + ] } ] }, diff --git a/pkg/planner/core/rule_partition_processor.go b/pkg/planner/core/rule_partition_processor.go index 9d80b20bf7ff5..5066d314f7ba2 100644 --- a/pkg/planner/core/rule_partition_processor.go +++ b/pkg/planner/core/rule_partition_processor.go @@ -148,11 +148,18 @@ func generateHashPartitionExpr(ctx base.PlanContext, pi *model.PartitionInfo, co func getPartColumnsForHashPartition(hashExpr expression.Expression) ([]*expression.Column, []int) { partCols := expression.ExtractColumns(hashExpr) colLen := make([]int, 0, len(partCols)) + retCols := make([]*expression.Column, 0, len(partCols)) + filled := make(map[int64]struct{}) for i := 0; i < len(partCols); i++ { - partCols[i].Index = i - colLen = append(colLen, types.UnspecifiedLength) + // Deal with same columns. + if _, done := filled[partCols[i].UniqueID]; !done { + partCols[i].Index = len(filled) + filled[partCols[i].UniqueID] = struct{}{} + colLen = append(colLen, types.UnspecifiedLength) + retCols = append(retCols, partCols[i]) + } } - return partCols, colLen + return retCols, colLen } func (s *PartitionProcessor) getUsedHashPartitions(ctx base.PlanContext, @@ -247,16 +254,15 @@ func (s *PartitionProcessor) getUsedHashPartitions(ctx base.PlanContext, used = []int{FullRange} break } - if !r.HighVal[0].IsNull() { - if len(r.HighVal) != len(partCols) { - used = []int{-1} - break - } + + // The code below is for the range `r` is a point. + if len(r.HighVal) != len(partCols) { + used = []int{FullRange} + break } - highLowVals := make([]types.Datum, 0, len(r.HighVal)+len(r.LowVal)) - highLowVals = append(highLowVals, r.HighVal...) - highLowVals = append(highLowVals, r.LowVal...) - pos, isNull, err := hashExpr.EvalInt(ctx.GetExprCtx().GetEvalCtx(), chunk.MutRowFromDatums(highLowVals).ToRow()) + vals := make([]types.Datum, 0, len(partCols)) + vals = append(vals, r.HighVal...) + pos, isNull, err := hashExpr.EvalInt(ctx.GetExprCtx().GetEvalCtx(), chunk.MutRowFromDatums(vals).ToRow()) if err != nil { // If we failed to get the point position, we can just skip and ignore it. continue