Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/bindinfo/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ go_library(
"//pkg/sessionctx/sessionstates",
"//pkg/sessionctx/variable",
"//pkg/types",
"//pkg/types/parser_driver",
"//pkg/util",
"//pkg/util/chunk",
"//pkg/util/hack",
Expand Down
70 changes: 68 additions & 2 deletions pkg/bindinfo/binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package bindinfo

import (
"context"
"fmt"
"strings"
"sync"
"unsafe"
Expand All @@ -24,9 +26,11 @@ import (
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/types"
driver "github.com/pingcap/tidb/pkg/types/parser_driver"
"github.com/pingcap/tidb/pkg/util/hint"
utilparser "github.com/pingcap/tidb/pkg/util/parser"
"github.com/pkg/errors"
"go.uber.org/zap"
)

const (
Expand Down Expand Up @@ -264,7 +268,7 @@ func (*tableNameCollector) Leave(in ast.Node) (out ast.Node, ok bool) {

// prepareHints builds ID and Hint for Bindings. If sctx is not nil, we check if
// the BindSQL is still valid.
func prepareHints(_ sessionctx.Context, binding *Binding) (rerr error) {
func prepareHints(sctx sessionctx.Context, binding *Binding) (rerr error) {
defer func() {
if r := recover(); r != nil {
rerr = errors.Errorf("panic when preparing hints for binding %v, panic: %v", binding.BindSQL, r)
Expand All @@ -286,10 +290,16 @@ func prepareHints(_ sessionctx.Context, binding *Binding) (rerr error) {
dbName = "*" // ues '*' for universal bindings
}

hintsSet, _, warns, err := hint.ParseHintsSet(p, binding.BindSQL, binding.Charset, binding.Collation, dbName)
hintsSet, stmt, warns, err := hint.ParseHintsSet(p, binding.BindSQL, binding.Charset, binding.Collation, dbName)
if err != nil {
return err
}
if !isCrossDB && !hasParam(stmt) {
// TODO: how to check cross-db binding and bindings with parameters?
if err = checkBindingValidation(sctx, binding.BindSQL); err != nil {
return err
}
}
hintsStr, err := hintsSet.Restore()
if err != nil {
return err
Expand Down Expand Up @@ -460,3 +470,59 @@ func eraseLastSemicolon(stmt ast.StmtNode) {
stmt.SetText(nil, sql[:len(sql)-1])
}
}

type paramChecker struct {
hasParam bool
}

func (e *paramChecker) Enter(in ast.Node) (ast.Node, bool) {
if _, ok := in.(*driver.ParamMarkerExpr); ok {
e.hasParam = true
return in, true
}
return in, false
}

func (*paramChecker) Leave(in ast.Node) (ast.Node, bool) {
return in, true
}

// hasParam checks whether the statement contains any parameters.
// For example, `create binding using select * from t where a=?` contains a parameter '?'.
func hasParam(stmt ast.Node) bool {
p := new(paramChecker)
stmt.Accept(p)
return p.hasParam
}

// CheckBindingStmt checks whether the statement is valid.
func checkBindingValidation(sctx sessionctx.Context, bindingSQL string) error {
origVals := sctx.GetSessionVars().UsePlanBaselines
sctx.GetSessionVars().UsePlanBaselines = false

// Usually passing a sprintf to ExecuteInternal is not recommended, but in this case
// it is safe because ExecuteInternal does not permit MultiStatement execution. Thus,
// the statement won't be able to "break out" from EXPLAIN.
rs, err := exec(sctx, fmt.Sprintf("EXPLAIN FORMAT='hint' %s", bindingSQL))
sctx.GetSessionVars().UsePlanBaselines = origVals
if rs != nil {
defer func() {
// Audit log is collected in Close(), set InRestrictedSQL to avoid 'create sql binding' been recorded as 'explain'.
origin := sctx.GetSessionVars().InRestrictedSQL
sctx.GetSessionVars().InRestrictedSQL = true
if rerr := rs.Close(); rerr != nil {
bindingLogger().Error("close result set failed", zap.Error(rerr), zap.String("binding_sql", bindingSQL))
}
sctx.GetSessionVars().InRestrictedSQL = origin
}()
}
if err != nil {
return err
}
chk := rs.NewChunk(nil)
err = rs.Next(context.TODO(), chk)
if err != nil {
return err
}
return nil
}
2 changes: 1 addition & 1 deletion pkg/bindinfo/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ go_test(
],
flaky = True,
race = "on",
shard_count = 21,
shard_count = 22,
deps = [
"//pkg/bindinfo",
"//pkg/domain",
Expand Down
34 changes: 33 additions & 1 deletion pkg/bindinfo/tests/bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,10 @@ func TestInvisibleIndex(t *testing.T) {
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int, b int, unique idx_a(a), index idx_b(b) invisible)")
tk.MustContainErrMsg("create global binding for select * from t using select * from t use index(idx_b)",
"[planner:1176]Key 'idx_b' doesn't exist in table 't'")

// Create bind using index
tk.MustExec("create global binding for select * from t using select * from t use index(idx_b)")
tk.MustExec("create global binding for select * from t using select * from t use index(idx_a)")

tk.MustQuery("select * from t")
Expand Down Expand Up @@ -851,3 +852,34 @@ func TestBatchDropBindings(t *testing.T) {
removeAllBindings(tk, true)
removeAllBindings(tk, false)
}

func TestInvalidBindingCheck(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec(`use test`)
tk.MustExec(`create table t (a int, b int)`)

cases := []struct {
SQL string
Err string
}{
{"select * from t where c=1", "[planner:1054]Unknown column 'c' in 'where clause'"},
{"select * from t where a=1 and c=1", "[planner:1054]Unknown column 'c' in 'where clause'"},
{"select * from dbx.t", "[schema:1146]Table 'dbx.t' doesn't exist"},
{"select * from t1", "[schema:1146]Table 'test.t1' doesn't exist"},
{"select * from t1, t", "[schema:1146]Table 'test.t1' doesn't exist"},
{"select * from t use index(c)", "[planner:1176]Key 'c' doesn't exist in table 't'"},
}

for _, c := range cases {
for _, scope := range []string{"session", "global"} {
sql := fmt.Sprintf("create %v binding using %v", scope, c.SQL)
tk.MustGetErrMsg(sql, c.Err)
}
}

// cross-db bindings or bindings with parameters can bypass the check, which is expected.
// We'll optimize this check further in the future.
tk.MustExec("create binding using select * from *.t where c=1")
tk.MustExec("create binding using select * from t where c=?")
}
12 changes: 3 additions & 9 deletions pkg/sessionctx/sessionstates/session_states_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1293,9 +1293,7 @@ func TestSQLBinding(t *testing.T) {
tk.MustExec("drop table test.t1")
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
require.Equal(t, 1, len(tk.MustQuery("show session bindings").Rows()))
},
restoreErr: errno.ErrNoSuchTable,
cleanFunc: func(tk *testkit.TestKit) {
tk.MustExec("create table test.t1(id int primary key, name varchar(10), key(name))")
},
Expand All @@ -1310,9 +1308,7 @@ func TestSQLBinding(t *testing.T) {
tk.MustExec("drop database test1")
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
require.Equal(t, 1, len(tk.MustQuery("show session bindings").Rows()))
},
restoreErr: errno.ErrNoSuchTable,
},
{
// alter the table
Expand All @@ -1321,9 +1317,7 @@ func TestSQLBinding(t *testing.T) {
tk.MustExec("alter table test.t1 drop index name")
return nil
},
checkFunc: func(tk *testkit.TestKit, param any) {
require.Equal(t, 1, len(tk.MustQuery("show session bindings").Rows()))
},
restoreErr: errno.ErrKeyDoesNotExist,
cleanFunc: func(tk *testkit.TestKit) {
tk.MustExec("alter table test.t1 add index name(name)")
},
Expand Down