Skip to content

Commit 38ab23b

Browse files
authored
expression: use the correct type when eval decimal and float session var (#51395)
close #43527
1 parent 42c8d2d commit 38ab23b

File tree

5 files changed

+76
-5
lines changed

5 files changed

+76
-5
lines changed

pkg/expression/builtin_other.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,11 @@ func (b *builtinGetRealVarSig) evalReal(ctx EvalContext, row chunk.Row) (float64
10521052
}
10531053
varName = strings.ToLower(varName)
10541054
if v, ok := sessionVars.GetUserVarVal(varName); ok {
1055-
return v.GetFloat64(), false, nil
1055+
d, err := v.ToFloat64(typeCtx(ctx))
1056+
if err != nil {
1057+
return 0, false, err
1058+
}
1059+
return d, false, nil
10561060
}
10571061
return 0, true, nil
10581062
}
@@ -1092,7 +1096,11 @@ func (b *builtinGetDecimalVarSig) evalDecimal(ctx EvalContext, row chunk.Row) (*
10921096
}
10931097
varName = strings.ToLower(varName)
10941098
if v, ok := sessionVars.GetUserVarVal(varName); ok {
1095-
return v.GetMysqlDecimal(), false, nil
1099+
d, err := v.ToDecimal(typeCtx(ctx))
1100+
if err != nil {
1101+
return nil, false, err
1102+
}
1103+
return d, false, nil
10961104
}
10971105
return nil, true, nil
10981106
}

pkg/expression/builtin_other_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,33 @@ func TestGetVar(t *testing.T) {
169169
}
170170
}
171171

172+
func TestTypeConversion(t *testing.T) {
173+
ctx := createContext(t)
174+
// Set value as int64
175+
key := "a"
176+
val := int64(3)
177+
ctx.GetSessionVars().SetUserVarVal(key, types.NewDatum(val))
178+
tp := types.NewFieldType(mysql.TypeLonglong)
179+
ctx.GetSessionVars().SetUserVarType(key, tp)
180+
181+
args := []any{"a"}
182+
// To Decimal.
183+
tp = types.NewFieldType(mysql.TypeNewDecimal)
184+
fn, err := BuildGetVarFunction(ctx, datumsToConstants(types.MakeDatums(args...))[0], tp)
185+
require.NoError(t, err)
186+
d, err := fn.Eval(ctx, chunk.Row{})
187+
require.NoError(t, err)
188+
des := types.NewDecFromInt(3)
189+
require.Equal(t, des, d.GetValue())
190+
// To Float.
191+
tp = types.NewFieldType(mysql.TypeDouble)
192+
fn, err = BuildGetVarFunction(ctx, datumsToConstants(types.MakeDatums(args...))[0], tp)
193+
require.NoError(t, err)
194+
d, err = fn.Eval(ctx, chunk.Row{})
195+
require.NoError(t, err)
196+
require.Equal(t, float64(3), d.GetValue())
197+
}
198+
172199
func TestValues(t *testing.T) {
173200
ctx := createContext(t)
174201
fc := &valuesFunctionClass{baseFunctionClass{ast.Values, 0, 0}, 1, types.NewFieldType(mysql.TypeVarchar)}

pkg/expression/builtin_other_vec.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,8 @@ func (b *builtinGetRealVarSig) vectorized() bool {
386386
return true
387387
}
388388

389+
// NOTE: get/set variable vectorized eval was disabled. See more in
390+
// https://github.com/pingcap/tidb/pull/8412
389391
func (b *builtinGetRealVarSig) vecEvalReal(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error {
390392
n := input.NumRows()
391393
buf0, err := b.bufAllocator.get()
@@ -406,7 +408,11 @@ func (b *builtinGetRealVarSig) vecEvalReal(ctx EvalContext, input *chunk.Chunk,
406408
}
407409
varName := strings.ToLower(buf0.GetString(i))
408410
if v, ok := sessionVars.GetUserVarVal(varName); ok {
409-
f64s[i] = v.GetFloat64()
411+
d, err := v.ToFloat64(typeCtx(ctx))
412+
if err != nil {
413+
return err
414+
}
415+
f64s[i] = d
410416
continue
411417
}
412418
result.SetNull(i, true)
@@ -418,6 +424,8 @@ func (b *builtinGetDecimalVarSig) vectorized() bool {
418424
return true
419425
}
420426

427+
// NOTE: get/set variable vectorized eval was disabled. See more in
428+
// https://github.com/pingcap/tidb/pull/8412
421429
func (b *builtinGetDecimalVarSig) vecEvalDecimal(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error {
422430
n := input.NumRows()
423431
buf0, err := b.bufAllocator.get()
@@ -438,7 +446,11 @@ func (b *builtinGetDecimalVarSig) vecEvalDecimal(ctx EvalContext, input *chunk.C
438446
}
439447
varName := strings.ToLower(buf0.GetString(i))
440448
if v, ok := sessionVars.GetUserVarVal(varName); ok {
441-
decs[i] = *v.GetMysqlDecimal()
449+
d, err := v.ToDecimal(typeCtx(ctx))
450+
if err != nil {
451+
return err
452+
}
453+
decs[i] = *d
442454
continue
443455
}
444456
result.SetNull(i, true)

pkg/expression/integration_test/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ go_test(
88
"main_test.go",
99
],
1010
flaky = True,
11-
shard_count = 24,
11+
shard_count = 25,
1212
deps = [
1313
"//pkg/config",
1414
"//pkg/domain",

pkg/expression/integration_test/integration_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2976,3 +2976,27 @@ func TestTiDBRowChecksumBuiltin(t *testing.T) {
29762976
tk.MustGetDBError("select tidb_row_checksum() from t", expression.ErrNotSupportedYet)
29772977
tk.MustGetDBError("select tidb_row_checksum() from t where id > 0", expression.ErrNotSupportedYet)
29782978
}
2979+
2980+
func TestIssue43527(t *testing.T) {
2981+
store := testkit.CreateMockStore(t)
2982+
tk := testkit.NewTestKit(t, store)
2983+
tk.MustExec("use test")
2984+
tk.MustExec("create table test (a datetime, b bigint, c decimal(10, 2), d float)")
2985+
tk.MustExec("insert into test values('2010-10-10 10:10:10', 100, 100.01, 100)")
2986+
// Decimal.
2987+
tk.MustQuery(
2988+
"SELECT @total := @total + c FROM (SELECT c FROM test) AS temp, (SELECT @total := 200) AS T1",
2989+
).Check(testkit.Rows("300.01"))
2990+
// Float.
2991+
tk.MustQuery(
2992+
"SELECT @total := @total + d FROM (SELECT d FROM test) AS temp, (SELECT @total := 200) AS T1",
2993+
).Check(testkit.Rows("300"))
2994+
tk.MustExec("insert into test values('2010-10-10 10:10:10', 100, 100.01, 100)")
2995+
// Vectorized.
2996+
// NOTE: Because https://github.com/pingcap/tidb/pull/8412 disabled the vectorized execution of get or set variable,
2997+
// the following test case will not be executed in vectorized mode.
2998+
// It will be executed in the normal mode.
2999+
tk.MustQuery(
3000+
"SELECT @total := @total + d FROM (SELECT d FROM test) AS temp, (SELECT @total := b FROM test) AS T1 where @total >= 100",
3001+
).Check(testkit.Rows("200", "300", "400", "500"))
3002+
}

0 commit comments

Comments
 (0)