Skip to content

Commit 311eef9

Browse files
authored
expression: introduce SessionEvalContext to implement EvalContext (#52091)
close #52089
1 parent b6dd179 commit 311eef9

File tree

9 files changed

+66
-87
lines changed

9 files changed

+66
-87
lines changed

pkg/executor/BUILD.bazel

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ go_library(
133133
"//pkg/expression",
134134
"//pkg/expression/aggregation",
135135
"//pkg/expression/context",
136-
"//pkg/expression/contextimpl",
137136
"//pkg/infoschema",
138137
"//pkg/keyspace",
139138
"//pkg/kv",
@@ -397,7 +396,6 @@ go_test(
397396
"//pkg/executor/sortexec",
398397
"//pkg/expression",
399398
"//pkg/expression/aggregation",
400-
"//pkg/expression/contextimpl",
401399
"//pkg/infoschema",
402400
"//pkg/kv",
403401
"//pkg/meta",

pkg/executor/cluster_table_test.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import (
2727
"github.com/pingcap/tidb/pkg/config"
2828
"github.com/pingcap/tidb/pkg/domain"
2929
"github.com/pingcap/tidb/pkg/expression"
30-
"github.com/pingcap/tidb/pkg/expression/contextimpl"
3130
"github.com/pingcap/tidb/pkg/parser"
3231
"github.com/pingcap/tidb/pkg/parser/auth"
3332
"github.com/pingcap/tidb/pkg/parser/mysql"
@@ -312,9 +311,7 @@ func TestSQLDigestTextRetriever(t *testing.T) {
312311
},
313312
}
314313

315-
sqlExec, err := contextimpl.NewSQLExecutor(tk.Session())
316-
require.NoError(t, err)
317-
err = r.RetrieveLocal(context.Background(), sqlExec)
314+
err := r.RetrieveLocal(context.Background(), tk.Session().GetRestrictedSQLExecutor())
318315
require.NoError(t, err)
319316
require.Equal(t, insertNormalized, r.SQLDigestsMap[insertDigest.String()])
320317
require.Equal(t, "", r.SQLDigestsMap[updateDigest.String()])

pkg/executor/infoschema_reader.go

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ import (
4040
"github.com/pingcap/tidb/pkg/executor/internal/exec"
4141
"github.com/pingcap/tidb/pkg/executor/internal/pdhelper"
4242
"github.com/pingcap/tidb/pkg/expression"
43-
"github.com/pingcap/tidb/pkg/expression/contextimpl"
4443
"github.com/pingcap/tidb/pkg/infoschema"
4544
"github.com/pingcap/tidb/pkg/kv"
4645
"github.com/pingcap/tidb/pkg/meta/autoid"
@@ -2718,11 +2717,9 @@ func (e *tidbTrxTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Co
27182717
e.batchRetrieverHelper.batchSize = 1024
27192718
}
27202719

2721-
sqlExec, err := contextimpl.NewSQLExecutor(sctx)
2722-
if err != nil {
2723-
return nil, err
2724-
}
2720+
sqlExec := sctx.GetRestrictedSQLExecutor()
27252721

2722+
var err error
27262723
// The current TiDB node's address is needed by the CLUSTER_TIDB_TRX table.
27272724
var instanceAddr string
27282725
if e.table.Name.O == infoschema.ClusterTableTiDBTrx {
@@ -2871,12 +2868,7 @@ func (r *dataLockWaitsTableRetriever) retrieve(ctx context.Context, sctx session
28712868
}
28722869
}
28732870

2874-
sqlExec, err := contextimpl.NewSQLExecutor(sctx)
2875-
if err != nil {
2876-
return errors.Trace(err)
2877-
}
2878-
2879-
err = sqlRetriever.RetrieveGlobal(ctx, sqlExec)
2871+
err := sqlRetriever.RetrieveGlobal(ctx, sctx.GetRestrictedSQLExecutor())
28802872
if err != nil {
28812873
return errors.Trace(err)
28822874
}
@@ -3074,11 +3066,7 @@ func (r *deadlocksTableRetriever) retrieve(ctx context.Context, sctx sessionctx.
30743066
}
30753067
// Retrieve the SQL texts if necessary.
30763068
if sqlRetriever != nil {
3077-
sqlExec, err := contextimpl.NewSQLExecutor(sctx)
3078-
if err != nil {
3079-
return errors.Trace(err)
3080-
}
3081-
err1 := sqlRetriever.RetrieveGlobal(ctx, sqlExec)
3069+
err1 := sqlRetriever.RetrieveGlobal(ctx, sctx.GetRestrictedSQLExecutor())
30823070
if err1 != nil {
30833071
return errors.Trace(err1)
30843072
}

pkg/expression/contextimpl/BUILD.bazel

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ go_library(
1010
"//pkg/expression/context",
1111
"//pkg/expression/contextopt",
1212
"//pkg/infoschema/context",
13-
"//pkg/parser/ast",
1413
"//pkg/parser/auth",
1514
"//pkg/parser/model",
1615
"//pkg/parser/mysql",
@@ -20,11 +19,8 @@ go_library(
2019
"//pkg/sessionctx/variable",
2120
"//pkg/types",
2221
"//pkg/util",
23-
"//pkg/util/chunk",
2422
"//pkg/util/intest",
2523
"//pkg/util/logutil",
26-
"//pkg/util/sqlexec",
27-
"@com_github_pingcap_errors//:errors",
2824
"@com_github_tikv_client_go_v2//oracle",
2925
"@org_uber_go_zap//:zap",
3026
],

pkg/expression/contextimpl/sessionctx.go

Lines changed: 44 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@ import (
1919
"math"
2020
"time"
2121

22-
"github.com/pingcap/errors"
2322
"github.com/pingcap/tidb/pkg/errctx"
2423
exprctx "github.com/pingcap/tidb/pkg/expression/context"
2524
"github.com/pingcap/tidb/pkg/expression/contextopt"
2625
infoschema "github.com/pingcap/tidb/pkg/infoschema/context"
27-
"github.com/pingcap/tidb/pkg/parser/ast"
2826
"github.com/pingcap/tidb/pkg/parser/auth"
2927
"github.com/pingcap/tidb/pkg/parser/model"
3028
"github.com/pingcap/tidb/pkg/parser/mysql"
@@ -34,10 +32,8 @@ import (
3432
"github.com/pingcap/tidb/pkg/sessionctx/variable"
3533
"github.com/pingcap/tidb/pkg/types"
3634
"github.com/pingcap/tidb/pkg/util"
37-
"github.com/pingcap/tidb/pkg/util/chunk"
3835
"github.com/pingcap/tidb/pkg/util/intest"
3936
"github.com/pingcap/tidb/pkg/util/logutil"
40-
"github.com/pingcap/tidb/pkg/util/sqlexec"
4137
"github.com/tikv/client-go/v2/oracle"
4238
"go.uber.org/zap"
4339
)
@@ -51,92 +47,104 @@ var _ exprctx.BuildContext = struct {
5147

5248
// ExprCtxExtendedImpl extends the sessionctx.Context to implement `expression.BuildContext`
5349
type ExprCtxExtendedImpl struct {
54-
sctx sessionctx.Context
55-
props contextopt.OptionalEvalPropProviders
50+
*SessionEvalContext
5651
}
5752

5853
// NewExprExtendedImpl creates a new ExprCtxExtendedImpl.
5954
func NewExprExtendedImpl(sctx sessionctx.Context) *ExprCtxExtendedImpl {
60-
impl := &ExprCtxExtendedImpl{sctx: sctx}
55+
return &ExprCtxExtendedImpl{
56+
SessionEvalContext: NewSessionEvalContext(sctx),
57+
}
58+
}
59+
60+
// SessionEvalContext implements the `expression.EvalContext` interface to provide evaluation context in session.
61+
type SessionEvalContext struct {
62+
sctx sessionctx.Context
63+
props contextopt.OptionalEvalPropProviders
64+
}
65+
66+
// NewSessionEvalContext creates a new SessionEvalContext.
67+
func NewSessionEvalContext(sctx sessionctx.Context) *SessionEvalContext {
68+
ctx := &SessionEvalContext{sctx: sctx}
6169
// set all optional properties
62-
impl.setOptionalProp(currentUserProp(sctx))
63-
impl.setOptionalProp(contextopt.NewSessionVarsProvider(sctx))
64-
impl.setOptionalProp(infoSchemaProp(sctx))
65-
impl.setOptionalProp(contextopt.KVStorePropProvider(sctx.GetStore))
66-
impl.setOptionalProp(sqlExecutorProp(sctx))
67-
impl.setOptionalProp(sequenceOperatorProp(sctx))
68-
impl.setOptionalProp(contextopt.NewAdvisoryLockPropProvider(sctx))
69-
impl.setOptionalProp(contextopt.DDLOwnerInfoProvider(sctx.IsDDLOwner))
70+
ctx.setOptionalProp(currentUserProp(sctx))
71+
ctx.setOptionalProp(contextopt.NewSessionVarsProvider(sctx))
72+
ctx.setOptionalProp(infoSchemaProp(sctx))
73+
ctx.setOptionalProp(contextopt.KVStorePropProvider(sctx.GetStore))
74+
ctx.setOptionalProp(sqlExecutorProp(sctx))
75+
ctx.setOptionalProp(sequenceOperatorProp(sctx))
76+
ctx.setOptionalProp(contextopt.NewAdvisoryLockPropProvider(sctx))
77+
ctx.setOptionalProp(contextopt.DDLOwnerInfoProvider(sctx.IsDDLOwner))
7078
// When EvalContext is created from a session, it should contain all the optional properties.
71-
intest.Assert(impl.props.PropKeySet().IsFull())
72-
return impl
79+
intest.Assert(ctx.props.PropKeySet().IsFull())
80+
return ctx
7381
}
7482

75-
func (ctx *ExprCtxExtendedImpl) setOptionalProp(prop exprctx.OptionalEvalPropProvider) {
83+
func (ctx *SessionEvalContext) setOptionalProp(prop exprctx.OptionalEvalPropProvider) {
7684
intest.AssertFunc(func() bool {
7785
return !ctx.props.Contains(prop.Desc().Key())
7886
})
7987
ctx.props.Add(prop)
8088
}
8189

8290
// CtxID returns the context id.
83-
func (ctx *ExprCtxExtendedImpl) CtxID() uint64 {
91+
func (ctx *SessionEvalContext) CtxID() uint64 {
8492
return ctx.sctx.GetSessionVars().StmtCtx.CtxID()
8593
}
8694

8795
// SQLMode returns the sql mode
88-
func (ctx *ExprCtxExtendedImpl) SQLMode() mysql.SQLMode {
96+
func (ctx *SessionEvalContext) SQLMode() mysql.SQLMode {
8997
return ctx.sctx.GetSessionVars().SQLMode
9098
}
9199

92100
// TypeCtx returns the types.Context
93-
func (ctx *ExprCtxExtendedImpl) TypeCtx() types.Context {
101+
func (ctx *SessionEvalContext) TypeCtx() types.Context {
94102
return ctx.sctx.GetSessionVars().StmtCtx.TypeCtx()
95103
}
96104

97105
// ErrCtx returns the errctx.Context
98-
func (ctx *ExprCtxExtendedImpl) ErrCtx() errctx.Context {
106+
func (ctx *SessionEvalContext) ErrCtx() errctx.Context {
99107
return ctx.sctx.GetSessionVars().StmtCtx.ErrCtx()
100108
}
101109

102110
// Location returns the timezone info
103-
func (ctx *ExprCtxExtendedImpl) Location() *time.Location {
111+
func (ctx *SessionEvalContext) Location() *time.Location {
104112
tc := ctx.TypeCtx()
105113
return tc.Location()
106114
}
107115

108116
// AppendWarning append warnings to the context.
109-
func (ctx *ExprCtxExtendedImpl) AppendWarning(err error) {
117+
func (ctx *SessionEvalContext) AppendWarning(err error) {
110118
ctx.sctx.GetSessionVars().StmtCtx.AppendWarning(err)
111119
}
112120

113121
// WarningCount gets warning count.
114-
func (ctx *ExprCtxExtendedImpl) WarningCount() int {
122+
func (ctx *SessionEvalContext) WarningCount() int {
115123
return int(ctx.sctx.GetSessionVars().StmtCtx.WarningCount())
116124
}
117125

118126
// TruncateWarnings truncates warnings begin from start and returns the truncated warnings.
119-
func (ctx *ExprCtxExtendedImpl) TruncateWarnings(start int) []stmtctx.SQLWarn {
127+
func (ctx *SessionEvalContext) TruncateWarnings(start int) []stmtctx.SQLWarn {
120128
return ctx.sctx.GetSessionVars().StmtCtx.TruncateWarnings(start)
121129
}
122130

123131
// CurrentDB returns the current database name
124-
func (ctx *ExprCtxExtendedImpl) CurrentDB() string {
132+
func (ctx *SessionEvalContext) CurrentDB() string {
125133
return ctx.sctx.GetSessionVars().CurrentDB
126134
}
127135

128136
// CurrentTime returns the current time
129-
func (ctx *ExprCtxExtendedImpl) CurrentTime() (time.Time, error) {
137+
func (ctx *SessionEvalContext) CurrentTime() (time.Time, error) {
130138
return getStmtTimestamp(ctx.sctx)
131139
}
132140

133141
// GetMaxAllowedPacket returns the value of the 'max_allowed_packet' system variable.
134-
func (ctx *ExprCtxExtendedImpl) GetMaxAllowedPacket() uint64 {
142+
func (ctx *SessionEvalContext) GetMaxAllowedPacket() uint64 {
135143
return ctx.sctx.GetSessionVars().MaxAllowedPacket
136144
}
137145

138146
// GetDefaultWeekFormatMode returns the value of the 'default_week_format' system variable.
139-
func (ctx *ExprCtxExtendedImpl) GetDefaultWeekFormatMode() string {
147+
func (ctx *SessionEvalContext) GetDefaultWeekFormatMode() string {
140148
mode, ok := ctx.sctx.GetSessionVars().GetSystemVar(variable.DefaultWeekFormat)
141149
if !ok || mode == "" {
142150
return "0"
@@ -145,22 +153,22 @@ func (ctx *ExprCtxExtendedImpl) GetDefaultWeekFormatMode() string {
145153
}
146154

147155
// GetDivPrecisionIncrement returns the specified value of DivPrecisionIncrement.
148-
func (ctx *ExprCtxExtendedImpl) GetDivPrecisionIncrement() int {
156+
func (ctx *SessionEvalContext) GetDivPrecisionIncrement() int {
149157
return ctx.sctx.GetSessionVars().GetDivPrecisionIncrement()
150158
}
151159

152160
// GetOptionalPropSet gets the optional property set from context
153-
func (ctx *ExprCtxExtendedImpl) GetOptionalPropSet() exprctx.OptionalEvalPropKeySet {
161+
func (ctx *SessionEvalContext) GetOptionalPropSet() exprctx.OptionalEvalPropKeySet {
154162
return ctx.props.PropKeySet()
155163
}
156164

157165
// GetOptionalPropProvider gets the optional property provider by key
158-
func (ctx *ExprCtxExtendedImpl) GetOptionalPropProvider(key exprctx.OptionalEvalPropKey) (exprctx.OptionalEvalPropProvider, bool) {
166+
func (ctx *SessionEvalContext) GetOptionalPropProvider(key exprctx.OptionalEvalPropKey) (exprctx.OptionalEvalPropProvider, bool) {
159167
return ctx.props.Get(key)
160168
}
161169

162170
// RequestVerification verifies user privilege
163-
func (ctx *ExprCtxExtendedImpl) RequestVerification(db, table, column string, priv mysql.PrivilegeType) bool {
171+
func (ctx *SessionEvalContext) RequestVerification(db, table, column string, priv mysql.PrivilegeType) bool {
164172
checker := privilege.GetPrivilegeManager(ctx.sctx)
165173
if checker == nil {
166174
return true
@@ -169,7 +177,7 @@ func (ctx *ExprCtxExtendedImpl) RequestVerification(db, table, column string, pr
169177
}
170178

171179
// RequestDynamicVerification verifies user privilege for a DYNAMIC privilege.
172-
func (ctx *ExprCtxExtendedImpl) RequestDynamicVerification(privName string, grantable bool) bool {
180+
func (ctx *SessionEvalContext) RequestDynamicVerification(privName string, grantable bool) bool {
173181
checker := privilege.GetPrivilegeManager(ctx.sctx)
174182
if checker == nil {
175183
return true
@@ -223,27 +231,9 @@ func infoSchemaProp(sctx sessionctx.Context) contextopt.InfoSchemaPropProvider {
223231
}
224232
}
225233

226-
type sqlExecutor struct {
227-
exec sqlexec.RestrictedSQLExecutor
228-
}
229-
230-
// NewSQLExecutor creates a new SQLExecutor.
231-
func NewSQLExecutor(sctx sessionctx.Context) (contextopt.SQLExecutor, error) {
232-
if e, ok := sctx.(sqlexec.RestrictedSQLExecutor); ok {
233-
return &sqlExecutor{exec: e}, nil
234-
}
235-
return nil, errors.Errorf("'%T' cannot be casted to sqlexec.RestrictedSQLExecutor", sctx)
236-
}
237-
238-
func (e *sqlExecutor) ExecRestrictedSQL(
239-
ctx context.Context, sql string, args ...any,
240-
) ([]chunk.Row, []*ast.ResultField, error) {
241-
return e.exec.ExecRestrictedSQL(ctx, nil, sql, args...)
242-
}
243-
244234
func sqlExecutorProp(sctx sessionctx.Context) contextopt.SQLExecutorPropProvider {
245235
return func() (contextopt.SQLExecutor, error) {
246-
return NewSQLExecutor(sctx)
236+
return sctx.GetRestrictedSQLExecutor(), nil
247237
}
248238
}
249239

pkg/expression/contextimpl/sessionctx_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import (
3535
"github.com/tikv/client-go/v2/oracle"
3636
)
3737

38-
func TestEvalContextImplWithSessionCtx(t *testing.T) {
38+
func TestSessionEvalContextBasic(t *testing.T) {
3939
ctx := mock.NewContext()
4040
vars := ctx.GetSessionVars()
4141
sc := vars.StmtCtx
@@ -93,7 +93,7 @@ func TestEvalContextImplWithSessionCtx(t *testing.T) {
9393
require.Equal(t, "err1", warnings[0].Err.Error())
9494
}
9595

96-
func TestEvalContextImplCurrentTime(t *testing.T) {
96+
func TestSessionEvalContextCurrentTime(t *testing.T) {
9797
ctx := mock.NewContext()
9898
vars := ctx.GetSessionVars()
9999
sc := vars.StmtCtx
@@ -162,7 +162,7 @@ func (m *mockPrivManager) RequestDynamicVerification(
162162
return m.Called(activeRoles, privName, grantable).Bool(0)
163163
}
164164

165-
func TestEvalContextImplPrivilegeCheck(t *testing.T) {
165+
func TestSessionEvalContextPrivilegeCheck(t *testing.T) {
166166
ctx := mock.NewContext()
167167
impl := contextimpl.NewExprExtendedImpl(ctx)
168168
activeRoles := []*auth.RoleIdentity{
@@ -212,7 +212,7 @@ func getProvider[T context.OptionalEvalPropProvider](
212212
return p
213213
}
214214

215-
func TestEvalContextImplWithSessionCtxForOptProps(t *testing.T) {
215+
func TestSessionEvalContextOptProps(t *testing.T) {
216216
ctx := mock.NewContext()
217217
impl := contextimpl.NewExprExtendedImpl(ctx)
218218

pkg/expression/contextopt/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ go_library(
2424
"//pkg/sessionctx/variable",
2525
"//pkg/util/chunk",
2626
"//pkg/util/intest",
27+
"//pkg/util/sqlexec",
2728
"@com_github_pingcap_errors//:errors",
2829
],
2930
)

pkg/expression/contextopt/sqlexec.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,21 @@ import (
2020
exprctx "github.com/pingcap/tidb/pkg/expression/context"
2121
"github.com/pingcap/tidb/pkg/parser/ast"
2222
"github.com/pingcap/tidb/pkg/util/chunk"
23+
"github.com/pingcap/tidb/pkg/util/sqlexec"
2324
)
2425

26+
// SQLExecutor provides a subset of methods in RestrictedSQLExecutor.
27+
var _ SQLExecutor = sqlexec.RestrictedSQLExecutor(nil)
28+
2529
// SQLExecutor is the interface for SQL executing in expression.
26-
// We do not `sqlexec.SQLExecutor` here to avoid to introduce too many dependencies in `sessionctx.Context`
30+
// We do not `sqlexec.SQLExecutor` to limit expression to use specified methods only.
2731
type SQLExecutor interface {
28-
ExecRestrictedSQL(ctx context.Context, sql string, args ...any) ([]chunk.Row, []*ast.ResultField, error)
32+
ExecRestrictedSQL(
33+
ctx context.Context,
34+
opts []sqlexec.OptionFuncAlias,
35+
sql string,
36+
args ...any,
37+
) ([]chunk.Row, []*ast.ResultField, error)
2938
}
3039

3140
// SQLExecutorPropProvider provides the SQLExecutor

0 commit comments

Comments
 (0)