Skip to content

Commit 7774284

Browse files
authored
Merge pull request pingcap#4 from windtalker/hanfei/join-merge
basic cbo support for broadcast join
2 parents 9724824 + f1f5163 commit 7774284

File tree

8 files changed

+220
-22
lines changed

8 files changed

+220
-22
lines changed

planner/core/exhaust_physical_plans.go

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,13 +1479,15 @@ func (p *LogicalJoin) exhaustPhysicalPlans(prop *property.PhysicalProperty) []Ph
14791479
return joins
14801480
}
14811481

1482-
func getAllDataSourceRowCount(plan LogicalPlan) int64 {
1482+
func getAllDataSourceTotalRowSize(plan LogicalPlan) float64 {
14831483
if ds, ok := plan.(*DataSource); ok {
1484-
return ds.statsInfo().Count()
1484+
rowCount := ds.statsInfo().Count()
1485+
rowSize := ds.TblColHists.GetTableAvgRowSize(ds.ctx, ds.schema.Columns, kv.StoreType(ds.preferStoreType), ds.handleCol != nil)
1486+
return float64(rowCount) * rowSize
14851487
}
1486-
ret := int64(0)
1488+
ret := float64(0)
14871489
for _, child := range plan.Children() {
1488-
ret += getAllDataSourceRowCount(child)
1490+
ret += getAllDataSourceTotalRowSize(child)
14891491
}
14901492
return ret
14911493
}
@@ -1511,28 +1513,30 @@ func (p *LogicalJoin) tryToGetBroadCastJoin(prop *property.PhysicalProperty) []P
15111513
LeftJoinKeys: lkeys,
15121514
RightJoinKeys: rkeys,
15131515
}
1514-
// todo: currently, build side is the one has less rowcont and global read side
1515-
// is the one has less datasource row count(which mean less remote read), need
1516-
// to use cbo to decide the build side and global read side
1517-
if p.children[0].statsInfo().Count() < p.children[1].statsInfo().Count() {
1518-
baseJoin.InnerChildIdx = 0
1519-
} else {
1520-
baseJoin.InnerChildIdx = 1
1516+
1517+
preferredBuildIndex := 0
1518+
if p.children[0].statsInfo().Count() > p.children[1].statsInfo().Count() {
1519+
preferredBuildIndex = 1
15211520
}
1522-
globalIndex := baseJoin.InnerChildIdx
1523-
if prop.TaskTp != property.CopTiFlashGlobalReadTaskType && getAllDataSourceRowCount(p.children[globalIndex]) > getAllDataSourceRowCount(p.children[1-globalIndex]) {
1524-
globalIndex = 1 - globalIndex
1521+
preferredGlobalIndex := preferredBuildIndex
1522+
if prop.TaskTp != property.CopTiFlashGlobalReadTaskType && getAllDataSourceTotalRowSize(p.children[preferredGlobalIndex]) > getAllDataSourceTotalRowSize(p.children[1 - preferredGlobalIndex]) {
1523+
preferredGlobalIndex = 1 - preferredGlobalIndex
15251524
}
1525+
// todo: currently, build side is the one has less rowcount and global read side
1526+
// is the one has less datasource row size(which mean less remote read), need
1527+
// to use cbo to decide the build side and global read side if preferred build index
1528+
// is not equal to preferred global index
1529+
baseJoin.InnerChildIdx = preferredBuildIndex
15261530
childrenReqProps := make([]*property.PhysicalProperty, 2)
1527-
childrenReqProps[globalIndex] = &property.PhysicalProperty{TaskTp: property.CopTiFlashGlobalReadTaskType}
1531+
childrenReqProps[preferredGlobalIndex] = &property.PhysicalProperty{TaskTp: property.CopTiFlashGlobalReadTaskType}
15281532
if prop.TaskTp == property.CopTiFlashGlobalReadTaskType {
1529-
childrenReqProps[1-globalIndex] = &property.PhysicalProperty{TaskTp: property.CopTiFlashGlobalReadTaskType}
1533+
childrenReqProps[1-preferredGlobalIndex] = &property.PhysicalProperty{TaskTp: property.CopTiFlashGlobalReadTaskType}
15301534
} else {
1531-
childrenReqProps[1-globalIndex] = &property.PhysicalProperty{TaskTp: property.CopTiFlashLocalReadTaskType}
1535+
childrenReqProps[1-preferredGlobalIndex] = &property.PhysicalProperty{TaskTp: property.CopTiFlashLocalReadTaskType}
15321536
}
15331537
join := PhysicalBroadCastJoin{
15341538
basePhysicalJoin: baseJoin,
1535-
globalChildIndex: globalIndex,
1539+
globalChildIndex: preferredGlobalIndex,
15361540
}.Init(p.ctx, p.stats, p.blockOffset, childrenReqProps...)
15371541
results := make([]PhysicalPlan, 0, 1)
15381542
results = append(results, join)

planner/core/find_best_task.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,9 @@ func (ds *DataSource) getOriginalPhysicalTableScan(prop *property.PhysicalProper
13451345
}
13461346
sessVars := ds.ctx.GetSessionVars()
13471347
cost := rowCount * rowSize * sessVars.ScanFactor
1348+
if ts.IsGlobalRead {
1349+
cost += rowCount * sessVars.NetworkFactor * rowSize
1350+
}
13481351
if isMatchProp {
13491352
if prop.Items[0].Desc {
13501353
ts.Desc = true

planner/core/integration_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,68 @@ func (s *testIntegrationSerialSuite) TestSelPushDownTiFlash(c *C) {
341341
}
342342
}
343343

344+
func (s *testIntegrationSerialSuite) TestBroadcastJoin(c *C) {
345+
tk := testkit.NewTestKit(c, s.store)
346+
tk.MustExec("use test")
347+
tk.MustExec("drop table if exists d1_t")
348+
tk.MustExec("create table d1_t(d1_k int, value int)")
349+
tk.MustExec("insert into d1_t values(1,2),(2,3)")
350+
tk.MustExec("analyze table d1_t")
351+
tk.MustExec("drop table if exists d2_t")
352+
tk.MustExec("create table d2_t(d2_k decimal(10,2), value int)")
353+
tk.MustExec("insert into d2_t values(10.11,2),(10.12,3)")
354+
tk.MustExec("analyze table d2_t")
355+
tk.MustExec("drop table if exists d3_t")
356+
tk.MustExec("create table d3_t(d3_k date, value int)")
357+
tk.MustExec("insert into d3_t values(date'2010-01-01',2),(date'2010-01-02',3)")
358+
tk.MustExec("analyze table d3_t")
359+
tk.MustExec("drop table if exists fact_t")
360+
tk.MustExec("create table fact_t(d1_k int, d2_k decimal(10,2), d3_k date, col1 int, col2 int, col3 int)")
361+
tk.MustExec("insert into fact_t values(1,10.11,date'2010-01-01',1,2,3),(1,10.11,date'2010-01-02',1,2,3),(1,10.12,date'2010-01-01',1,2,3),(1,10.12,date'2010-01-02',1,2,3)")
362+
tk.MustExec("insert into fact_t values(2,10.11,date'2010-01-01',1,2,3),(2,10.11,date'2010-01-02',1,2,3),(2,10.12,date'2010-01-01',1,2,3),(2,10.12,date'2010-01-02',1,2,3)")
363+
tk.MustExec("analyze table fact_t")
364+
365+
// Create virtual tiflash replica info.
366+
dom := domain.GetDomain(tk.Se)
367+
is := dom.InfoSchema()
368+
db, exists := is.SchemaByName(model.NewCIStr("test"))
369+
c.Assert(exists, IsTrue)
370+
for _, tblInfo := range db.Tables {
371+
if tblInfo.Name.L == "fact_t" || tblInfo.Name.L == "d1_t" || tblInfo.Name.L == "d2_t" || tblInfo.Name.L == "d3_t" {
372+
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
373+
Count: 1,
374+
Available: true,
375+
}
376+
}
377+
}
378+
379+
tk.MustExec("set @@session.tidb_isolation_read_engines = 'tiflash'")
380+
tk.MustExec("set @@session.tidb_opt_broadcast_join = 1")
381+
var input []string
382+
var output []struct {
383+
SQL string
384+
Plan []string
385+
}
386+
s.testData.GetTestCases(c, &input, &output)
387+
for i, tt := range input {
388+
s.testData.OnRecord(func() {
389+
output[i].SQL = tt
390+
output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
391+
})
392+
res := tk.MustQuery(tt)
393+
res.Check(testkit.Rows(output[i].Plan...))
394+
}
395+
396+
// out join not supported
397+
_, err := tk.Exec("explain select /*+ tidb_bcj(fact_t, d1_t) */ count(*) from fact_t left join d1_t on fact_t.d1_k = d1_t.d1_k")
398+
c.Assert(err, NotNil)
399+
c.Assert(err.Error(), Equals, "[planner:1815]Internal : Can't find a proper physical plan for this query")
400+
// join with non-equal condition not supported
401+
_, err = tk.Exec("explain select /*+ tidb_bcj(fact_t, d1_t) */ count(*) from fact_t join d1_t on fact_t.d1_k = d1_t.d1_k and fact_t.col1 > d1_t.value")
402+
c.Assert(err, NotNil)
403+
c.Assert(err.Error(), Equals, "[planner:1815]Internal : Can't find a proper physical plan for this query")
404+
}
405+
344406
func (s *testIntegrationSerialSuite) TestIssue15110(c *C) {
345407
tk := testkit.NewTestKit(c, s.store)
346408
tk.MustExec("use test")

planner/core/task.go

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"github.com/pingcap/tidb/planner/property"
2929
"github.com/pingcap/tidb/sessionctx"
3030
"github.com/pingcap/tidb/statistics"
31+
"github.com/pingcap/tidb/store/tikv"
3132
"github.com/pingcap/tidb/types"
3233
"github.com/pingcap/tidb/util/chunk"
3334
"github.com/pingcap/tidb/util/plancodec"
@@ -533,9 +534,45 @@ func (p *PhysicalHashJoin) attach2Task(tasks ...task) task {
533534
return task
534535
}
535536

537+
// GetCost computes cost of broadcast join operator itself.
538+
func (p *PhysicalBroadCastJoin) GetCost(lCnt, rCnt float64) float64 {
539+
buildCnt := lCnt
540+
if p.InnerChildIdx == 1 {
541+
buildCnt = rCnt
542+
}
543+
sessVars := p.ctx.GetSessionVars()
544+
// Cost of building hash table.
545+
cpuCost := buildCnt * sessVars.CopCPUFactor
546+
memoryCost := buildCnt * sessVars.MemoryFactor
547+
// Number of matched row pairs regarding the equal join conditions.
548+
helper := &fullJoinRowCountHelper{
549+
cartesian: false,
550+
leftProfile: p.children[0].statsInfo(),
551+
rightProfile: p.children[1].statsInfo(),
552+
leftJoinKeys: p.LeftJoinKeys,
553+
rightJoinKeys: p.RightJoinKeys,
554+
leftSchema: p.children[0].Schema(),
555+
rightSchema: p.children[1].Schema(),
556+
}
557+
numPairs := helper.estimate()
558+
probeCost := numPairs * sessVars.CopCPUFactor
559+
// should divided by the cop concurrency, which is decide by TiFlash, but TiDB
560+
// can not get the information from TiFlash, so just use `sessVars.HashJoinConcurrency`
561+
// as a workaround
562+
probeCost /= float64(sessVars.HashJoinConcurrency)
563+
cpuCost += probeCost
564+
565+
// todo since TiFlash join is significant faster than TiDB join, maybe
566+
// need to add a variable like 'tiflash_accelerate_factor', and divide
567+
// the final cost by that factor
568+
return cpuCost + memoryCost
569+
}
570+
536571
func (p *PhysicalBroadCastJoin) attach2Task(tasks ...task) task {
537572
lTask, lok := tasks[0].(*copTask)
538573
rTask, rok := tasks[1].(*copTask)
574+
lGlobalRead := p.childrenReqProps[0].TaskTp == property.CopTiFlashGlobalReadTaskType
575+
rGlobalRead := p.childrenReqProps[1].TaskTp == property.CopTiFlashGlobalReadTaskType
539576
if !lok || !rok || (lTask.getStoreType() != kv.TiFlash && rTask.getStoreType() != kv.TiFlash) {
540577
return invalidTask
541578
}
@@ -547,11 +584,30 @@ func (p *PhysicalBroadCastJoin) attach2Task(tasks ...task) task {
547584
if !rTask.indexPlanFinished {
548585
rTask.finishIndexPlan()
549586
}
550-
task := &copTask{
551-
tblColHists: rTask.tblColHists,
587+
588+
lCost := lTask.cost()
589+
rCost := rTask.cost()
590+
if !(lGlobalRead && rGlobalRead) {
591+
// the cost model for top level broadcast join is
592+
// globalReadSideCost * copTaskNumber + localReadSideCost + broadcast operator cost
593+
// because for broadcast join, the global side is executed in every copTask.
594+
copTaskNumber := int32(1)
595+
copClient, ok := p.ctx.GetClient().(*tikv.CopClient)
596+
if ok {
597+
copTaskNumber = copClient.GetBatchCopTaskNumber()
598+
}
599+
if lGlobalRead {
600+
lCost = lCost * float64(copTaskNumber)
601+
} else {
602+
rCost = rCost * float64(copTaskNumber)
603+
}
604+
}
605+
606+
task := & copTask {
607+
tblColHists: rTask.tblColHists,
552608
indexPlanFinished: true,
553-
tablePlan: p,
554-
cst: lTask.cost() + rTask.cost(),
609+
tablePlan: p,
610+
cst: lCost + rCost + p.GetCost(lTask.count(), rTask.count()),
555611
}
556612
logutil.BgLogger().Info("bc join cost", zap.Float64("bc cost", task.cst))
557613
return task
@@ -708,6 +764,11 @@ func finishCopTask(ctx sessionctx.Context, task task) task {
708764
// is Min(DistSQLScanConcurrency, numRegionsInvolvedInScan), since we cannot infer
709765
// the number of regions involved, we simply use DistSQLScanConcurrency.
710766
copIterWorkers := float64(t.plan().SCtx().GetSessionVars().DistSQLScanConcurrency)
767+
if t.tablePlan != nil && t.tablePlan.TP() == plancodec.TypeBroadcastJoin {
768+
if copClient, ok := ctx.GetClient().(*tikv.CopClient); ok {
769+
copIterWorkers = math.Min(float64(copClient.GetBatchCopTaskNumber()), copIterWorkers)
770+
}
771+
}
711772
t.finishIndexPlan()
712773
// Network cost of transferring rows of table scan to TiDB.
713774
if t.tablePlan != nil {

planner/core/testdata/integration_serial_suite_in.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
"explain select * from t where cast(t.a as float) + 3 = 5.1"
77
]
88
},
9+
{
10+
"name": "TestBroadcastJoin",
11+
"cases": [
12+
"explain select /*+ tidb_bcj(fact_t,d1_t) */ count(*) from fact_t, d1_t where fact_t.d1_k = d1_t.d1_k",
13+
"explain select /*+ tidb_bcj(fact_t,d1_t,d2_t,d3_t) */ count(*) from fact_t, d1_t, d2_t, d3_t where fact_t.d1_k = d1_t.d1_k and fact_t.d2_k = d2_t.d2_k and fact_t.d3_k = d3_t.d3_k"
14+
]
15+
},
916
{
1017
"name": "TestReadFromStorageHint",
1118
"cases": [

planner/core/testdata/integration_serial_suite_out.json

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,43 @@
2020
}
2121
]
2222
},
23+
{
24+
"Name": "TestBroadcastJoin",
25+
"Cases": [
26+
{
27+
"SQL": "explain select /*+ tidb_bcj(fact_t,d1_t) */ count(*) from fact_t, d1_t where fact_t.d1_k = d1_t.d1_k",
28+
"Plan": [
29+
"StreamAgg_25 1.00 root funcs:count(Column#14)->Column#11",
30+
"└─TableReader_26 1.00 root data:StreamAgg_13",
31+
" └─StreamAgg_13 1.00 cop[tiflash] funcs:count(1)->Column#14",
32+
" └─TypeBroadcastJoin_24 8.00 cop[tiflash] ",
33+
" ├─Selection_18(Build) 0.00 cop[tiflash] not(isnull(test.d1_t.d1_k))",
34+
" │ └─TableFullScan_17 0.00 cop[tiflash] table:d1_t keep order:false, global read",
35+
" └─Selection_16(Probe) 0.00 cop[tiflash] not(isnull(test.fact_t.d1_k))",
36+
" └─TableFullScan_15 0.00 cop[tiflash] table:fact_t keep order:false"
37+
]
38+
},
39+
{
40+
"SQL": "explain select /*+ tidb_bcj(fact_t,d1_t,d2_t,d3_t) */ count(*) from fact_t, d1_t, d2_t, d3_t where fact_t.d1_k = d1_t.d1_k and fact_t.d2_k = d2_t.d2_k and fact_t.d3_k = d3_t.d3_k",
41+
"Plan": [
42+
"StreamAgg_35 1.00 root funcs:count(Column#20)->Column#17",
43+
"└─TableReader_36 1.00 root data:StreamAgg_17",
44+
" └─StreamAgg_17 1.00 cop[tiflash] funcs:count(1)->Column#20",
45+
" └─TypeBroadcastJoin_34 8.00 cop[tiflash] ",
46+
" ├─Selection_28(Build) 0.00 cop[tiflash] not(isnull(test.d3_t.d3_k))",
47+
" │ └─TableFullScan_27 0.00 cop[tiflash] table:d3_t keep order:false, global read",
48+
" └─TypeBroadcastJoin_19(Probe) 8.00 cop[tiflash] ",
49+
" ├─Selection_26(Build) 0.00 cop[tiflash] not(isnull(test.d2_t.d2_k))",
50+
" │ └─TableFullScan_25 0.00 cop[tiflash] table:d2_t keep order:false, global read",
51+
" └─TypeBroadcastJoin_20(Probe) 8.00 cop[tiflash] ",
52+
" ├─Selection_24(Build) 0.00 cop[tiflash] not(isnull(test.d1_t.d1_k))",
53+
" │ └─TableFullScan_23 0.00 cop[tiflash] table:d1_t keep order:false, global read",
54+
" └─Selection_22(Probe) 0.00 cop[tiflash] not(isnull(test.fact_t.d1_k)), not(isnull(test.fact_t.d2_k)), not(isnull(test.fact_t.d3_k))",
55+
" └─TableFullScan_21 8.00 cop[tiflash] table:fact_t keep order:false"
56+
]
57+
}
58+
]
59+
},
2360
{
2461
"Name": "TestReadFromStorageHint",
2562
"Cases": [

store/tikv/coprocessor.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ type CopClient struct {
5555
replicaReadSeed uint32
5656
}
5757

58+
func (c *CopClient) GetBatchCopTaskNumber() (ret int32) {
59+
ret = c.store.regionCache.storeMu.flashStoreNumber
60+
if ret <= 0 {
61+
ret = 1
62+
}
63+
return ret
64+
}
65+
5866
// Send builds the request and gets the coprocessor iterator response.
5967
func (c *CopClient) Send(ctx context.Context, req *kv.Request, vars *kv.Variables) kv.Response {
6068
if req.StoreType == kv.TiFlash && req.BatchCop {

store/tikv/region_cache.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ type RegionCache struct {
217217
storeMu struct {
218218
sync.RWMutex
219219
stores map[uint64]*Store
220+
flashStoreNumber int32
220221
}
221222
notifyDieCh chan []string
222223
notifyCheckCh chan struct{}
@@ -231,6 +232,7 @@ func NewRegionCache(pdClient pd.Client) *RegionCache {
231232
c.mu.regions = make(map[RegionVerID]*Region)
232233
c.mu.sorted = btree.New(btreeDegree)
233234
c.storeMu.stores = make(map[uint64]*Store)
235+
c.storeMu.flashStoreNumber = 0
234236
c.notifyCheckCh = make(chan struct{}, 1)
235237
c.notifyDieCh = make(chan []string, 1)
236238
c.closeCh = make(chan struct{})
@@ -951,6 +953,13 @@ func (c *RegionCache) getStoreAddr(bo *Backoffer, region *Region, store *Store,
951953
return
952954
case unresolved:
953955
addr, err = store.initResolve(bo, c)
956+
if store.storeType == kv.TiFlash {
957+
c.storeMu.Lock()
958+
if _,exists := c.storeMu.stores[store.storeID]; exists {
959+
c.storeMu.flashStoreNumber++
960+
}
961+
c.storeMu.Unlock()
962+
}
954963
return
955964
case deleted:
956965
addr = c.changeToActiveStore(region, store, storeIdx)
@@ -1371,7 +1380,14 @@ func (s *Store) reResolve(c *RegionCache) {
13711380
newStore := &Store{storeID: s.storeID, addr: addr, storeType: storeType}
13721381
newStore.state = *(*uint64)(unsafe.Pointer(&state))
13731382
c.storeMu.Lock()
1383+
orgStore,exists := c.storeMu.stores[newStore.storeID]
1384+
if exists && orgStore.storeType == kv.TiFlash {
1385+
c.storeMu.flashStoreNumber--
1386+
}
13741387
c.storeMu.stores[newStore.storeID] = newStore
1388+
if newStore.storeType == kv.TiFlash {
1389+
c.storeMu.flashStoreNumber++
1390+
}
13751391
c.storeMu.Unlock()
13761392
retryMarkDel:
13771393
// all region used those

0 commit comments

Comments
 (0)