Skip to content

Commit 8e57797

Browse files
authored
planner: fix column evaluator can not detect input's column-ref and thus swapping and destroying later column ref projection logic (#53794) (#56199)
close #53713
1 parent e329890 commit 8e57797

File tree

6 files changed

+229
-2
lines changed

6 files changed

+229
-2
lines changed

pkg/expression/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ go_library(
9797
"//pkg/util/encrypt",
9898
"//pkg/util/generatedexpr",
9999
"//pkg/util/hack",
100+
"//pkg/util/intest",
100101
"//pkg/util/intset",
101102
"//pkg/util/logutil",
102103
"//pkg/util/mathutil",

pkg/expression/evaluator.go

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,30 @@
1515
package expression
1616

1717
import (
18+
"sync/atomic"
19+
1820
"github.com/pingcap/tidb/pkg/sessionctx"
1921
"github.com/pingcap/tidb/pkg/util/chunk"
22+
"github.com/pingcap/tidb/pkg/util/disjointset"
23+
"github.com/pingcap/tidb/pkg/util/intest"
2024
)
2125

2226
type columnEvaluator struct {
2327
inputIdxToOutputIdxes map[int][]int
28+
// mergedInputIdxToOutputIdxes is only determined in runtime when saw the input chunk.
29+
mergedInputIdxToOutputIdxes atomic.Pointer[map[int][]int]
2430
}
2531

2632
// run evaluates "Column" expressions.
2733
// NOTE: It should be called after all the other expressions are evaluated
2834
//
2935
// since it will change the content of the input Chunk.
3036
func (e *columnEvaluator) run(ctx sessionctx.Context, input, output *chunk.Chunk) error {
31-
for inputIdx, outputIdxes := range e.inputIdxToOutputIdxes {
37+
// mergedInputIdxToOutputIdxes only can be determined in runtime when we saw the input chunk structure.
38+
if e.mergedInputIdxToOutputIdxes.Load() == nil {
39+
e.mergeInputIdxToOutputIdxes(input, e.inputIdxToOutputIdxes)
40+
}
41+
for inputIdx, outputIdxes := range *e.mergedInputIdxToOutputIdxes.Load() {
3242
if err := output.SwapColumn(outputIdxes[0], input, inputIdx); err != nil {
3343
return err
3444
}
@@ -39,6 +49,93 @@ func (e *columnEvaluator) run(ctx sessionctx.Context, input, output *chunk.Chunk
3949
return nil
4050
}
4151

52+
// mergeInputIdxToOutputIdxes merges separate inputIdxToOutputIdxes entries when column references
53+
// are detected within the input chunk. This process ensures consistent handling of columns derived
54+
// from the same original source.
55+
//
56+
// Consider the following scenario:
57+
//
58+
// Initial scan operation produces a column 'a':
59+
//
60+
// scan: a (addr: ???)
61+
//
62+
// This column 'a' is used in the first projection (proj1) to create two columns a1 and a2, both referencing 'a':
63+
//
64+
// proj1
65+
// / \
66+
// / \
67+
// / \
68+
// a1 (addr: 0xe) a2 (addr: 0xe)
69+
// / \
70+
// / \
71+
// / \
72+
// proj2 proj2
73+
// / \ / \
74+
// / \ / \
75+
// a3 a4 a5 a6
76+
//
77+
// (addr: 0xe) (addr: 0xe) (addr: 0xe) (addr: 0xe)
78+
//
79+
// Here, a1 and a2 share the same address (0xe), indicating they reference the same data from the original 'a'.
80+
//
81+
// When moving to the second projection (proj2), the system tries to project these columns further:
82+
// - The first set (left side) consists of a3 and a4, derived from a1, both retaining the address (0xe).
83+
// - The second set (right side) consists of a5 and a6, derived from a2, also starting with address (0xe).
84+
//
85+
// When proj1 is complete, the output chunk contains two columns [a1, a2], both derived from the single column 'a' from the scan.
86+
// Since both a1 and a2 are column references with the same address (0xe), they are treated as referencing the same data.
87+
//
88+
// In proj2, two separate <inputIdx, []outputIdxes> items are created:
89+
// - <0, [0,1]>: This means the 0th input column (a1) is projected twice, into the 0th and 1st columns of the output chunk.
90+
// - <1, [2,3]>: This means the 1st input column (a2) is projected twice, into the 2nd and 3rd columns of the output chunk.
91+
//
92+
// Due to the column swapping logic in each projection, after applying the <0, [0,1]> projection,
93+
// the addresses for a1 and a2 may become swapped or invalid:
94+
//
95+
// proj1: a1 (addr: invalid) a2 (addr: invalid)
96+
//
97+
// This can lead to issues in proj2, where further operations on these columns may be unsafe:
98+
//
99+
// proj2: a3 (addr: 0xe) a4 (addr: 0xe) a5 (addr: ???) a6 (addr: ???)
100+
//
101+
// Therefore, it's crucial to identify and merge the original column references early, ensuring
102+
// the final inputIdxToOutputIdxes mapping accurately reflects the shared origins of the data.
103+
// For instance, <0, [0,1,2,3]> indicates that the 0th input column (original 'a') is referenced
104+
// by all four output columns in the final output.
105+
//
106+
// mergeInputIdxToOutputIdxes merges inputIdxToOutputIdxes based on detected column references.
107+
// This ensures that columns with the same reference are correctly handled in the output chunk.
108+
func (e *columnEvaluator) mergeInputIdxToOutputIdxes(input *chunk.Chunk, inputIdxToOutputIdxes map[int][]int) {
109+
originalDJSet := disjointset.NewSet[int](4)
110+
flag := make([]bool, input.NumCols())
111+
// Detect self column-references inside the input chunk by comparing column addresses
112+
for i := 0; i < input.NumCols(); i++ {
113+
if flag[i] {
114+
continue
115+
}
116+
for j := i + 1; j < input.NumCols(); j++ {
117+
if input.Column(i) == input.Column(j) {
118+
flag[j] = true
119+
originalDJSet.Union(i, j)
120+
}
121+
}
122+
}
123+
// Merge inputIdxToOutputIdxes based on the detected column references.
124+
newInputIdxToOutputIdxes := make(map[int][]int, len(inputIdxToOutputIdxes))
125+
for inputIdx := range inputIdxToOutputIdxes {
126+
// Root idx is internal offset, not the right column index.
127+
originalRootIdx := originalDJSet.FindRoot(inputIdx)
128+
originalVal, ok := originalDJSet.FindVal(originalRootIdx)
129+
intest.Assert(ok)
130+
mergedOutputIdxes := newInputIdxToOutputIdxes[originalVal]
131+
mergedOutputIdxes = append(mergedOutputIdxes, inputIdxToOutputIdxes[inputIdx]...)
132+
newInputIdxToOutputIdxes[originalVal] = mergedOutputIdxes
133+
}
134+
// Update the merged inputIdxToOutputIdxes automatically.
135+
// Once failed, it means other worker has done this job at meantime.
136+
e.mergedInputIdxToOutputIdxes.CompareAndSwap(nil, &newInputIdxToOutputIdxes)
137+
}
138+
42139
type defaultEvaluator struct {
43140
outputIdxes []int
44141
exprs []Expression

pkg/expression/evaluator_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package expression
1616

1717
import (
18+
"slices"
1819
"sync/atomic"
1920
"testing"
2021
"time"
@@ -593,3 +594,42 @@ func TestMod(t *testing.T) {
593594
require.NoError(t, err)
594595
require.Equal(t, types.NewDatum(1.5), r)
595596
}
597+
598+
func TestMergeInputIdxToOutputIdxes(t *testing.T) {
599+
ctx := createContext(t)
600+
inputIdxToOutputIdxes := make(map[int][]int)
601+
// input 0th should be column referred as 0th and 1st in output columns.
602+
inputIdxToOutputIdxes[0] = []int{0, 1}
603+
// input 1th should be column referred as 2nd and 3rd in output columns.
604+
inputIdxToOutputIdxes[1] = []int{2, 3}
605+
columnEval := columnEvaluator{inputIdxToOutputIdxes: inputIdxToOutputIdxes}
606+
607+
input := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong), types.NewFieldType(mysql.TypeLonglong)}, 2)
608+
input.AppendInt64(0, 99)
609+
// input chunk's 0th and 1st are column referred itself.
610+
input.MakeRef(0, 1)
611+
612+
// chunk: col1 <---(ref) col2
613+
// ____________/ \___________/ \___
614+
// proj: col1 col2 col3 col4
615+
//
616+
// original case after inputIdxToOutputIdxes[0], the original col2 will be nil pointer
617+
// cause consecutive col3,col4 ref projection are invalid.
618+
//
619+
// after fix, the new inputIdxToOutputIdxes should be: inputIdxToOutputIdxes[0]: {0, 1, 2, 3}
620+
621+
output := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong), types.NewFieldType(mysql.TypeLonglong),
622+
types.NewFieldType(mysql.TypeLonglong), types.NewFieldType(mysql.TypeLonglong)}, 2)
623+
624+
err := columnEval.run(ctx, input, output)
625+
require.NoError(t, err)
626+
// all four columns are column-referred, pointing to the first one.
627+
require.Equal(t, output.Column(0), output.Column(1))
628+
require.Equal(t, output.Column(1), output.Column(2))
629+
require.Equal(t, output.Column(2), output.Column(3))
630+
require.Equal(t, output.GetRow(0).GetInt64(0), int64(99))
631+
632+
require.Equal(t, len(*columnEval.mergedInputIdxToOutputIdxes.Load()), 1)
633+
slices.Sort((*columnEval.mergedInputIdxToOutputIdxes.Load())[0])
634+
require.Equal(t, (*columnEval.mergedInputIdxToOutputIdxes.Load())[0], []int{0, 1, 2, 3})
635+
}

pkg/util/disjointset/BUILD.bazel

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
22

33
go_library(
44
name = "disjointset",
5-
srcs = ["int_set.go"],
5+
srcs = [
6+
"int_set.go",
7+
"set.go",
8+
],
69
importpath = "github.com/pingcap/tidb/pkg/util/disjointset",
710
visibility = ["//visibility:public"],
811
)

pkg/util/disjointset/int_set.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ func (m *IntSet) FindRoot(a int) int {
3838
if a == m.parent[a] {
3939
return a
4040
}
41+
// Path compression, which leads the time complexity to the inverse Ackermann function.
4142
m.parent[a] = m.FindRoot(m.parent[a])
4243
return m.parent[a]
4344
}

pkg/util/disjointset/set.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright 2024 PingCAP, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package disjointset
16+
17+
// Set is the universal implementation of a disjoint set.
18+
// It's designed for sparse cases or non-integer types.
19+
// If you are dealing with continuous integers, you should use SimpleIntSet to avoid the cost of a hash map.
20+
// We hash the original value to an integer index and then apply the core disjoint set algorithm.
21+
// Time complexity: the union operation has an inverse Ackermann function time complexity, which is very close to O(1).
22+
type Set[T comparable] struct {
23+
parent []int
24+
val2Idx map[T]int
25+
idx2Val map[int]T
26+
tailIdx int
27+
}
28+
29+
// NewSet creates a disjoint set.
30+
func NewSet[T comparable](size int) *Set[T] {
31+
return &Set[T]{
32+
parent: make([]int, 0, size),
33+
val2Idx: make(map[T]int, size),
34+
idx2Val: make(map[int]T, size),
35+
tailIdx: 0,
36+
}
37+
}
38+
39+
func (s *Set[T]) findRootOriginalVal(a T) int {
40+
idx, ok := s.val2Idx[a]
41+
if !ok {
42+
s.parent = append(s.parent, s.tailIdx)
43+
s.val2Idx[a] = s.tailIdx
44+
s.tailIdx++
45+
s.idx2Val[s.tailIdx-1] = a
46+
return s.tailIdx - 1
47+
}
48+
return s.findRootInternal(idx)
49+
}
50+
51+
// findRoot is an internal implementation. Call it inside findRootOriginalVal.
52+
func (s *Set[T]) findRootInternal(a int) int {
53+
if s.parent[a] != a {
54+
// Path compression, which leads the time complexity to the inverse Ackermann function.
55+
s.parent[a] = s.findRootInternal(s.parent[a])
56+
}
57+
return s.parent[a]
58+
}
59+
60+
// InSameGroup checks whether a and b are in the same group.
61+
func (s *Set[T]) InSameGroup(a, b T) bool {
62+
return s.findRootOriginalVal(a) == s.findRootOriginalVal(b)
63+
}
64+
65+
// Union joins two sets in the disjoint set.
66+
func (s *Set[T]) Union(a, b T) {
67+
rootA := s.findRootOriginalVal(a)
68+
rootB := s.findRootOriginalVal(b)
69+
// take b as successor, respect the rootA as the root of the new set.
70+
if rootA != rootB {
71+
s.parent[rootB] = rootA
72+
}
73+
}
74+
75+
// FindRoot finds the root of the set that contains a.
76+
func (s *Set[T]) FindRoot(a T) int {
77+
// if a is not in the set, assign a new index to it.
78+
return s.findRootOriginalVal(a)
79+
}
80+
81+
// FindVal finds the value of the set corresponding to the index.
82+
func (s *Set[T]) FindVal(idx int) (T, bool) {
83+
v, ok := s.idx2Val[s.findRootInternal(idx)]
84+
return v, ok
85+
}

0 commit comments

Comments
 (0)