Skip to content

Commit d4ef9af

Browse files
winorosAilinKid
andauthored
planner: maintain functional dependency for joins (pingcap#5)
Co-authored-by: ailinkid <[email protected]>
1 parent 716584a commit d4ef9af

File tree

6 files changed

+260
-19
lines changed

6 files changed

+260
-19
lines changed

planner/core/logical_plan_builder.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4269,10 +4269,10 @@ func (ds *DataSource) ExtractFD() *fd.FDSet {
42694269
notnullColsUniqueIDs := extractNotNullFromConds(ds.allConds, ds)
42704270

42714271
// extract the constant cols from selection conditions.
4272-
constUniqueIDs := extractConstantCols(ds.allConds, ds, fds)
4272+
constUniqueIDs := extractConstantCols(ds.allConds, ds.SCtx(), fds)
42734273

42744274
// extract equivalence cols.
4275-
equivUniqueIDs := extractEquivalenceCols(ds.allConds, ds, fds)
4275+
equivUniqueIDs := extractEquivalenceCols(ds.allConds, ds.SCtx(), fds)
42764276

42774277
// apply conditions to FD.
42784278
fds.MakeNotNull(notnullColsUniqueIDs)

planner/core/logical_plans.go

Lines changed: 107 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,99 @@ func (p *LogicalJoin) Shallow() *LogicalJoin {
181181
return join.Init(p.ctx, p.blockOffset)
182182
}
183183

184+
// ExtractFD implements the interface LogicalPlan.
185+
func (p *LogicalJoin) ExtractFD() *fd.FDSet {
186+
switch p.JoinType {
187+
case InnerJoin:
188+
return p.extractFDForInnerJoin()
189+
case LeftOuterJoin, RightOuterJoin:
190+
return p.extractFDForOuterJoin()
191+
case SemiJoin:
192+
return p.extractFDForSemiJoin()
193+
default:
194+
return &fd.FDSet{HashCodeToUniqueID: make(map[string]int)}
195+
}
196+
}
197+
198+
func (p *LogicalJoin) extractFDForSemiJoin() *fd.FDSet {
199+
// 1: since semi join will keep the part or all rows of the outer table, it's outer FD can be saved.
200+
// 2: the un-projected column will be left for the upper layer projection or already be pruned from bottom up.
201+
outerFD, _ := p.children[0].ExtractFD(), p.children[1].ExtractFD()
202+
fds := outerFD
203+
204+
eqCondSlice := expression.ScalarFuncs2Exprs(p.EqualConditions)
205+
allConds := append(eqCondSlice, p.OtherConditions...)
206+
notNullColsFromFilters := extractNotNullFromConds(allConds, p)
207+
208+
constUniqueIDs := extractConstantCols(p.LeftConditions, p.SCtx(), fds)
209+
210+
fds.MakeNotNull(notNullColsFromFilters)
211+
fds.AddConstants(constUniqueIDs)
212+
p.fdSet = fds
213+
return fds
214+
}
215+
216+
func (p *LogicalJoin) extractFDForInnerJoin() *fd.FDSet {
217+
leftFD, rightFD := p.children[0].ExtractFD(), p.children[1].ExtractFD()
218+
fds := leftFD
219+
fds.MakeCartesianProduct(rightFD)
220+
221+
eqCondSlice := expression.ScalarFuncs2Exprs(p.EqualConditions)
222+
allConds := append(eqCondSlice, p.OtherConditions...)
223+
notNullColsFromFilters := extractNotNullFromConds(allConds, p)
224+
225+
constUniqueIDs := extractConstantCols(eqCondSlice, p.SCtx(), fds)
226+
227+
equivUniqueIDs := extractEquivalenceCols(eqCondSlice, p.SCtx(), fds)
228+
229+
fds.MakeNotNull(notNullColsFromFilters)
230+
fds.AddConstants(constUniqueIDs)
231+
for _, equiv := range equivUniqueIDs {
232+
fds.AddEquivalence(equiv[0], equiv[1])
233+
}
234+
p.fdSet = fds
235+
return fds
236+
}
237+
238+
func (p *LogicalJoin) extractFDForOuterJoin() *fd.FDSet {
239+
outerFD, innerFD := p.children[0].ExtractFD(), p.children[1].ExtractFD()
240+
innerCondition := p.RightConditions
241+
outerCols, innerCols := fd.NewFastIntSet(), fd.NewFastIntSet()
242+
for _, col := range p.children[0].Schema().Columns {
243+
outerCols.Insert(int(col.UniqueID))
244+
}
245+
for _, col := range p.children[1].Schema().Columns {
246+
innerCols.Insert(int(col.UniqueID))
247+
}
248+
if p.JoinType == RightOuterJoin {
249+
innerFD, outerFD = outerFD, innerFD
250+
innerCondition = p.LeftConditions
251+
innerCols, outerCols = outerCols, innerCols
252+
}
253+
254+
eqCondSlice := expression.ScalarFuncs2Exprs(p.EqualConditions)
255+
allConds := append(eqCondSlice, p.OtherConditions...)
256+
allConds = append(allConds, innerCondition...)
257+
notNullColsFromFilters := extractNotNullFromConds(allConds, p)
258+
259+
filterFD := &fd.FDSet{HashCodeToUniqueID: make(map[string]int)}
260+
261+
constUniqueIDs := extractConstantCols(eqCondSlice, p.SCtx(), filterFD)
262+
263+
equivUniqueIDs := extractEquivalenceCols(eqCondSlice, p.SCtx(), filterFD)
264+
265+
filterFD.AddConstants(constUniqueIDs)
266+
for _, equiv := range equivUniqueIDs {
267+
filterFD.AddEquivalence(equiv[0], equiv[1])
268+
}
269+
filterFD.MakeNotNull(notNullColsFromFilters)
270+
271+
fds := outerFD
272+
fds.MakeOuterJoin(innerFD, filterFD, outerCols, innerCols)
273+
p.fdSet = fds
274+
return fds
275+
}
276+
184277
// GetJoinKeys extracts join keys(columns) from EqualConditions. It returns left join keys, right
185278
// join keys and an `isNullEQ` array which means the `joinKey[i]` is a `NullEQ` function. The `hasNullEQ`
186279
// means whether there is a `NullEQ` of a join key.
@@ -657,34 +750,34 @@ func extractNotNullFromConds(Conditions []expression.Expression, p LogicalPlan)
657750
return notnullColsUniqueIDs
658751
}
659752

660-
func extractConstantCols(Conditions []expression.Expression, p LogicalPlan, fds *fd.FDSet) fd.FastIntSet {
753+
func extractConstantCols(Conditions []expression.Expression, sctx sessionctx.Context, fds *fd.FDSet) fd.FastIntSet {
661754
// extract constant cols
662755
// eg: where a=1 and b is null and (1+c)=5.
663756
// TODO: Some columns can only be determined to be constant from multiple constraints (e.g. x <= 1 AND x >= 1)
664757
var (
665758
constObjs []expression.Expression
666759
constUniqueIDs = fd.NewFastIntSet()
667760
)
668-
constObjs = expression.ExtractConstantEqColumnsOrScalar(p.SCtx(), constObjs, Conditions)
761+
constObjs = expression.ExtractConstantEqColumnsOrScalar(sctx, constObjs, Conditions)
669762
for _, constObj := range constObjs {
670763
switch x := constObj.(type) {
671764
case *expression.Column:
672765
constUniqueIDs.Insert(int(x.UniqueID))
673766
case *expression.ScalarFunction:
674-
hashCode := string(x.HashCode(p.SCtx().GetSessionVars().StmtCtx))
767+
hashCode := string(x.HashCode(sctx.GetSessionVars().StmtCtx))
675768
if uniqueID, ok := fds.IsHashCodeRegistered(hashCode); ok {
676769
constUniqueIDs.Insert(uniqueID)
677770
} else {
678-
scalarUniqueID := int(p.SCtx().GetSessionVars().AllocPlanColumnID())
679-
fds.RegisterUniqueID(string(x.HashCode(p.SCtx().GetSessionVars().StmtCtx)), scalarUniqueID)
771+
scalarUniqueID := int(sctx.GetSessionVars().AllocPlanColumnID())
772+
fds.RegisterUniqueID(string(x.HashCode(sctx.GetSessionVars().StmtCtx)), scalarUniqueID)
680773
constUniqueIDs.Insert(scalarUniqueID)
681774
}
682775
}
683776
}
684777
return constUniqueIDs
685778
}
686779

687-
func extractEquivalenceCols(Conditions []expression.Expression, p LogicalPlan, fds *fd.FDSet) [][]fd.FastIntSet {
780+
func extractEquivalenceCols(Conditions []expression.Expression, sctx sessionctx.Context, fds *fd.FDSet) [][]fd.FastIntSet {
688781
var equivObjsPair [][]expression.Expression
689782
equivObjsPair = expression.ExtractEquivalenceColumns(equivObjsPair, Conditions)
690783
equivUniqueIDs := make([][]fd.FastIntSet, 0, len(equivObjsPair))
@@ -698,12 +791,12 @@ func extractEquivalenceCols(Conditions []expression.Expression, p LogicalPlan, f
698791
case *expression.Column:
699792
lhsUniqueID = int(x.UniqueID)
700793
case *expression.ScalarFunction:
701-
hashCode := string(x.HashCode(p.SCtx().GetSessionVars().StmtCtx))
794+
hashCode := string(x.HashCode(sctx.GetSessionVars().StmtCtx))
702795
if uniqueID, ok := fds.IsHashCodeRegistered(hashCode); ok {
703796
lhsUniqueID = uniqueID
704797
} else {
705-
scalarUniqueID := int(p.SCtx().GetSessionVars().AllocPlanColumnID())
706-
fds.RegisterUniqueID(string(x.HashCode(p.SCtx().GetSessionVars().StmtCtx)), scalarUniqueID)
798+
scalarUniqueID := int(sctx.GetSessionVars().AllocPlanColumnID())
799+
fds.RegisterUniqueID(string(x.HashCode(sctx.GetSessionVars().StmtCtx)), scalarUniqueID)
707800
lhsUniqueID = scalarUniqueID
708801
}
709802
}
@@ -712,12 +805,12 @@ func extractEquivalenceCols(Conditions []expression.Expression, p LogicalPlan, f
712805
case *expression.Column:
713806
rhsUniqueID = int(x.UniqueID)
714807
case *expression.ScalarFunction:
715-
hashCode := string(x.HashCode(p.SCtx().GetSessionVars().StmtCtx))
808+
hashCode := string(x.HashCode(sctx.GetSessionVars().StmtCtx))
716809
if uniqueID, ok := fds.IsHashCodeRegistered(hashCode); ok {
717810
rhsUniqueID = uniqueID
718811
} else {
719-
scalarUniqueID := int(p.SCtx().GetSessionVars().AllocPlanColumnID())
720-
fds.RegisterUniqueID(string(x.HashCode(p.SCtx().GetSessionVars().StmtCtx)), scalarUniqueID)
812+
scalarUniqueID := int(sctx.GetSessionVars().AllocPlanColumnID())
813+
fds.RegisterUniqueID(string(x.HashCode(sctx.GetSessionVars().StmtCtx)), scalarUniqueID)
721814
rhsUniqueID = scalarUniqueID
722815
}
723816
}
@@ -743,10 +836,10 @@ func (p *LogicalSelection) ExtractFD() *fd.FDSet {
743836
notnullColsUniqueIDs.UnionWith(extractNotNullFromConds(p.Conditions, p))
744837

745838
// extract the constant cols from selection conditions.
746-
constUniqueIDs := extractConstantCols(p.Conditions, p, fds)
839+
constUniqueIDs := extractConstantCols(p.Conditions, p.SCtx(), fds)
747840

748841
// extract equivalence cols.
749-
equivUniqueIDs := extractEquivalenceCols(p.Conditions, p, fds)
842+
equivUniqueIDs := extractEquivalenceCols(p.Conditions, p.SCtx(), fds)
750843

751844
// apply operator's characteristic's FD setting.
752845
fds.MakeNotNull(notnullColsUniqueIDs)

planner/core/stringer.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ func fdToString(in LogicalPlan, strs []string, idxs []int) ([]string, []int) {
6767
}
6868
case *DataSource:
6969
strs = append(strs, "{"+x.fdSet.String()+"}")
70+
case *LogicalApply:
71+
strs = append(strs, "{"+x.fdSet.String()+"}")
72+
case *LogicalJoin:
73+
strs = append(strs, "{"+x.fdSet.String()+"}")
7074
default:
7175
}
7276
return strs, idxs

planner/functional_dependency/extract_fd_test.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,102 @@ func TestFDSet_ExtractFD(t *testing.T) {
218218
ass.Equal(tt.fd, plannercore.FDToString(p.(plannercore.LogicalPlan)), comment)
219219
}
220220
}
221+
222+
func TestFDSet_ExtractFDForApply(t *testing.T) {
223+
t.Parallel()
224+
ass := assert.New(t)
225+
226+
store, clean := testkit.CreateMockStore(t)
227+
defer clean()
228+
par := parser.New()
229+
par.SetParserConfig(parser.ParserConfig{EnableWindowFunction: true, EnableStrictDoubleTypeCheck: true})
230+
231+
tk := testkit.NewTestKit(t, store)
232+
tk.MustExec("use test")
233+
tk.MustExec("CREATE TABLE X (a INT PRIMARY KEY, b INT, c INT, d INT, e INT)")
234+
tk.MustExec("CREATE UNIQUE INDEX uni ON X (b, c)")
235+
tk.MustExec("CREATE TABLE Y (m INT, n INT, p INT, q INT, PRIMARY KEY (m, n))")
236+
237+
tests := []struct {
238+
sql string
239+
best string
240+
fd string
241+
}{
242+
{
243+
sql: "select * from X where exists (select * from Y where m=a limit 1)",
244+
// For this Apply, it's essentially a semi join, for every `a` in X, do the inner loop once.
245+
// +- datasource(x)
246+
// +- limit
247+
// +- datasource(Y)
248+
best: "Apply{DataScan(X)->DataScan(Y)->Limit}->Projection",
249+
// Since semi join will keep the **all** rows of the outer table, it's FD can be derived.
250+
fd: "{(1)-->(2-5), (2,3)~~>(1,4,5)} >>> {(1)-->(2-5), (2,3)~~>(1,4,5)}",
251+
},
252+
{
253+
sql: "select a, b from X where exists (select * from Y where m=a limit 1)",
254+
// For this Apply, it's essentially a semi join, for every `a` in X, do the inner loop once.
255+
// +- datasource(x)
256+
// +- limit
257+
// +- datasource(Y)
258+
best: "Apply{DataScan(X)->DataScan(Y)->Limit}->Projection", // semi join
259+
// Since semi join will keep the **part** rows of the outer table, it's FD can be derived.
260+
fd: "{(1)-->(2-5), (2,3)~~>(1,4,5)} >>> {(1)-->(2)}",
261+
},
262+
{
263+
// Limit will refuse to de-correlate apply to join while this won't.
264+
sql: "select * from X where exists (select * from Y where m=a and p=1)",
265+
best: "Join{DataScan(X)->DataScan(Y)}(test.x.a,test.y.m)->Projection", // semi join
266+
fd: "{(1)-->(2-5), (2,3)~~>(1,4,5)} >>> {(1)-->(2-5), (2,3)~~>(1,4,5)}",
267+
},
268+
{
269+
sql: "select * from X where exists (select * from Y where m=a and p=q)",
270+
best: "Join{DataScan(X)->DataScan(Y)}(test.x.a,test.y.m)->Projection", // semi join
271+
fd: "{(1)-->(2-5), (2,3)~~>(1,4,5)} >>> {(1)-->(2-5), (2,3)~~>(1,4,5)}",
272+
},
273+
{
274+
sql: "select * from X where exists (select * from Y where m=a and b=1)",
275+
best: "Join{DataScan(X)->DataScan(Y)}(test.x.a,test.y.m)->Projection", // semi join
276+
// b=1 is semi join's left condition which should be conserved.
277+
fd: "{(1)-->(3-5), (2,3)~~>(1,4,5), ()-->(2)} >>> {(1)-->(3-5), (2,3)~~>(1,4,5), ()-->(2)}",
278+
},
279+
{
280+
sql: "select * from (select b,c,d,e from X) X1 where exists (select * from Y where p=q and n=1) ",
281+
best: "Dual->Projection",
282+
fd: "{}",
283+
},
284+
{
285+
sql: "select * from (select b,c,d,e from X) X1 where exists (select * from Y where p=b and n=1) ",
286+
best: "Join{DataScan(X)->DataScan(Y)}(test.x.b,test.y.p)->Projection", // semi join
287+
fd: "{(1)-->(2-5), (2,3)~~>(1,4,5)} >>> {(2,3)~~>(4,5)}",
288+
},
289+
{
290+
sql: "select * from X where exists (select m, p, q from Y where n=a and p=1)",
291+
best: "Join{DataScan(X)->DataScan(Y)}(test.x.a,test.y.n)->Projection",
292+
// p=1 is semi join's right condition which should **NOT** be conserved.
293+
fd: "{(1)-->(2-5), (2,3)~~>(1,4,5)} >>> {(1)-->(2-5), (2,3)~~>(1,4,5)}",
294+
},
295+
}
296+
297+
ctx := context.TODO()
298+
is := testGetIS(ass, tk.Session())
299+
for i, tt := range tests {
300+
comment := fmt.Sprintf("case:%v sql:%s", i, tt.sql)
301+
stmt, err := par.ParseOneStmt(tt.sql, "", "")
302+
ass.Nil(err, comment)
303+
tk.Session().GetSessionVars().PlanID = 0
304+
tk.Session().GetSessionVars().PlanColumnID = 0
305+
err = plannercore.Preprocess(tk.Session(), stmt, plannercore.WithPreprocessorReturn(&plannercore.PreprocessorReturn{InfoSchema: is}))
306+
ass.Nil(err)
307+
tk.Session().PrepareTSFuture(ctx)
308+
builder, _ := plannercore.NewPlanBuilder().Init(tk.Session(), is, &hint.BlockHintProcessor{})
309+
// extract FD to every OP
310+
p, err := builder.Build(ctx, stmt)
311+
ass.Nil(err)
312+
p, err = plannercore.LogicalOptimizeTest(ctx, builder.GetOptFlag(), p.(plannercore.LogicalPlan))
313+
ass.Nil(err)
314+
ass.Equal(tt.best, plannercore.ToString(p), comment)
315+
// extract FD to every OP
316+
p.(plannercore.LogicalPlan).ExtractFD()
317+
ass.Equal(tt.fd, plannercore.FDToString(p.(plannercore.LogicalPlan)), comment)
318+
}
319+
}

planner/functional_dependency/fd_graph.go

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,8 +500,49 @@ func (s *FDSet) MakeCartesianProduct(rhs *FDSet) {
500500
}
501501
}
502502

503-
func (s *FDSet) MakeLeftOuter(lhs, filterFDs *FDSet, lCols, rCols, notNullCols FastIntSet) {
504-
// TODO:
503+
// MakeApply maintain the FD relationship between outer and inner table after Apply OP is done.
504+
// Since Apply is implemented by join, it seems the fd can be extracted through its inner join directly.
505+
func (s *FDSet) MakeApply(inner *FDSet) {
506+
}
507+
508+
// MakeOuterJoin generates the records the fdSet of the outer join.
509+
// As we know, the outer join would generate null extended rows compared with inner join.
510+
// So we cannot directly do the same thing with the inner join. This function deals with the special cases of the outer join.
511+
func (s *FDSet) MakeOuterJoin(innerFDs, filterFDs *FDSet, outerCols, innerCols FastIntSet) {
512+
for _, edge := range innerFDs.fdEdges {
513+
// We don't maintain the equiv edges and lax edges currently.
514+
if edge.equiv || !edge.strict {
515+
continue
516+
}
517+
// If the one of the column from the inner child's functional dependency's left side is not null, this FD
518+
// can be remained.
519+
// This is because that the outer join would generate null-extended rows. So if at least one row from the left side
520+
// is not null. We can guarantee that the there's no same part between the original rows and the generated rows.
521+
// So the null extended rows would not break the original functional dependency.
522+
if edge.from.SubsetOf(innerFDs.NotNullCols) {
523+
s.addFunctionalDependency(edge.from, edge.to, edge.strict, edge.equiv)
524+
} else if edge.from.SubsetOf(filterFDs.NotNullCols) {
525+
// If we can make sure the filters of the join would filter out all nulls of this FD's left side
526+
// and this FD is from the join's inner child. This FD can be remained.
527+
// This is because the outer join filters out the null values. The generated null-extended rows would not
528+
// find the same row from the original rows of the inner child. So it won't break the original functional dependency.
529+
s.addFunctionalDependency(edge.from, edge.to, edge.strict, edge.equiv)
530+
}
531+
}
532+
for _, edge := range filterFDs.fdEdges {
533+
// We don't maintain the equiv edges and the lax edges currently.
534+
if edge.equiv || !edge.strict {
535+
continue
536+
}
537+
if edge.from.SubsetOf(innerCols) && edge.to.SubsetOf(innerCols) && edge.from.SubsetOf(filterFDs.NotNullCols) {
538+
// The functional dependency generated from the join filter would be reserved if it meets the following conditions:
539+
// 1. All columns from this functional dependency are the columns from the inner side.
540+
// 2. The join keys can filter out the null values from the left side of the FD.
541+
// This is the same with the above cases. If the join filters can filter out the null values of the FD's left side,
542+
// We won't find a same row between the original rows of the inner side and the generated null-extended rows.
543+
s.addFunctionalDependency(edge.from, edge.to, edge.strict, edge.equiv)
544+
}
545+
}
505546
}
506547

507548
func (s FDSet) AllCols() FastIntSet {

planner/functional_dependency/fd_graph_ported_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func TestFuncDeps_ColsAreKey(t *testing.T) {
6565
loj = *abcde
6666
loj.MakeCartesianProduct(mnpq)
6767
loj.AddConstants(NewFastIntSet(3))
68-
loj.MakeLeftOuter(abcde, &FDSet{}, preservedCols, nullExtendedCols, NewFastIntSet(1, 10, 11))
68+
loj.MakeOuterJoin(nil, &FDSet{}, preservedCols, nullExtendedCols)
6969
loj.AddEquivalence(NewFastIntSet(1), NewFastIntSet(10))
7070

7171
testcases := []struct {
@@ -330,3 +330,7 @@ func makeJoinFD(ass *assert.Assertions) *FDSet {
330330
testColsAreLaxKey(ass, join, NewFastIntSet(2, 3, 11), join.AllCols(), false)
331331
return join
332332
}
333+
334+
func TestFuncDeps_OuterJoin(t *testing.T) {
335+
336+
}

0 commit comments

Comments
 (0)