Skip to content

Commit 648a642

Browse files
YangKeaozeminzhou
authored andcommitted
planner: fix the issue that the type of BatchGet with multiple columns is incorrect (pingcap#60524)
close pingcap#60523
1 parent 0c417d3 commit 648a642

File tree

2 files changed

+45
-8
lines changed

2 files changed

+45
-8
lines changed

pkg/planner/core/point_get_plan.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,15 +1178,13 @@ func newBatchPointGetPlan(
11781178
}
11791179
var values []types.Datum
11801180
var valuesParams []*expression.Constant
1181-
var pairs []nameValuePair
11821181
switch x := item.(type) {
11831182
case *ast.RowExpr:
11841183
// The `len(values) == len(valuesParams)` should be satisfied in this mode
11851184
if len(x.Values) != len(whereColNames) {
11861185
return nil
11871186
}
11881187
values = make([]types.Datum, len(x.Values))
1189-
pairs = make([]nameValuePair, 0, len(x.Values))
11901188
valuesParams = make([]*expression.Constant, len(x.Values))
11911189
initTypes := false
11921190
if indexTypes == nil { // only init once
@@ -1202,8 +1200,7 @@ func newBatchPointGetPlan(
12021200
if dval == nil {
12031201
return nil
12041202
}
1205-
values[permIndex] = innerX.Datum
1206-
pairs = append(pairs, nameValuePair{colName: whereColNames[index], value: innerX.Datum})
1203+
values[permIndex] = *dval
12071204
case *driver.ParamMarkerExpr:
12081205
con, err := expression.ParamMarkerExpression(ctx.GetExprCtx(), innerX, true)
12091206
if err != nil {
@@ -1217,12 +1214,11 @@ func newBatchPointGetPlan(
12171214
if dval == nil {
12181215
return nil
12191216
}
1220-
values[permIndex] = innerX.Datum
1217+
values[permIndex] = *dval
12211218
valuesParams[permIndex] = con
12221219
if initTypes {
12231220
indexTypes[permIndex] = &colInfos[index].FieldType
12241221
}
1225-
pairs = append(pairs, nameValuePair{colName: whereColNames[index], value: innerX.Datum})
12261222
default:
12271223
return nil
12281224
}
@@ -1239,7 +1235,6 @@ func newBatchPointGetPlan(
12391235
}
12401236
values = []types.Datum{*dval}
12411237
valuesParams = []*expression.Constant{nil}
1242-
pairs = append(pairs, nameValuePair{colName: whereColNames[0], value: *dval})
12431238
case *driver.ParamMarkerExpr:
12441239
if len(whereColNames) != 1 {
12451240
return nil
@@ -1261,7 +1256,6 @@ func newBatchPointGetPlan(
12611256
if indexTypes == nil { // only init once
12621257
indexTypes = []*types.FieldType{&colInfos[0].FieldType}
12631258
}
1264-
pairs = append(pairs, nameValuePair{colName: whereColNames[0], value: *dval})
12651259

12661260
default:
12671261
return nil

pkg/server/tests/commontest/tidb_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3353,3 +3353,46 @@ func TestAuthSocket(t *testing.T) {
33533353
ts.CheckRows(t, rows, "u2@%")
33543354
})
33553355
}
3356+
3357+
func TestBatchGetTypeForRowExpr(t *testing.T) {
3358+
ts := servertestkit.CreateTidbTestSuite(t)
3359+
3360+
// single columns
3361+
ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) {
3362+
dbt.MustExec("use test;")
3363+
dbt.MustExec("create table t1 (id varchar(255) collate utf8mb4_general_ci, primary key (id));")
3364+
dbt.MustExec("insert into t1 values ('a'), ('c');")
3365+
3366+
conn, err := dbt.GetDB().Conn(context.Background())
3367+
require.NoError(t, err)
3368+
defer func() {
3369+
require.NoError(t, conn.Close())
3370+
}()
3371+
_, err = conn.ExecContext(context.Background(), "set @@session.collation_connection = 'utf8mb4_general_ci'")
3372+
require.NoError(t, err)
3373+
stmt, err := conn.PrepareContext(context.Background(), "select * from t1 where id in (?, ?)")
3374+
require.NoError(t, err)
3375+
rows, err := stmt.Query("A", "C")
3376+
require.NoError(t, err)
3377+
ts.CheckRows(t, rows, "a\nc")
3378+
})
3379+
3380+
// multiple columns
3381+
ts.RunTests(t, nil, func(dbt *testkit.DBTestKit) {
3382+
dbt.MustExec("use test;")
3383+
dbt.MustExec("create table t2 (id1 varchar(255) collate utf8mb4_general_ci, id2 varchar(255) collate utf8mb4_general_ci, primary key (id1, id2));")
3384+
dbt.MustExec("insert into t2 values ('a', 'b'), ('c', 'd');")
3385+
3386+
conn, err := dbt.GetDB().Conn(context.Background())
3387+
require.NoError(t, err)
3388+
defer func() {
3389+
require.NoError(t, conn.Close())
3390+
}()
3391+
conn.ExecContext(context.Background(), "set @@session.collation_connection = 'utf8mb4_general_ci'")
3392+
stmt, err := conn.PrepareContext(context.Background(), "select * from t2 where (id1, id2) in ((?, ?), (?, ?))")
3393+
require.NoError(t, err)
3394+
rows, err := stmt.Query("A", "B", "C", "D")
3395+
require.NoError(t, err)
3396+
ts.CheckRows(t, rows, "a b\nc d")
3397+
})
3398+
}

0 commit comments

Comments
 (0)