Skip to content

Commit 2d6302d

Browse files
authored
extension: disable some optimizations for extension function (#51926) (#51946)
close #51925
1 parent 9fd6e81 commit 2d6302d

File tree

5 files changed

+133
-4
lines changed

5 files changed

+133
-4
lines changed

pkg/expression/constant_fold.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ func foldConstant(expr Expression) (Expression, bool) {
156156
if _, ok := unFoldableFunctions[x.FuncName.L]; ok {
157157
return expr, false
158158
}
159+
if _, ok := x.Function.(*extensionFuncSig); ok {
160+
// we should not fold the extension function, because it may have a side effect.
161+
return expr, false
162+
}
159163
if function := specialFoldHandler[x.FuncName.L]; function != nil && !MaybeOverOptimized4PlanCache(x.GetCtx(), []Expression{expr}) {
160164
return function(x)
161165
}

pkg/expression/extension.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func newExtensionFuncClass(def *extension.FunctionDef) (*extensionFuncClass, err
9797
}
9898

9999
func (c *extensionFuncClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
100-
if err := c.checkPrivileges(ctx); err != nil {
100+
if err := checkPrivileges(ctx, &c.funcDef); err != nil {
101101
return nil, err
102102
}
103103

@@ -108,13 +108,18 @@ func (c *extensionFuncClass) getFunction(ctx sessionctx.Context, args []Expressi
108108
if err != nil {
109109
return nil, err
110110
}
111+
112+
// Though currently, `getFunction` does not require too much information that makes it safe to be cached,
113+
// we still skip the plan cache for extension functions because there are no strong requirements to do it.
114+
// Skipping the plan cache can make the behavior simple.
115+
ctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.NewNoStackError("extension function should not be cached"))
111116
bf.tp.SetFlen(c.flen)
112117
sig := &extensionFuncSig{context.TODO(), bf, c.funcDef}
113118
return sig, nil
114119
}
115120

116-
func (c *extensionFuncClass) checkPrivileges(ctx sessionctx.Context) error {
117-
fn := c.funcDef.RequireDynamicPrivileges
121+
func checkPrivileges(ctx sessionctx.Context, fnDef *extension.FunctionDef) error {
122+
fn := fnDef.RequireDynamicPrivileges
118123
if fn == nil {
119124
return nil
120125
}
@@ -157,13 +162,21 @@ func (b *extensionFuncSig) Clone() builtinFunc {
157162
}
158163

159164
func (b *extensionFuncSig) evalString(row chunk.Row) (string, bool, error) {
165+
if err := checkPrivileges(b.ctx, &b.FunctionDef); err != nil {
166+
return "", true, err
167+
}
168+
160169
if b.EvalTp == types.ETString {
161170
return b.EvalStringFunc(b, row)
162171
}
163172
return b.baseBuiltinFunc.evalString(row)
164173
}
165174

166175
func (b *extensionFuncSig) evalInt(row chunk.Row) (int64, bool, error) {
176+
if err := checkPrivileges(b.ctx, &b.FunctionDef); err != nil {
177+
return 0, true, err
178+
}
179+
167180
if b.EvalTp == types.ETInt {
168181
return b.EvalIntFunc(b, row)
169182
}

pkg/expression/scalar_function.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,12 @@ func (sf *ScalarFunction) ConstItem(sc *stmtctx.StatementContext) bool {
366366
if _, ok := unFoldableFunctions[sf.FuncName.L]; ok {
367367
return false
368368
}
369+
370+
if _, ok := sf.Function.(*extensionFuncSig); ok {
371+
// we should return false for extension functions for safety, because it may have a side effect.
372+
return false
373+
}
374+
369375
for _, arg := range sf.GetArgs() {
370376
if !arg.ConstItem(sc) {
371377
return false

pkg/extension/BUILD.bazel

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ go_test(
3939
],
4040
embed = [":extension"],
4141
flaky = True,
42-
shard_count = 14,
42+
shard_count = 15,
4343
deps = [
4444
"//pkg/expression",
4545
"//pkg/parser/ast",
@@ -55,6 +55,7 @@ go_test(
5555
"//pkg/testkit/testsetup",
5656
"//pkg/types",
5757
"//pkg/util/chunk",
58+
"//pkg/util/mock",
5859
"//pkg/util/sem",
5960
"@com_github_pingcap_errors//:errors",
6061
"@com_github_stretchr_testify//require",

pkg/extension/function_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"fmt"
1919
"sort"
2020
"strings"
21+
"sync/atomic"
2122
"testing"
2223

2324
"github.com/pingcap/errors"
@@ -28,6 +29,7 @@ import (
2829
"github.com/pingcap/tidb/pkg/testkit"
2930
"github.com/pingcap/tidb/pkg/types"
3031
"github.com/pingcap/tidb/pkg/util/chunk"
32+
"github.com/pingcap/tidb/pkg/util/mock"
3133
"github.com/pingcap/tidb/pkg/util/sem"
3234
"github.com/stretchr/testify/require"
3335
)
@@ -318,6 +320,19 @@ func TestExtensionFuncPrivilege(t *testing.T) {
318320
return "ghi", false, nil
319321
},
320322
},
323+
{
324+
Name: "custom_eval_int_func",
325+
EvalTp: types.ETInt,
326+
RequireDynamicPrivileges: func(sem bool) []string {
327+
if sem {
328+
return []string{"RESTRICTED_CUSTOM_DYN_PRIV_2"}
329+
}
330+
return []string{"CUSTOM_DYN_PRIV_1"}
331+
},
332+
EvalIntFunc: func(ctx extension.FunctionContext, row chunk.Row) (int64, bool, error) {
333+
return 1, false, nil
334+
},
335+
},
321336
}),
322337
extension.WithCustomDynPrivs([]string{
323338
"CUSTOM_DYN_PRIV_1",
@@ -349,34 +364,43 @@ func TestExtensionFuncPrivilege(t *testing.T) {
349364
tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc"))
350365
tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def"))
351366
tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi"))
367+
tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1"))
352368

353369
// u1 in non-sem
354370
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "localhost"}, nil, nil, nil))
355371
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
356372
require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
357373
tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def"))
358374
require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
375+
require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
376+
377+
// prepare should check privilege
378+
require.EqualError(t, tk1.ExecToErr("prepare stmt1 from 'select custom_only_dyn_priv_func()'"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
379+
require.EqualError(t, tk1.ExecToErr("prepare stmt2 from 'select custom_eval_int_func()'"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
359380

360381
// u2 in non-sem
361382
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u2", Hostname: "localhost"}, nil, nil, nil))
362383
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
363384
tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc"))
364385
tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def"))
365386
tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi"))
387+
tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1"))
366388

367389
// u3 in non-sem
368390
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u3", Hostname: "localhost"}, nil, nil, nil))
369391
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
370392
require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
371393
tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def"))
372394
require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
395+
require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the SUPER or CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
373396

374397
// u4 in non-sem
375398
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u4", Hostname: "localhost"}, nil, nil, nil))
376399
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
377400
tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc"))
378401
tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def"))
379402
tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi"))
403+
tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1"))
380404

381405
sem.Enable()
382406

@@ -386,32 +410,113 @@ func TestExtensionFuncPrivilege(t *testing.T) {
386410
tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc"))
387411
require.EqualError(t, tk1.ExecToErr("select custom_only_sem_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
388412
require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
413+
require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
389414

390415
// u1 in sem
391416
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u1", Hostname: "localhost"}, nil, nil, nil))
392417
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
393418
require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
394419
require.EqualError(t, tk1.ExecToErr("select custom_only_sem_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
395420
require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
421+
require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
396422

397423
// u2 in sem
398424
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u2", Hostname: "localhost"}, nil, nil, nil))
399425
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
400426
tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc"))
401427
require.EqualError(t, tk1.ExecToErr("select custom_only_sem_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
402428
require.EqualError(t, tk1.ExecToErr("select custom_both_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
429+
require.EqualError(t, tk1.ExecToErr("select custom_eval_int_func()"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
403430

404431
// u3 in sem
405432
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u3", Hostname: "localhost"}, nil, nil, nil))
406433
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
407434
require.EqualError(t, tk1.ExecToErr("select custom_only_dyn_priv_func()"), "[expression:1227]Access denied; you need (at least one of) the CUSTOM_DYN_PRIV_1 privilege(s) for this operation")
408435
tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def"))
409436
tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi"))
437+
tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1"))
410438

411439
// u4 in sem
412440
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "u4", Hostname: "localhost"}, nil, nil, nil))
413441
tk1.MustQuery("select custom_no_priv_func()").Check(testkit.Rows("zzz"))
414442
tk1.MustQuery("select custom_only_dyn_priv_func()").Check(testkit.Rows("abc"))
415443
tk1.MustQuery("select custom_only_sem_dyn_priv_func()").Check(testkit.Rows("def"))
416444
tk1.MustQuery("select custom_both_dyn_priv_func()").Check(testkit.Rows("ghi"))
445+
tk1.MustQuery("select custom_eval_int_func()").Check(testkit.Rows("1"))
446+
447+
// Test the privilege should also be checked when evaluating especially for when privilege is revoked.
448+
tk1.MustExec("prepare s1 from 'select custom_both_dyn_priv_func()'")
449+
tk1.MustExec("prepare s2 from 'select custom_eval_int_func()'")
450+
tk1.MustQuery("execute s1").Check(testkit.Rows("ghi"))
451+
tk1.MustQuery("execute s2").Check(testkit.Rows("1"))
452+
tk.MustExec("REVOKE RESTRICTED_CUSTOM_DYN_PRIV_2 on *.* FROM u4@localhost")
453+
require.EqualError(t, tk1.ExecToErr("execute s1"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
454+
require.EqualError(t, tk1.ExecToErr("execute s2"), "[expression:1227]Access denied; you need (at least one of) the RESTRICTED_CUSTOM_DYN_PRIV_2 privilege(s) for this operation")
455+
}
456+
457+
func TestShouldNotOptimizeExtensionFunc(t *testing.T) {
458+
defer func() {
459+
extension.Reset()
460+
sem.Disable()
461+
}()
462+
463+
extension.Reset()
464+
var cnt atomic.Int64
465+
require.NoError(t, extension.Register("test",
466+
extension.WithCustomFunctions([]*extension.FunctionDef{
467+
{
468+
Name: "my_func1",
469+
EvalTp: types.ETInt,
470+
EvalIntFunc: func(ctx extension.FunctionContext, row chunk.Row) (int64, bool, error) {
471+
val := cnt.Add(1)
472+
return val, false, nil
473+
},
474+
},
475+
{
476+
Name: "my_func2",
477+
EvalTp: types.ETString,
478+
EvalStringFunc: func(ctx extension.FunctionContext, row chunk.Row) (string, bool, error) {
479+
val := cnt.Add(1)
480+
if val%2 == 0 {
481+
return "abc", false, nil
482+
}
483+
return "def", false, nil
484+
},
485+
},
486+
}),
487+
))
488+
489+
store := testkit.CreateMockStore(t)
490+
tk := testkit.NewTestKit(t, store)
491+
tk.MustExec("use test")
492+
tk.MustExec("create table t1(a int primary key)")
493+
tk.MustExec("insert into t1 values(1000), (2000)")
494+
495+
// Test extension function should not fold.
496+
// if my_func1 is folded, the result will be "1000 1", "2000 1",
497+
// because after fold the function will be called only once.
498+
tk.MustQuery("select a, my_func1() from t1 order by a").Check(testkit.Rows("1000 1", "2000 2"))
499+
require.Equal(t, int64(2), cnt.Load())
500+
501+
// Test extension function should not be seen as a constant, i.e., its `ConstantLevel()` should return `ConstNone`.
502+
// my_func2 should be called twice to return different regexp string for the below query.
503+
// If it is optimized by mistake, a wrong result "1000 0", "2000 0" will be produced.
504+
cnt.Store(0)
505+
tk.MustQuery("select a, 'abc' regexp my_func2() from t1 order by a").Check(testkit.Rows("1000 0", "2000 1"))
506+
507+
// Test flags after building expression
508+
for _, exprStr := range []string{
509+
"my_func1()",
510+
"my_func2()",
511+
} {
512+
ctx := mock.NewContext()
513+
ctx.GetSessionVars().StmtCtx.UseCache = true
514+
exprs, err := expression.ParseSimpleExprsWithNames(ctx, exprStr, nil, nil)
515+
require.NoError(t, err)
516+
require.Equal(t, 1, len(exprs))
517+
scalar, ok := exprs[0].(*expression.ScalarFunction)
518+
require.True(t, ok)
519+
require.False(t, scalar.ConstItem(ctx.GetSessionVars().StmtCtx))
520+
require.False(t, ctx.GetSessionVars().StmtCtx.UseCache)
521+
}
417522
}

0 commit comments

Comments
 (0)