Skip to content

Commit 5ebfed9

Browse files
authored
[opt](Nereids) use 1 as narrowest column when do column pruning on union (#41719)
just like previous PR #41548 this PR process union node to ensure not require any column from its children when it is required by its parent with empty slot set
1 parent 1813ec2 commit 5ebfed9

File tree

4 files changed

+57
-21
lines changed

4 files changed

+57
-21
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.doris.nereids.trees.expressions.Slot;
2727
import org.apache.doris.nereids.trees.expressions.SlotReference;
2828
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
29+
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
2930
import org.apache.doris.nereids.trees.plans.Plan;
3031
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
3132
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
@@ -42,6 +43,7 @@
4243
import org.apache.doris.nereids.trees.plans.logical.OutputPrunable;
4344
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
4445
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
46+
import org.apache.doris.nereids.types.TinyIntType;
4547
import org.apache.doris.nereids.util.ExpressionUtils;
4648
import org.apache.doris.nereids.util.Utils;
4749
import org.apache.doris.qe.ConnectContext;
@@ -345,6 +347,8 @@ private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context)
345347
}
346348
List<NamedExpression> prunedOutputs = Lists.newArrayList();
347349
List<List<NamedExpression>> constantExprsList = union.getConstantExprsList();
350+
List<List<SlotReference>> regularChildrenOutputs = union.getRegularChildrenOutputs();
351+
List<Plan> children = union.children();
348352
List<Integer> extractColumnIndex = Lists.newArrayList();
349353
for (int i = 0; i < originOutput.size(); i++) {
350354
NamedExpression output = originOutput.get(i);
@@ -353,31 +357,41 @@ private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context)
353357
extractColumnIndex.add(i);
354358
}
355359
}
356-
if (prunedOutputs.isEmpty()) {
357-
List<NamedExpression> candidates = Lists.newArrayList(originOutput);
358-
candidates.retainAll(keys);
359-
if (candidates.isEmpty()) {
360-
candidates = originOutput;
361-
}
362-
NamedExpression minimumColumn = ExpressionUtils.selectMinimumColumn(candidates);
363-
prunedOutputs = ImmutableList.of(minimumColumn);
364-
extractColumnIndex.add(originOutput.indexOf(minimumColumn));
365-
}
366360

367-
int len = extractColumnIndex.size();
368361
ImmutableList.Builder<List<NamedExpression>> prunedConstantExprsList
369362
= ImmutableList.builderWithExpectedSize(constantExprsList.size());
370-
for (List<NamedExpression> row : constantExprsList) {
371-
ImmutableList.Builder<NamedExpression> newRow = ImmutableList.builderWithExpectedSize(len);
372-
for (int idx : extractColumnIndex) {
373-
newRow.add(row.get(idx));
363+
if (prunedOutputs.isEmpty()) {
364+
// process prune all columns
365+
NamedExpression originSlot = originOutput.get(0);
366+
prunedOutputs = ImmutableList.of(new SlotReference(originSlot.getExprId(), originSlot.getName(),
367+
TinyIntType.INSTANCE, false, originSlot.getQualifier()));
368+
regularChildrenOutputs = Lists.newArrayListWithCapacity(regularChildrenOutputs.size());
369+
children = Lists.newArrayListWithCapacity(children.size());
370+
for (int i = 0; i < union.getArity(); i++) {
371+
LogicalProject<?> project = new LogicalProject<>(
372+
ImmutableList.of(new Alias(new TinyIntLiteral((byte) 1))), union.child(i));
373+
regularChildrenOutputs.add((List) project.getOutput());
374+
children.add(project);
375+
}
376+
for (int i = 0; i < constantExprsList.size(); i++) {
377+
prunedConstantExprsList.add(ImmutableList.of(new Alias(new TinyIntLiteral((byte) 1))));
378+
}
379+
} else {
380+
int len = extractColumnIndex.size();
381+
for (List<NamedExpression> row : constantExprsList) {
382+
ImmutableList.Builder<NamedExpression> newRow = ImmutableList.builderWithExpectedSize(len);
383+
for (int idx : extractColumnIndex) {
384+
newRow.add(row.get(idx));
385+
}
386+
prunedConstantExprsList.add(newRow.build());
374387
}
375-
prunedConstantExprsList.add(newRow.build());
376388
}
377-
if (prunedOutputs.equals(originOutput)) {
389+
390+
if (prunedOutputs.equals(originOutput) && !context.requiredSlots.isEmpty()) {
378391
return union;
379392
} else {
380-
return union.withNewOutputsAndConstExprsList(prunedOutputs, prunedConstantExprsList.build());
393+
return union.withNewOutputsChildrenAndConstExprsList(prunedOutputs, children,
394+
regularChildrenOutputs, prunedConstantExprsList.build());
381395
}
382396
}
383397

fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ColumnPruningTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import org.apache.doris.nereids.trees.expressions.NamedExpression;
2121
import org.apache.doris.nereids.trees.expressions.SlotReference;
22+
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
2223
import org.apache.doris.nereids.trees.plans.Plan;
2324
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
2425
import org.apache.doris.nereids.types.DoubleType;
@@ -313,6 +314,21 @@ public void pruneAggregateOutput() {
313314
);
314315
}
315316

317+
@Test
318+
public void pruneUnionAllWithCount() {
319+
PlanChecker.from(connectContext)
320+
.analyze("select count() from (select 1, 2 union all select id, age from student) t")
321+
.customRewrite(new ColumnPruning())
322+
.matches(
323+
logicalProject(
324+
logicalUnion(
325+
logicalProject().when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral),
326+
logicalProject().when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral)
327+
)
328+
).when(p -> p.getProjects().size() == 1 && p.getProjects().get(0).child(0) instanceof TinyIntLiteral)
329+
);
330+
}
331+
316332
private List<String> getOutputQualifiedNames(LogicalProject<? extends Plan> p) {
317333
return getOutputQualifiedNames(p.getOutputs());
318334
}

regression-test/suites/nereids_rules_p0/column_pruning/union_const_expr_column_pruning.groovy

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
suite("const_expr_column_pruning") {
1919
sql """SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"""
2020
// should only keep one column in union
21-
sql "select count(1) from(select 3, 6 union all select 1, 3) t"
22-
sql "select count(a) from(select 3 a, 6 union all select 1, 3) t"
23-
}
21+
sql """select count(1) from(select 3, 6 union all select 1, 3) t"""
22+
sql """select count(1) from(select 3, 6 union all select "1", 3) t"""
23+
sql """select count(a) from(select 3 a, 6 union all select "1", 3) t"""
24+
}

regression-test/suites/nereids_rules_p0/column_pruning/window_column_pruning.groovy

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,10 @@ suite("window_column_pruning") {
5656
sql "select id from (select id, rank() over() px from window_column_pruning union all select id, rank() over() px from window_column_pruning) a"
5757
notContains "rank"
5858
}
59+
60+
explain {
61+
sql "select count() from (select row_number() over(partition by id) from window_column_pruning) tmp"
62+
notContains "row_number"
63+
}
5964
}
6065

0 commit comments

Comments
 (0)