Skip to content

Commit 8390fc4

Browse files
authored
planner: check binding validation when creating bindings (pingcap#58760)
ref pingcap#51347
1 parent 2a3542c commit 8390fc4

File tree

5 files changed

+106
-13
lines changed

5 files changed

+106
-13
lines changed

pkg/bindinfo/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ go_library(
2323
"//pkg/sessionctx/sessionstates",
2424
"//pkg/sessionctx/variable",
2525
"//pkg/types",
26+
"//pkg/types/parser_driver",
2627
"//pkg/util",
2728
"//pkg/util/chunk",
2829
"//pkg/util/hack",

pkg/bindinfo/binding.go

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
package bindinfo
1616

1717
import (
18+
"context"
19+
"fmt"
1820
"strings"
1921
"sync"
2022
"unsafe"
@@ -24,9 +26,11 @@ import (
2426
"github.com/pingcap/tidb/pkg/parser/ast"
2527
"github.com/pingcap/tidb/pkg/sessionctx"
2628
"github.com/pingcap/tidb/pkg/types"
29+
driver "github.com/pingcap/tidb/pkg/types/parser_driver"
2730
"github.com/pingcap/tidb/pkg/util/hint"
2831
utilparser "github.com/pingcap/tidb/pkg/util/parser"
2932
"github.com/pkg/errors"
33+
"go.uber.org/zap"
3034
)
3135

3236
const (
@@ -264,7 +268,7 @@ func (*tableNameCollector) Leave(in ast.Node) (out ast.Node, ok bool) {
264268

265269
// prepareHints builds ID and Hint for Bindings. If sctx is not nil, we check if
266270
// the BindSQL is still valid.
267-
func prepareHints(_ sessionctx.Context, binding *Binding) (rerr error) {
271+
func prepareHints(sctx sessionctx.Context, binding *Binding) (rerr error) {
268272
defer func() {
269273
if r := recover(); r != nil {
270274
rerr = errors.Errorf("panic when preparing hints for binding %v, panic: %v", binding.BindSQL, r)
@@ -286,10 +290,16 @@ func prepareHints(_ sessionctx.Context, binding *Binding) (rerr error) {
286290
dbName = "*" // ues '*' for universal bindings
287291
}
288292

289-
hintsSet, _, warns, err := hint.ParseHintsSet(p, binding.BindSQL, binding.Charset, binding.Collation, dbName)
293+
hintsSet, stmt, warns, err := hint.ParseHintsSet(p, binding.BindSQL, binding.Charset, binding.Collation, dbName)
290294
if err != nil {
291295
return err
292296
}
297+
if !isCrossDB && !hasParam(stmt) {
298+
// TODO: how to check cross-db binding and bindings with parameters?
299+
if err = checkBindingValidation(sctx, binding.BindSQL); err != nil {
300+
return err
301+
}
302+
}
293303
hintsStr, err := hintsSet.Restore()
294304
if err != nil {
295305
return err
@@ -460,3 +470,59 @@ func eraseLastSemicolon(stmt ast.StmtNode) {
460470
stmt.SetText(nil, sql[:len(sql)-1])
461471
}
462472
}
473+
474+
type paramChecker struct {
475+
hasParam bool
476+
}
477+
478+
func (e *paramChecker) Enter(in ast.Node) (ast.Node, bool) {
479+
if _, ok := in.(*driver.ParamMarkerExpr); ok {
480+
e.hasParam = true
481+
return in, true
482+
}
483+
return in, false
484+
}
485+
486+
func (*paramChecker) Leave(in ast.Node) (ast.Node, bool) {
487+
return in, true
488+
}
489+
490+
// hasParam checks whether the statement contains any parameters.
491+
// For example, `create binding using select * from t where a=?` contains a parameter '?'.
492+
func hasParam(stmt ast.Node) bool {
493+
p := new(paramChecker)
494+
stmt.Accept(p)
495+
return p.hasParam
496+
}
497+
498+
// CheckBindingStmt checks whether the statement is valid.
499+
func checkBindingValidation(sctx sessionctx.Context, bindingSQL string) error {
500+
origVals := sctx.GetSessionVars().UsePlanBaselines
501+
sctx.GetSessionVars().UsePlanBaselines = false
502+
503+
// Usually passing a sprintf to ExecuteInternal is not recommended, but in this case
504+
// it is safe because ExecuteInternal does not permit MultiStatement execution. Thus,
505+
// the statement won't be able to "break out" from EXPLAIN.
506+
rs, err := exec(sctx, fmt.Sprintf("EXPLAIN FORMAT='hint' %s", bindingSQL))
507+
sctx.GetSessionVars().UsePlanBaselines = origVals
508+
if rs != nil {
509+
defer func() {
510+
// Audit log is collected in Close(), set InRestrictedSQL to avoid 'create sql binding' been recorded as 'explain'.
511+
origin := sctx.GetSessionVars().InRestrictedSQL
512+
sctx.GetSessionVars().InRestrictedSQL = true
513+
if rerr := rs.Close(); rerr != nil {
514+
bindingLogger().Error("close result set failed", zap.Error(rerr), zap.String("binding_sql", bindingSQL))
515+
}
516+
sctx.GetSessionVars().InRestrictedSQL = origin
517+
}()
518+
}
519+
if err != nil {
520+
return err
521+
}
522+
chk := rs.NewChunk(nil)
523+
err = rs.Next(context.TODO(), chk)
524+
if err != nil {
525+
return err
526+
}
527+
return nil
528+
}

pkg/bindinfo/tests/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ go_test(
1010
],
1111
flaky = True,
1212
race = "on",
13-
shard_count = 21,
13+
shard_count = 22,
1414
deps = [
1515
"//pkg/bindinfo",
1616
"//pkg/domain",

pkg/bindinfo/tests/bind_test.go

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,10 @@ func TestInvisibleIndex(t *testing.T) {
348348
tk.MustExec("use test")
349349
tk.MustExec("drop table if exists t")
350350
tk.MustExec("create table t(a int, b int, unique idx_a(a), index idx_b(b) invisible)")
351+
tk.MustContainErrMsg("create global binding for select * from t using select * from t use index(idx_b)",
352+
"[planner:1176]Key 'idx_b' doesn't exist in table 't'")
351353

352354
// Create bind using index
353-
tk.MustExec("create global binding for select * from t using select * from t use index(idx_b)")
354355
tk.MustExec("create global binding for select * from t using select * from t use index(idx_a)")
355356

356357
tk.MustQuery("select * from t")
@@ -851,3 +852,34 @@ func TestBatchDropBindings(t *testing.T) {
851852
removeAllBindings(tk, true)
852853
removeAllBindings(tk, false)
853854
}
855+
856+
func TestInvalidBindingCheck(t *testing.T) {
857+
store := testkit.CreateMockStore(t)
858+
tk := testkit.NewTestKit(t, store)
859+
tk.MustExec(`use test`)
860+
tk.MustExec(`create table t (a int, b int)`)
861+
862+
cases := []struct {
863+
SQL string
864+
Err string
865+
}{
866+
{"select * from t where c=1", "[planner:1054]Unknown column 'c' in 'where clause'"},
867+
{"select * from t where a=1 and c=1", "[planner:1054]Unknown column 'c' in 'where clause'"},
868+
{"select * from dbx.t", "[schema:1146]Table 'dbx.t' doesn't exist"},
869+
{"select * from t1", "[schema:1146]Table 'test.t1' doesn't exist"},
870+
{"select * from t1, t", "[schema:1146]Table 'test.t1' doesn't exist"},
871+
{"select * from t use index(c)", "[planner:1176]Key 'c' doesn't exist in table 't'"},
872+
}
873+
874+
for _, c := range cases {
875+
for _, scope := range []string{"session", "global"} {
876+
sql := fmt.Sprintf("create %v binding using %v", scope, c.SQL)
877+
tk.MustGetErrMsg(sql, c.Err)
878+
}
879+
}
880+
881+
// cross-db bindings or bindings with parameters can bypass the check, which is expected.
882+
// We'll optimize this check further in the future.
883+
tk.MustExec("create binding using select * from *.t where c=1")
884+
tk.MustExec("create binding using select * from t where c=?")
885+
}

pkg/sessionctx/sessionstates/session_states_test.go

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,9 +1293,7 @@ func TestSQLBinding(t *testing.T) {
12931293
tk.MustExec("drop table test.t1")
12941294
return nil
12951295
},
1296-
checkFunc: func(tk *testkit.TestKit, param any) {
1297-
require.Equal(t, 1, len(tk.MustQuery("show session bindings").Rows()))
1298-
},
1296+
restoreErr: errno.ErrNoSuchTable,
12991297
cleanFunc: func(tk *testkit.TestKit) {
13001298
tk.MustExec("create table test.t1(id int primary key, name varchar(10), key(name))")
13011299
},
@@ -1310,9 +1308,7 @@ func TestSQLBinding(t *testing.T) {
13101308
tk.MustExec("drop database test1")
13111309
return nil
13121310
},
1313-
checkFunc: func(tk *testkit.TestKit, param any) {
1314-
require.Equal(t, 1, len(tk.MustQuery("show session bindings").Rows()))
1315-
},
1311+
restoreErr: errno.ErrNoSuchTable,
13161312
},
13171313
{
13181314
// alter the table
@@ -1321,9 +1317,7 @@ func TestSQLBinding(t *testing.T) {
13211317
tk.MustExec("alter table test.t1 drop index name")
13221318
return nil
13231319
},
1324-
checkFunc: func(tk *testkit.TestKit, param any) {
1325-
require.Equal(t, 1, len(tk.MustQuery("show session bindings").Rows()))
1326-
},
1320+
restoreErr: errno.ErrKeyDoesNotExist,
13271321
cleanFunc: func(tk *testkit.TestKit) {
13281322
tk.MustExec("alter table test.t1 add index name(name)")
13291323
},

0 commit comments

Comments
 (0)