Skip to content

Commit 5aa6511

Browse files
committed
HIVE-29176: Wrong result when HiveAntiJoin is replacing an IS NULL filter on a nullable column
1 parent 6f53c7f commit 5aa6511

File tree

6 files changed

+710
-90
lines changed

6 files changed

+710
-90
lines changed

ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,43 +1233,17 @@ public FixNullabilityShuttle(RexBuilder rexBuilder,
12331233
}
12341234

12351235
/**
1236-
* Checks if any of the expression given as list expressions are from right side of the join.
1237-
* This is used during anti join conversion.
1238-
*
1239-
* @param joinRel Join node whose right side has to be searched.
1240-
* @param expressions The list of expression to search.
1241-
* @return true if any of the expressions is from right side of join.
1236+
* Given a join, creates a bitset of the joined columns originating from the right-hand side.
1237+
* @param joinRel a join that concatenates all columns from its inputs (so no semi-join)
1238+
* @return a bitset
12421239
*/
1243-
public static boolean hasAnyExpressionFromRightSide(RelNode joinRel, List<RexNode> expressions) {
1244-
List<RelDataTypeField> joinFields = joinRel.getRowType().getFieldList();
1245-
int nTotalFields = joinFields.size();
1246-
List<RelDataTypeField> leftFields = (joinRel.getInputs().get(0)).getRowType().getFieldList();
1247-
int nFieldsLeft = leftFields.size();
1248-
ImmutableBitSet rightBitmap = ImmutableBitSet.range(nFieldsLeft, nTotalFields);
1249-
1250-
for (RexNode node : expressions) {
1251-
ImmutableBitSet inputBits = RelOptUtil.InputFinder.bits(node);
1252-
if (rightBitmap.contains(inputBits)) {
1253-
return true;
1254-
}
1255-
}
1256-
return false;
1257-
}
1258-
1259-
public static boolean hasAllExpressionsFromRightSide(RelNode joinRel, List<RexNode> expressions) {
1260-
List<RelDataTypeField> joinFields = joinRel.getRowType().getFieldList();
1261-
int nTotalFields = joinFields.size();
1262-
List<RelDataTypeField> leftFields = (joinRel.getInputs().get(0)).getRowType().getFieldList();
1263-
int nFieldsLeft = leftFields.size();
1264-
ImmutableBitSet rightBitmap = ImmutableBitSet.range(nFieldsLeft, nTotalFields);
1265-
1266-
for (RexNode node : expressions) {
1267-
ImmutableBitSet inputBits = RelOptUtil.InputFinder.bits(node);
1268-
if (!rightBitmap.contains(inputBits)) {
1269-
return false;
1270-
}
1240+
public static ImmutableBitSet getRightSideBitset(RelNode joinRel) {
1241+
if(joinRel.getInputs().size() != 2) {
1242+
throw new IllegalArgumentException("The relation must have exactly two children:\n" + RelOptUtil.toString(joinRel));
12711243
}
1272-
return true;
1244+
int nTotalFields = joinRel.getRowType().getFieldCount();
1245+
int nFieldsLeft = (joinRel.getInputs().get(0)).getRowType().getFieldCount();
1246+
return ImmutableBitSet.range(nFieldsLeft, nTotalFields);
12731247
}
12741248

12751249
/**

ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAntiSemiJoinRule.java

Lines changed: 80 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,15 @@
2626
import org.apache.calcite.rel.core.Join;
2727
import org.apache.calcite.rel.core.JoinRelType;
2828
import org.apache.calcite.rel.core.Project;
29+
import org.apache.calcite.rel.type.RelDataTypeField;
2930
import org.apache.calcite.rex.RexCall;
3031
import org.apache.calcite.rex.RexNode;
3132
import org.apache.calcite.rex.RexVisitorImpl;
3233
import org.apache.calcite.sql.SqlKind;
3334
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
35+
import org.apache.calcite.util.ImmutableBitSet;
36+
import org.apache.calcite.util.Util;
37+
import org.apache.commons.lang3.mutable.MutableBoolean;
3438
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
3539
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAntiJoin;
3640
import org.slf4j.Logger;
@@ -39,7 +43,7 @@
3943
import java.util.ArrayList;
4044
import java.util.Collections;
4145
import java.util.List;
42-
import java.util.concurrent.atomic.AtomicBoolean;
46+
import java.util.Optional;
4347

4448
/**
4549
* Planner rule that converts a join plus filter to anti join.
@@ -86,14 +90,17 @@ protected void perform(RelOptRuleCall call, Project project, Filter filter, Join
8690

8791
assert (filter != null);
8892

89-
List<RexNode> filterList = getResidualFilterNodes(filter, join);
90-
if (filterList == null) {
93+
ImmutableBitSet rhsFields = HiveCalciteUtil.getRightSideBitset(join);
94+
Optional<List<RexNode>> optFilterList = getResidualFilterNodes(filter, join, rhsFields);
95+
if (optFilterList.isEmpty()) {
9196
return;
9297
}
98+
List<RexNode> filterList = optFilterList.get();
9399

94100
// If any projection is there from right side, then we can not convert to anti join.
95-
boolean hasProjection = HiveCalciteUtil.hasAnyExpressionFromRightSide(join, project.getProjects());
96-
if (hasProjection) {
101+
ImmutableBitSet projectedFields = RelOptUtil.InputFinder.bits(project.getProjects(), null);
102+
boolean projectionUsesRHS = projectedFields.intersects(rhsFields);
103+
if (projectionUsesRHS) {
97104
return;
98105
}
99106

@@ -119,13 +126,14 @@ protected void perform(RelOptRuleCall call, Project project, Filter filter, Join
119126
/**
120127
* Extracts the non-null filter conditions from given filter node.
121128
*
122-
* @param filter The filter condition to be checked.
123-
* @param join Join node whose right side has to be searched.
129+
* @param filter The filter condition to be checked.
130+
* @param join Join node whose right side has to be searched.
131+
* @param rhsFields
124132
* @return null : Anti join condition is not matched for filter.
125-
* Empty list : No residual filter conditions present.
126-
* Valid list containing the filter to be applied after join.
133+
* Empty list : No residual filter conditions present.
134+
* Valid list containing the filter to be applied after join.
127135
*/
128-
private List<RexNode> getResidualFilterNodes(Filter filter, Join join) {
136+
private Optional<List<RexNode>> getResidualFilterNodes(Filter filter, Join join, ImmutableBitSet rhsFields) {
129137
// 1. If null filter is not present from right side then we can not convert to anti join.
130138
// 2. If any non-null filter is present from right side, we can not convert it to anti join.
131139
// 3. Keep other filters which needs to be executed after join.
@@ -135,43 +143,76 @@ private List<RexNode> getResidualFilterNodes(Filter filter, Join join) {
135143
List<RexNode> aboveFilters = RelOptUtil.conjunctions(filter.getCondition());
136144
boolean hasNullFilterOnRightSide = false;
137145
List<RexNode> filterList = new ArrayList<>();
146+
final ImmutableBitSet notNullColumnsFromRightSide = getNotNullColumnsFromRightSide(join);
147+
138148
for (RexNode filterNode : aboveFilters) {
139-
if (filterNode.getKind() == SqlKind.IS_NULL) {
140-
// Null filter from right side table can be removed and its a pre-condition for anti join conversion.
141-
if (HiveCalciteUtil.hasAllExpressionsFromRightSide(join, Collections.singletonList(filterNode))
142-
&& isStrong(((RexCall) filterNode).getOperands().get(0))) {
143-
hasNullFilterOnRightSide = true;
144-
} else {
145-
filterList.add(filterNode);
146-
}
147-
} else {
148-
if (HiveCalciteUtil.hasAnyExpressionFromRightSide(join, Collections.singletonList(filterNode))) {
149-
// If some non null condition is present from right side, we can not convert the join to anti join as
150-
// anti join does not project the fields from right side.
151-
return null;
152-
} else {
153-
filterList.add(filterNode);
154-
}
149+
final ImmutableBitSet usedFields = RelOptUtil.InputFinder.bits(filterNode);
150+
boolean usesFieldFromRHS = usedFields.intersects(rhsFields);
151+
152+
if(!usesFieldFromRHS) {
153+
// Only LHS fields or constants, so the filterNode is part of the residual filter
154+
filterList.add(filterNode);
155+
continue;
156+
}
157+
158+
// In the following we check for filter nodes that let us deduce that
159+
// "an (originally) not-null column of RHS IS NULL because the LHS row will not be matched"
160+
161+
if(filterNode.getKind() != SqlKind.IS_NULL) {
162+
return Optional.empty();
163+
}
164+
165+
boolean usesRHSFieldsOnly = rhsFields.contains(usedFields);
166+
if (!usesRHSFieldsOnly) {
167+
// If there is a mix between LHS and RHS fields, don't convert to anti-join
168+
return Optional.empty();
169+
}
170+
171+
// Null filter from right side table can be removed and it is a pre-condition for anti join conversion.
172+
RexNode arg = ((RexCall) filterNode).getOperands().get(0);
173+
if (isStrong(arg, notNullColumnsFromRightSide)) {
174+
hasNullFilterOnRightSide = true;
175+
} else if(!isStrong(arg, rhsFields)) {
176+
// if all RHS fields are null and the IS NULL is still not fulfilled, bail out
177+
return Optional.empty();
155178
}
156179
}
157180

158181
if (!hasNullFilterOnRightSide) {
159-
return null;
182+
return Optional.empty();
160183
}
161-
return filterList;
184+
return Optional.of(filterList);
162185
}
163186

164-
private boolean isStrong(RexNode rexNode) {
165-
AtomicBoolean hasCast = new AtomicBoolean(false);
166-
rexNode.accept(new RexVisitorImpl<Void>(true) {
167-
@Override
168-
public Void visitCall(RexCall call) {
169-
if (call.getKind() == SqlKind.CAST) {
170-
hasCast.set(true);
171-
}
172-
return super.visitCall(call);
187+
private ImmutableBitSet getNotNullColumnsFromRightSide(RelNode joinRel) {
188+
// we need to shift the indices of the second child to the right
189+
int shift = (joinRel.getInput(0)).getRowType().getFieldCount();
190+
191+
ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
192+
List<RelDataTypeField> rhsFields = joinRel.getInput(1).getRowType().getFieldList();
193+
for(RelDataTypeField field : rhsFields) {
194+
if(!field.getType().isNullable()) {
195+
builder.set(shift+field.getIndex());
173196
}
174-
});
175-
return !hasCast.get() && Strong.isStrong(rexNode);
197+
}
198+
return builder.build();
199+
}
200+
201+
private boolean isStrong(RexNode rexNode, ImmutableBitSet rightSideBitset) {
202+
try {
203+
rexNode.accept(new RexVisitorImpl<Void>(true) {
204+
@Override
205+
public Void visitCall(RexCall call) {
206+
if (call.getKind() == SqlKind.CAST) {
207+
throw Util.FoundOne.NULL;
208+
}
209+
return super.visitCall(call);
210+
}
211+
});
212+
} catch (Util.FoundOne e) {
213+
// Hive's CAST might introduce NULL for NOT NULL fields
214+
return false;
215+
}
216+
return Strong.isNull(rexNode, rightSideBitset);
176217
}
177218
}
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;
19+
20+
import org.apache.calcite.plan.RelOptPlanner;
21+
import org.apache.calcite.rel.RelNode;
22+
import org.apache.calcite.rel.core.JoinRelType;
23+
import org.apache.calcite.runtime.Hook;
24+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
25+
import org.apache.calcite.tools.RelBuilder;
26+
import org.junit.Test;
27+
import org.junit.runner.RunWith;
28+
import org.mockito.junit.MockitoJUnitRunner;
29+
30+
import java.util.Collections;
31+
32+
import static org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.*;
33+
34+
@RunWith(MockitoJUnitRunner.class)
35+
public class TestHiveAntiSemiJoinRule {
36+
37+
PlanFixture fixture() {
38+
RelOptPlanner planner = buildPlanner(Collections.singletonList(HiveAntiSemiJoinRule.INSTANCE));
39+
return new PlanFixture(planner).registerTable("t1", T1Record.class).registerTable("t2", T2Record.class);
40+
}
41+
42+
@Test
43+
public void testFilterOnNullableColumn() {
44+
PlanFixture fixture = fixture();
45+
RelBuilder b = fixture.createRelBuilder();
46+
47+
// @formatter:off
48+
RelNode plan = b
49+
.scan("t1")
50+
.scan("t2")
51+
.join(JoinRelType.LEFT, b.equals(
52+
b.field(2, 0, "t1nullable"),
53+
b.field(2, 1, "t2id")))
54+
.filter(b.isNull(b.field("t2nullable")))
55+
.project(b.field("t1id"))
56+
.build();
57+
58+
String expectedPlan = "HiveProject(t1id=[$0])\n"
59+
+ " HiveFilter(condition=[IS NULL($5)])\n"
60+
+ " HiveJoin(condition=[=($2, $3)], joinType=[left], algorithm=[none], cost=[not available])\n"
61+
+ " LogicalTableScan(table=[[t1]])\n"
62+
+ " LogicalTableScan(table=[[t2]])\n";
63+
// @formatter:on
64+
65+
assertPlans(fixture.getPlanner(), plan, expectedPlan, expectedPlan);
66+
}
67+
68+
@Test
69+
public void testFilterIsNullFromBothSides() {
70+
PlanFixture fixture = fixture();
71+
72+
RelNode plan;
73+
try (Hook.Closeable ignore = Hook.REL_BUILDER_SIMPLIFY.addThread(Hook.propertyJ(false))) {
74+
RelBuilder b = fixture.createRelBuilder();
75+
// @formatter:off
76+
plan = b.scan("t1")
77+
.scan("t2")
78+
.join(JoinRelType.LEFT, b.equals(b.field(2, 0, "t1nullable"), b.field(2, 1, "t2id")))
79+
.filter(b.isNull(b.call(SqlStdOperatorTable.PLUS, b.field("t2nullable"), b.field("t1nullable"))))
80+
.project(b.field("t1id")).build();
81+
// @formatter:on
82+
}
83+
84+
// @formatter:off
85+
String expectedPlan = "HiveProject(t1id=[$0])\n"
86+
+ " HiveFilter(condition=[IS NULL(+($5, $2))])\n"
87+
+ " HiveJoin(condition=[=($2, $3)], joinType=[left], algorithm=[none], cost=[not available])\n"
88+
+ " LogicalTableScan(table=[[t1]])\n"
89+
+ " LogicalTableScan(table=[[t2]])\n";
90+
// @formatter:on
91+
92+
assertPlans(fixture.getPlanner(), plan, expectedPlan, expectedPlan);
93+
}
94+
95+
@Test
96+
public void testFilterOnNotNullColumn() {
97+
PlanFixture fixture = fixture();
98+
RelBuilder b = fixture.createRelBuilder();
99+
100+
// @formatter:off
101+
RelNode plan = b
102+
.scan("t1")
103+
.scan("t2")
104+
.join(JoinRelType.LEFT, b.equals(
105+
b.field(2, 0, "t1nullable"),
106+
b.field(2, 1, "t2id")))
107+
.filter(b.isNull(b.field("t2notnull")))
108+
.project(b.field("t1id"))
109+
.build();
110+
111+
String prePlan = "HiveProject(t1id=[$0])\n"
112+
+ " HiveFilter(condition=[IS NULL($4)])\n"
113+
+ " HiveJoin(condition=[=($2, $3)], joinType=[left], algorithm=[none], cost=[not available])\n"
114+
+ " LogicalTableScan(table=[[t1]])\n"
115+
+ " LogicalTableScan(table=[[t2]])\n";
116+
117+
String postPlan = "HiveProject(t1id=[$0])\n"
118+
+ " HiveAntiJoin(condition=[=($2, $3)], joinType=[anti])\n"
119+
+ " LogicalTableScan(table=[[t1]])\n"
120+
+ " LogicalTableScan(table=[[t2]])\n";
121+
// @formatter:on
122+
123+
assertPlans(fixture.getPlanner(), plan, prePlan, postPlan);
124+
}
125+
126+
@Test
127+
public void testFilterOnNullAndNotNullColumn() {
128+
PlanFixture fixture = fixture();
129+
RelBuilder b = fixture.createRelBuilder();
130+
131+
// @formatter:off
132+
RelNode plan = b
133+
.scan("t1")
134+
.scan("t2")
135+
.join(JoinRelType.LEFT, b.equals(
136+
b.field(2, 0, "t1nullable"),
137+
b.field(2, 1, "t2id")))
138+
.filter(b.and(b.isNull(b.field("t2notnull")), b.isNull((b.field("t2nullable")))))
139+
.project(b.field("t1id"))
140+
.build();
141+
142+
String prePlan = "HiveProject(t1id=[$0])\n"
143+
+ " HiveFilter(condition=[AND(IS NULL($4), IS NULL($5))])\n"
144+
+ " HiveJoin(condition=[=($2, $3)], joinType=[left], algorithm=[none], cost=[not available])\n"
145+
+ " LogicalTableScan(table=[[t1]])\n"
146+
+ " LogicalTableScan(table=[[t2]])\n";
147+
148+
String postPlan = "HiveProject(t1id=[$0])\n"
149+
+ " HiveAntiJoin(condition=[=($2, $3)], joinType=[anti])\n"
150+
+ " LogicalTableScan(table=[[t1]])\n"
151+
+ " LogicalTableScan(table=[[t2]])\n";
152+
// @formatter:on
153+
154+
assertPlans(fixture.getPlanner(), plan, prePlan, postPlan);
155+
}
156+
157+
static class T1Record {
158+
public int t1id;
159+
public int t1notnull;
160+
public Integer t1nullable;
161+
}
162+
163+
static class T2Record {
164+
public int t2id;
165+
public int t2notnull;
166+
public Integer t2nullable;
167+
}
168+
}

0 commit comments

Comments
 (0)