Skip to content

Commit e2eaeb6

Browse files
seawindekeanji-x
andauthored
[fix](mtmv)Fix high level materialized view not hit because group by eliminate fail (#36888)
this depends on #36839 #36886 Such as low level materialized view contains 5 group by dimension, and query also has 5 group by dimension, they are equals.In this scene, would not add aggregate on mv when try to rewrite query by materialized view. But if query only use 4 group by dimension and the remain demension is can be eliminated, then the query will change to 4 group by dimension. this will cause add aggregate on mv and will cause high level materialize rewrite fail later. Solution: in aggregate rewrite by materialized view, we try to eliminate mv group by dimension by query used dimension. if eliminate successfully. then high level will rewrite continue. such as low level mv def sql is as following: def join_mv_1 = """ select l_orderkey, l_partkey, l_suppkey, o_orderkey, o_custkey, cast(sum(IFNULL(o_orderkey, 0) * IFNULL(o_custkey, 0)) as decimal(28, 8)) as agg1, sum(o_totalprice) as sum_total, max(o_totalprice) as max_total, min(o_totalprice) as min_total, count(*) as count_all, bitmap_union(to_bitmap(case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end)) cnt_1, bitmap_union(to_bitmap(case when o_shippriority > 2 and o_orderkey IN (2) then o_custkey else null end)) as cnt_2 from lineitem_1 inner join orders_1 on lineitem_1.l_orderkey = orders_1.o_orderkey where lineitem_1.l_shipdate >= "2023-10-17" group by l_orderkey, l_partkey, l_suppkey, o_orderkey, o_custkey """ def join_mv_2 = """ select l_orderkey, l_partkey, l_suppkey, o_orderkey, o_custkey, ps_partkey, ps_suppkey, t.agg1 as agg1, t.sum_total as agg3, t.max_total as agg4, t.min_total as agg5, t.count_all as agg6, cast(sum(IFNULL(ps_suppkey, 0) * IFNULL(ps_partkey, 0)) as decimal(28, 8)) as agg2 from ${mv_1} as t inner join partsupp_1 on t.l_partkey = partsupp_1.ps_partkey and t.l_suppkey = partsupp_1.ps_suppkey where partsupp_1.ps_suppkey > 1 group by l_orderkey, l_partkey, l_suppkey, o_orderkey, o_custkey, ps_partkey, ps_suppkey, agg1, agg3, agg4, agg5, agg6 """ high level mv def sql is as following: def join_mv_3 = """ select t1.l_orderkey, t2.l_partkey, t1.l_suppkey, t2.o_orderkey, t1.o_custkey, t2.ps_partkey, t1.ps_suppkey, t2.agg1, >t1.agg2, t2.agg3, t1.agg4, t2.agg5, t1.agg6 from ${mv_2} as t1 left join ${mv_2} as t2 on t1.l_orderkey = t2.l_orderkey where t1.l_orderkey > 1 group by t1.l_orderkey, t2.l_partkey, t1.l_suppkey, t2.o_orderkey, t1.o_custkey, t2.ps_partkey, t1.ps_suppkey, >t2.agg1, >t1.agg2, t2.agg3, t1.agg4, t2.agg5, t1.agg6 """ if we run the query as following, it can hit the mv3 select t1.l_orderkey, t2.l_partkey, t1.l_suppkey, t2.o_orderkey, t1.o_custkey, t2.ps_partkey, t1.ps_suppkey, t2.agg1, >t1.agg2, >t2.agg3, t1.agg4, t2.agg5, t1.agg6 from ( select l_orderkey, l_partkey, l_suppkey, o_orderkey, o_custkey, ps_partkey, ps_suppkey, t.agg1 as agg1, t.sum_total as agg3, t.max_total as agg4, t.min_total as agg5, t.count_all as agg6, cast(sum(IFNULL(ps_suppkey, 0) * IFNULL(ps_partkey, 0)) as decimal(28, 8)) as agg2 from ( select l_orderkey, l_partkey, l_suppkey, o_orderkey, o_custkey, cast(sum(IFNULL(o_orderkey, 0) * >IFNULL(o_custkey, 0)) as decimal(28, 8)) as agg1, sum(o_totalprice) as sum_total, max(o_totalprice) as max_total, min(o_totalprice) as min_total, count(*) as count_all, bitmap_union(to_bitmap(case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end)) >cnt_1, bitmap_union(to_bitmap(case when o_shippriority > 2 and o_orderkey IN (2) then o_custkey else null end)) as >cnt_2 from lineitem_1 inner join orders_1 on lineitem_1.l_orderkey = orders_1.o_orderkey where lineitem_1.l_shipdate >= "2023-10-17" group by l_orderkey, l_partkey, l_suppkey, o_orderkey, o_custkey ) as t inner join partsupp_1 on t.l_partkey = partsupp_1.ps_partkey and t.l_suppkey = partsupp_1.ps_suppkey where partsupp_1.ps_suppkey > 1 group by l_orderkey, l_partkey, l_suppkey, o_orderkey, o_custkey, ps_partkey, ps_suppkey, agg1, agg3, agg4, >agg5, >agg6 ) as t1 left join ( select l_orderkey, l_partkey, l_suppkey, o_orderkey, o_custkey, ps_partkey, ps_suppkey, t.agg1 as agg1, t.sum_total as agg3, t.max_total as agg4, t.min_total as agg5, t.count_all as agg6, cast(sum(IFNULL(ps_suppkey, 0) * IFNULL(ps_partkey, 0)) as decimal(28, 8)) as agg2 from ( select l_orderkey, l_partkey, l_suppkey, o_orderkey, o_custkey, cast(sum(IFNULL(o_orderkey, 0) * >IFNULL(o_custkey, 0)) as decimal(28, 8)) as agg1, sum(o_totalprice) as sum_total, max(o_totalprice) as max_total, min(o_totalprice) as min_total, count(*) as count_all, bitmap_union(to_bitmap(case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end)) >cnt_1, bitmap_union(to_bitmap(case when o_shippriority > 2 and o_orderkey IN (2) then o_custkey else null end)) as >cnt_2 from lineitem_1 inner join orders_1 on lineitem_1.l_orderkey = orders_1.o_orderkey where lineitem_1.l_shipdate >= "2023-10-17" group by l_orderkey, l_partkey, l_suppkey, o_orderkey, o_custkey ) as t inner join partsupp_1 on t.l_partkey = partsupp_1.ps_partkey and t.l_suppkey = partsupp_1.ps_suppkey where partsupp_1.ps_suppkey > 1 group by l_orderkey, l_partkey, l_suppkey, o_orderkey, o_custkey, ps_partkey, ps_suppkey, agg1, agg3, agg4, agg5, >agg6 ) as t2 on t1.l_orderkey = t2.l_orderkey where t1.l_orderkey > 1 group by t1.l_orderkey, t2.l_partkey, t1.l_suppkey, t2.o_orderkey, t1.o_custkey, t2.ps_partkey, t1.ps_suppkey, >t2.agg1, >t1.agg2, t2.agg3, t1.agg4, t2.agg5, t1.agg6 --------- Co-authored-by: xiejiann <[email protected]>
1 parent 39b358c commit e2eaeb6

File tree

9 files changed

+94
-25
lines changed

9 files changed

+94
-25
lines changed

fe/fe-core/src/main/java/org/apache/doris/catalog/MTMV.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,10 @@ public MTMVCache getOrGenerateCache(ConnectContext connectionContext) throws Ana
291291
return cache;
292292
}
293293

294+
public MTMVCache getCache() {
295+
return cache;
296+
}
297+
294298
public Map<String, String> getMvProperties() {
295299
readMvLock();
296300
try {

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewScanRule.java renamed to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewScanRule.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
/**
3535
* This is responsible for single table rewriting according to different pattern
3636
* */
37-
public abstract class MaterializedViewScanRule extends AbstractMaterializedViewRule {
37+
public abstract class AbstractMaterializedViewScanRule extends AbstractMaterializedViewRule {
3838

3939
@Override
4040
protected Plan rewriteQueryByView(MatchMode matchMode,

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewFilterProjectScanRule.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
/**
3131
* MaterializedViewFilterProjectScanRule
3232
*/
33-
public class MaterializedViewFilterProjectScanRule extends MaterializedViewScanRule {
33+
public class MaterializedViewFilterProjectScanRule extends AbstractMaterializedViewScanRule {
3434

3535
public static final MaterializedViewFilterProjectScanRule INSTANCE = new MaterializedViewFilterProjectScanRule();
3636

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewFilterScanRule.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
/**
3030
* MaterializedViewFilterScanRule
3131
*/
32-
public class MaterializedViewFilterScanRule extends MaterializedViewScanRule {
32+
public class MaterializedViewFilterScanRule extends AbstractMaterializedViewScanRule {
3333

3434
public static final MaterializedViewFilterScanRule INSTANCE = new MaterializedViewFilterScanRule();
3535

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewOnlyScanRule.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
/**
2828
* MaterializedViewOnlyScanRule
2929
*/
30-
public class MaterializedViewOnlyScanRule extends MaterializedViewScanRule {
30+
public class MaterializedViewOnlyScanRule extends AbstractMaterializedViewScanRule {
3131

3232
public static final MaterializedViewOnlyScanRule INSTANCE = new MaterializedViewOnlyScanRule();
3333

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewProjectFilterScanRule.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
/**
3131
* MaterializedViewProjectFilterScanRule
3232
*/
33-
public class MaterializedViewProjectFilterScanRule extends MaterializedViewScanRule {
33+
public class MaterializedViewProjectFilterScanRule extends AbstractMaterializedViewScanRule {
3434

3535
public static final MaterializedViewProjectFilterScanRule INSTANCE = new MaterializedViewProjectFilterScanRule();
3636

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewProjectScanRule.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
/**
3030
* MaterializedViewProjectScanRule
3131
*/
32-
public class MaterializedViewProjectScanRule extends MaterializedViewScanRule {
32+
public class MaterializedViewProjectScanRule extends AbstractMaterializedViewScanRule {
3333

3434
public static final MaterializedViewProjectScanRule INSTANCE = new MaterializedViewProjectScanRule();
3535

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
package org.apache.doris.nereids.trees.plans.logical;
1919

2020
import org.apache.doris.catalog.Column;
21+
import org.apache.doris.catalog.MTMV;
2122
import org.apache.doris.catalog.OlapTable;
2223
import org.apache.doris.catalog.Table;
24+
import org.apache.doris.mtmv.MTMVCache;
2325
import org.apache.doris.nereids.memo.GroupExpression;
2426
import org.apache.doris.nereids.properties.DataTrait;
2527
import org.apache.doris.nereids.properties.LogicalProperties;
@@ -47,6 +49,7 @@
4749
import org.json.JSONObject;
4850

4951
import java.util.Arrays;
52+
import java.util.HashMap;
5053
import java.util.List;
5154
import java.util.Map;
5255
import java.util.Objects;
@@ -495,7 +498,16 @@ AGGREGATE KEY (siteid,citycode,username)
495498
return;
496499
}
497500
Set<Slot> outputSet = Utils.fastToImmutableSet(getOutputSet());
498-
if (getTable().getKeysType().isAggregationFamily()) {
501+
if (getTable() instanceof MTMV) {
502+
MTMV mtmv = (MTMV) getTable();
503+
MTMVCache cache = mtmv.getCache();
504+
if (cache == null) {
505+
return;
506+
}
507+
Plan originalPlan = cache.getOriginalPlan();
508+
builder.addUniqueSlot(originalPlan.getLogicalProperties().getTrait());
509+
builder.replaceUniqueBy(constructReplaceMap(mtmv));
510+
} else if (getTable().getKeysType().isAggregationFamily()) {
499511
ImmutableSet.Builder<Slot> uniqSlots = ImmutableSet.builderWithExpectedSize(outputSet.size());
500512
for (Slot slot : outputSet) {
501513
if (!(slot instanceof SlotReference)) {
@@ -509,4 +521,58 @@ AGGREGATE KEY (siteid,citycode,username)
509521
builder.addUniqueSlot(uniqSlots.build());
510522
}
511523
}
524+
525+
@Override
526+
public void computeUniform(DataTrait.Builder builder) {
527+
if (getTable() instanceof MTMV) {
528+
MTMV mtmv = (MTMV) getTable();
529+
MTMVCache cache = mtmv.getCache();
530+
if (cache == null) {
531+
return;
532+
}
533+
Plan originalPlan = cache.getOriginalPlan();
534+
builder.addUniformSlot(originalPlan.getLogicalProperties().getTrait());
535+
builder.replaceUniformBy(constructReplaceMap(mtmv));
536+
}
537+
}
538+
539+
@Override
540+
public void computeEqualSet(DataTrait.Builder builder) {
541+
if (getTable() instanceof MTMV && getTable().getName().equals("mv1")) {
542+
System.out.println();
543+
}
544+
if (getTable() instanceof MTMV) {
545+
MTMV mtmv = (MTMV) getTable();
546+
MTMVCache cache = mtmv.getCache();
547+
if (cache == null) {
548+
return;
549+
}
550+
Plan originalPlan = cache.getOriginalPlan();
551+
builder.addEqualSet(originalPlan.getLogicalProperties().getTrait());
552+
builder.replaceEqualSetBy(constructReplaceMap(mtmv));
553+
}
554+
}
555+
556+
@Override
557+
public void computeFd(DataTrait.Builder builder) {
558+
if (getTable() instanceof MTMV) {
559+
MTMV mtmv = (MTMV) getTable();
560+
MTMVCache cache = mtmv.getCache();
561+
if (cache == null) {
562+
return;
563+
}
564+
Plan originalPlan = cache.getOriginalPlan();
565+
builder.addFuncDepsDG(originalPlan.getLogicalProperties().getTrait());
566+
builder.replaceFuncDepsBy(constructReplaceMap(mtmv));
567+
}
568+
}
569+
570+
Map<Slot, Slot> constructReplaceMap(MTMV mtmv) {
571+
Map<Slot, Slot> replaceMap = new HashMap<>();
572+
List<Slot> originOutputs = mtmv.getCache().getOriginalPlan().getOutput();
573+
for (int i = 0; i < getOutput().size(); i++) {
574+
replaceMap.put(originOutputs.get(i), getOutput().get(i));
575+
}
576+
return replaceMap;
577+
}
512578
}

regression-test/suites/nereids_rules_p0/mv/nested_mtmv/nested_mtmv.groovy

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -839,23 +839,22 @@ suite("nested_mtmv") {
839839
}
840840
compare_res(sql_2 + " order by 1,2,3,4,5,6,7,8,9,10,11,12,13")
841841

842-
// tmp and will fix soon
843-
// explain {
844-
// sql("${sql_3}")
845-
// contains "${mv_3}(${mv_3})"
846-
// }
847-
// compare_res(sql_3 + " order by 1,2,3,4,5,6,7,8,9,10,11,12,13")
848-
//
849-
// explain {
850-
// sql("${sql_4}")
851-
// contains "${mv_4}(${mv_4})"
852-
// }
853-
// compare_res(sql_4 + " order by 1,2,3,4,5,6,7,8,9,10,11,12,13")
854-
//
855-
// explain {
856-
// sql("${sql_5}")
857-
// contains "${mv_5}(${mv_5})"
858-
// }
859-
// compare_res(sql_5 + " order by 1,2,3,4,5,6,7,8,9,10,11,12,13")
842+
explain {
843+
sql("${sql_3}")
844+
contains "${mv_3}(${mv_3})"
845+
}
846+
compare_res(sql_3 + " order by 1,2,3,4,5,6,7,8,9,10,11,12,13")
847+
848+
explain {
849+
sql("${sql_4}")
850+
contains "${mv_4}(${mv_4})"
851+
}
852+
compare_res(sql_4 + " order by 1,2,3,4,5,6,7,8,9,10,11,12,13")
853+
854+
explain {
855+
sql("${sql_5}")
856+
contains "${mv_5}(${mv_5})"
857+
}
858+
compare_res(sql_5 + " order by 1,2,3,4,5,6,7,8,9,10,11,12,13")
860859

861860
}

0 commit comments

Comments
 (0)