Skip to content

Commit 2c484f6

Browse files
engleflyYour Name
authored andcommitted
[opt](nereids) optimize push limit to agg (#44042)
### What problem does this PR solve? Pr #34853 introduced PushTopnToAgg rule. But there is a limitation that Topn(limit) should output all group by keys. This pr removes this limitation by using the first group by key as order key.
1 parent 78556da commit 2c484f6

File tree

54 files changed

+1679
-1589
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1679
-1589
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import org.apache.doris.nereids.rules.Rule;
2222
import org.apache.doris.nereids.rules.RuleType;
2323
import org.apache.doris.nereids.trees.expressions.Expression;
24+
import org.apache.doris.nereids.trees.expressions.NamedExpression;
25+
import org.apache.doris.nereids.trees.expressions.SlotReference;
2426
import org.apache.doris.nereids.trees.plans.Plan;
2527
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
2628
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
@@ -32,6 +34,7 @@
3234
import com.google.common.collect.Lists;
3335

3436
import java.util.List;
37+
import java.util.Optional;
3538
import java.util.stream.Collectors;
3639

3740
/**
@@ -53,7 +56,11 @@ public List<Rule> buildRules() {
5356
>= limit.getLimit() + limit.getOffset())
5457
.then(limit -> {
5558
LogicalAggregate<? extends Plan> agg = limit.child();
56-
List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg);
59+
Optional<OrderKey> orderKeysOpt = tryGenerateOrderKeyByTheFirstGroupKey(agg);
60+
if (!orderKeysOpt.isPresent()) {
61+
return null;
62+
}
63+
List<OrderKey> orderKeys = Lists.newArrayList(orderKeysOpt.get());
5764
return new LogicalTopN<>(orderKeys, limit.getLimit(), limit.getOffset(), agg);
5865
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
5966
//limit->project->agg to topn->project->agg
@@ -62,12 +69,47 @@ public List<Rule> buildRules() {
6269
&& ConnectContext.get().getSessionVariable().pushTopnToAgg
6370
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
6471
>= limit.getLimit() + limit.getOffset())
65-
.when(limit -> outputAllGroupKeys(limit, limit.child().child()))
6672
.then(limit -> {
6773
LogicalProject<? extends Plan> project = limit.child();
68-
LogicalAggregate<? extends Plan> agg = (LogicalAggregate<? extends Plan>) project.child();
69-
List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg);
70-
return new LogicalTopN<>(orderKeys, limit.getLimit(), limit.getOffset(), project);
74+
LogicalAggregate<? extends Plan> agg
75+
= (LogicalAggregate<? extends Plan>) project.child();
76+
Optional<OrderKey> orderKeysOpt = tryGenerateOrderKeyByTheFirstGroupKey(agg);
77+
if (!orderKeysOpt.isPresent()) {
78+
return null;
79+
}
80+
List<OrderKey> orderKeys = Lists.newArrayList(orderKeysOpt.get());
81+
Plan result;
82+
83+
if (outputAllGroupKeys(limit, agg)) {
84+
result = new LogicalTopN<>(orderKeys, limit.getLimit(),
85+
limit.getOffset(), project);
86+
} else {
87+
// add the first group by key to topn, and prune this key by upper project
88+
// topn order keys are prefix of group by keys
89+
// refer to PushTopnToAgg.tryGenerateOrderKeyByGroupKeyAndTopnKey()
90+
Expression firstGroupByKey = agg.getGroupByExpressions().get(0);
91+
if (!(firstGroupByKey instanceof SlotReference)) {
92+
return null;
93+
}
94+
boolean shouldPruneFirstGroupByKey = true;
95+
if (project.getOutputs().contains(firstGroupByKey)) {
96+
shouldPruneFirstGroupByKey = false;
97+
} else {
98+
List<NamedExpression> bottomProjections = Lists.newArrayList(project.getProjects());
99+
bottomProjections.add((SlotReference) firstGroupByKey);
100+
project = project.withProjects(bottomProjections);
101+
}
102+
LogicalTopN topn = new LogicalTopN<>(orderKeys, limit.getLimit(),
103+
limit.getOffset(), project);
104+
if (shouldPruneFirstGroupByKey) {
105+
List<NamedExpression> limitOutput = limit.getOutput().stream()
106+
.map(e -> (NamedExpression) e).collect(Collectors.toList());
107+
result = new LogicalProject<>(limitOutput, topn);
108+
} else {
109+
result = topn;
110+
}
111+
}
112+
return result;
71113
}).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
72114
// topn -> agg: add all group key to sort key, if sort key is prefix of group key
73115
logicalTopN(logicalAggregate())
@@ -111,9 +153,10 @@ private boolean outputAllGroupKeys(LogicalLimit limit, LogicalAggregate agg) {
111153
return limit.getOutputSet().containsAll(agg.getGroupByExpressions());
112154
}
113155

114-
private List<OrderKey> generateOrderKeyByGroupKey(LogicalAggregate<? extends Plan> agg) {
115-
return agg.getGroupByExpressions().stream()
116-
.map(key -> new OrderKey(key, true, false))
117-
.collect(Collectors.toList());
156+
private Optional<OrderKey> tryGenerateOrderKeyByTheFirstGroupKey(LogicalAggregate<? extends Plan> agg) {
157+
if (agg.getGroupByExpressions().isEmpty()) {
158+
return Optional.empty();
159+
}
160+
return Optional.of(new OrderKey(agg.getGroupByExpressions().get(0), true, false));
118161
}
119162
}

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ public String toString() {
206206
"groupByExpr", groupByExpressions,
207207
"outputExpr", outputExpressions,
208208
"partitionExpr", partitionExpressions,
209-
"requireProperties", requireProperties,
210-
"topnOpt", topnPushInfo != null
209+
"topnFilter", topnPushInfo != null,
210+
"topnPushDown", getMutableState(MutableState.KEY_PUSH_TOPN_TO_AGG).isPresent()
211211
);
212212
}
213213

fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateSortTest.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,7 @@ void testSortLimit() {
165165
PlanChecker.from(connectContext).disableNereidsRules("PRUNE_EMPTY_PARTITION")
166166
.analyze("select count(*) from (select * from student order by id) t limit 1")
167167
.rewrite()
168-
// there is no topn below agg
169-
.matches(logicalTopN(logicalAggregate(logicalProject(logicalOlapScan()))));
168+
.nonMatch(logicalTopN());
170169
PlanChecker.from(connectContext)
171170
.disableNereidsRules("PRUNE_EMPTY_PARTITION")
172171
.analyze("select count(*) from (select * from student order by id limit 1) t")
@@ -184,8 +183,6 @@ void testSortLimit() {
184183
.analyze("select count(*) from "
185184
+ "(select * from student order by id) t1 left join student t2 on t1.id = t2.id limit 1")
186185
.rewrite()
187-
.matches(logicalTopN(logicalAggregate(logicalProject(logicalJoin(
188-
logicalProject(logicalOlapScan()),
189-
logicalProject(logicalOlapScan()))))));
186+
.nonMatch(logicalTopN());
190187
}
191188
}

regression-test/data/nereids_hint_tpcds_p0/shape/query23.out

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,35 +46,36 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
4646
--------------------------------filter(d_year IN (2000, 2001, 2002, 2003))
4747
----------------------------------PhysicalOlapScan[date_dim]
4848
----PhysicalResultSink
49-
------PhysicalTopN[GATHER_SORT]
50-
--------hashAgg[GLOBAL]
51-
----------PhysicalDistribute[DistributionSpecGather]
52-
------------hashAgg[LOCAL]
53-
--------------PhysicalUnion
54-
----------------PhysicalProject
55-
------------------hashJoin[RIGHT_SEMI_JOIN shuffle] hashCondition=((catalog_sales.cs_item_sk = frequent_ss_items.item_sk)) otherCondition=() build RFs:RF5 cs_item_sk->[item_sk]
56-
--------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF5
57-
--------------------PhysicalProject
58-
----------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((catalog_sales.cs_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF4 c_customer_sk->[cs_bill_customer_sk]
59-
------------------------PhysicalProject
60-
--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[cs_sold_date_sk]
61-
----------------------------PhysicalProject
62-
------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF3 RF4
63-
----------------------------PhysicalProject
64-
------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
65-
--------------------------------PhysicalOlapScan[date_dim]
66-
------------------------PhysicalCteConsumer ( cteId=CTEId#2 )
67-
----------------PhysicalProject
68-
------------------hashJoin[RIGHT_SEMI_JOIN shuffle] hashCondition=((web_sales.ws_item_sk = frequent_ss_items.item_sk)) otherCondition=() build RFs:RF8 ws_item_sk->[item_sk]
69-
--------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF8
70-
--------------------PhysicalProject
71-
----------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((web_sales.ws_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF7 c_customer_sk->[ws_bill_customer_sk]
72-
------------------------PhysicalProject
73-
--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF6 d_date_sk->[ws_sold_date_sk]
74-
----------------------------PhysicalProject
75-
------------------------------PhysicalOlapScan[web_sales] apply RFs: RF6 RF7
76-
----------------------------PhysicalProject
77-
------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
78-
--------------------------------PhysicalOlapScan[date_dim]
79-
------------------------PhysicalCteConsumer ( cteId=CTEId#2 )
49+
------PhysicalLimit[GLOBAL]
50+
--------PhysicalLimit[LOCAL]
51+
----------hashAgg[GLOBAL]
52+
------------PhysicalDistribute[DistributionSpecGather]
53+
--------------hashAgg[LOCAL]
54+
----------------PhysicalUnion
55+
------------------PhysicalProject
56+
--------------------hashJoin[RIGHT_SEMI_JOIN shuffle] hashCondition=((catalog_sales.cs_item_sk = frequent_ss_items.item_sk)) otherCondition=() build RFs:RF5 cs_item_sk->[item_sk]
57+
----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF5
58+
----------------------PhysicalProject
59+
------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((catalog_sales.cs_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF4 c_customer_sk->[cs_bill_customer_sk]
60+
--------------------------PhysicalProject
61+
----------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[cs_sold_date_sk]
62+
------------------------------PhysicalProject
63+
--------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF3 RF4
64+
------------------------------PhysicalProject
65+
--------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
66+
----------------------------------PhysicalOlapScan[date_dim]
67+
--------------------------PhysicalCteConsumer ( cteId=CTEId#2 )
68+
------------------PhysicalProject
69+
--------------------hashJoin[RIGHT_SEMI_JOIN shuffle] hashCondition=((web_sales.ws_item_sk = frequent_ss_items.item_sk)) otherCondition=() build RFs:RF8 ws_item_sk->[item_sk]
70+
----------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF8
71+
----------------------PhysicalProject
72+
------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((web_sales.ws_bill_customer_sk = best_ss_customer.c_customer_sk)) otherCondition=() build RFs:RF7 c_customer_sk->[ws_bill_customer_sk]
73+
--------------------------PhysicalProject
74+
----------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF6 d_date_sk->[ws_sold_date_sk]
75+
------------------------------PhysicalProject
76+
--------------------------------PhysicalOlapScan[web_sales] apply RFs: RF6 RF7
77+
------------------------------PhysicalProject
78+
--------------------------------filter((date_dim.d_moy = 7) and (date_dim.d_year = 2000))
79+
----------------------------------PhysicalOlapScan[date_dim]
80+
--------------------------PhysicalCteConsumer ( cteId=CTEId#2 )
8081

regression-test/data/nereids_hint_tpcds_p0/shape/query32.out

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
-- This file is automatically generated. You should know what you did if you want to edit this
22
-- !ds_shape_32 --
33
PhysicalResultSink
4-
--PhysicalTopN[GATHER_SORT]
5-
----hashAgg[GLOBAL]
6-
------PhysicalDistribute[DistributionSpecGather]
7-
--------hashAgg[LOCAL]
8-
----------PhysicalProject
9-
------------filter((cast(cs_ext_discount_amt as DECIMALV3(38, 5)) > (1.3 * avg(cast(cs_ext_discount_amt as DECIMALV3(9, 4))) OVER(PARTITION BY i_item_sk))))
10-
--------------PhysicalWindow
11-
----------------PhysicalQuickSort[LOCAL_SORT]
12-
------------------PhysicalDistribute[DistributionSpecHash]
13-
--------------------PhysicalProject
14-
----------------------hashJoin[INNER_JOIN broadcast] hashCondition=((date_dim.d_date_sk = catalog_sales.cs_sold_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[cs_sold_date_sk]
15-
------------------------PhysicalProject
16-
--------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((item.i_item_sk = catalog_sales.cs_item_sk)) otherCondition=() build RFs:RF0 i_item_sk->[cs_item_sk]
17-
----------------------------PhysicalProject
18-
------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF0 RF1
19-
----------------------------PhysicalProject
20-
------------------------------filter((item.i_manufact_id = 722))
21-
--------------------------------PhysicalOlapScan[item]
22-
------------------------PhysicalProject
23-
--------------------------filter((date_dim.d_date <= '2001-06-07') and (date_dim.d_date >= '2001-03-09'))
24-
----------------------------PhysicalOlapScan[date_dim]
4+
--PhysicalLimit[GLOBAL]
5+
----PhysicalLimit[LOCAL]
6+
------hashAgg[GLOBAL]
7+
--------PhysicalDistribute[DistributionSpecGather]
8+
----------hashAgg[LOCAL]
9+
------------PhysicalProject
10+
--------------filter((cast(cs_ext_discount_amt as DECIMALV3(38, 5)) > (1.3 * avg(cast(cs_ext_discount_amt as DECIMALV3(9, 4))) OVER(PARTITION BY i_item_sk))))
11+
----------------PhysicalWindow
12+
------------------PhysicalQuickSort[LOCAL_SORT]
13+
--------------------PhysicalDistribute[DistributionSpecHash]
14+
----------------------PhysicalProject
15+
------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((date_dim.d_date_sk = catalog_sales.cs_sold_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[cs_sold_date_sk]
16+
--------------------------PhysicalProject
17+
----------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((item.i_item_sk = catalog_sales.cs_item_sk)) otherCondition=() build RFs:RF0 i_item_sk->[cs_item_sk]
18+
------------------------------PhysicalProject
19+
--------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF0 RF1
20+
------------------------------PhysicalProject
21+
--------------------------------filter((item.i_manufact_id = 722))
22+
----------------------------------PhysicalOlapScan[item]
23+
--------------------------PhysicalProject
24+
----------------------------filter((date_dim.d_date <= '2001-06-07') and (date_dim.d_date >= '2001-03-09'))
25+
------------------------------PhysicalOlapScan[date_dim]
2526

2627
Hint log:
2728
Used: leading(catalog_sales item date_dim )

0 commit comments

Comments
 (0)