Skip to content

Commit 6b45e27

Browse files
authored
expression: fix the length of casting from INT/REAL/DECIMAL/.... to string (#62330)
close #61350
1 parent 9e1066b commit 6b45e27

File tree

12 files changed

+364
-59
lines changed

12 files changed

+364
-59
lines changed

pkg/expression/builtin_cast.go

Lines changed: 188 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import (
3838
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
3939
"github.com/pingcap/tidb/pkg/types"
4040
"github.com/pingcap/tidb/pkg/util/chunk"
41+
"github.com/pingcap/tidb/pkg/util/intest"
4142
"github.com/pingcap/tipb/go-tipb"
4243
)
4344

@@ -115,6 +116,18 @@ var (
115116
_ builtinFunc = &builtinCastVectorFloat32AsUnsupportedSig{}
116117
)
117118

119+
const (
120+
maxTinyBlobSize = 255
121+
maxBlobSize = 65535
122+
maxMediumBlobSize = 16777215
123+
maxLongBlobSize = 4294967295
124+
// These two are magic numbers to be compatible with MySQL.
125+
// They are `MaxBlobSize * 4` and `MaxMediumBlobSize * 4`, but multiply by 4 (mblen) is not necessary here. However
126+
// a bigger number is always safer to avoid truncation, so they are kept as is.
127+
castBlobFlen = maxBlobSize * 4
128+
castMediumBlobFlen = maxMediumBlobSize * 4
129+
)
130+
118131
type castAsIntFunctionClass struct {
119132
baseFunctionClass
120133

@@ -315,26 +328,11 @@ func (c *castAsStringFunctionClass) getFunction(ctx BuildContext, args []Express
315328
tp.AddFlag(mysql.BinaryFlag)
316329
args[0] = BuildCastFunction(ctx, args[0], tp)
317330
}
318-
argTp := args[0].GetType(ctx.GetEvalCtx()).EvalType()
319-
switch argTp {
331+
argFt := args[0].GetType(ctx.GetEvalCtx())
332+
adjustRetFtForCastString(bf.tp, argFt)
333+
334+
switch argFt.EvalType() {
320335
case types.ETInt:
321-
if bf.tp.GetFlen() == types.UnspecifiedLength {
322-
// check https://github.com/pingcap/tidb/issues/44786
323-
// set flen from integers may truncate integers, e.g. char(1) can not display -1[int(1)]
324-
switch args[0].GetType(ctx.GetEvalCtx()).GetType() {
325-
case mysql.TypeTiny:
326-
bf.tp.SetFlen(4)
327-
case mysql.TypeShort:
328-
bf.tp.SetFlen(6)
329-
case mysql.TypeInt24:
330-
bf.tp.SetFlen(9)
331-
case mysql.TypeLong:
332-
// set it to 11 as mysql
333-
bf.tp.SetFlen(11)
334-
default:
335-
bf.tp.SetFlen(args[0].GetType(ctx.GetEvalCtx()).GetFlen())
336-
}
337-
}
338336
sig = &builtinCastIntAsStringSig{bf}
339337
sig.setPbCode(tipb.ScalarFuncSig_CastIntAsString)
340338
case types.ETReal:
@@ -362,11 +360,156 @@ func (c *castAsStringFunctionClass) getFunction(ctx BuildContext, args []Express
362360
sig = &builtinCastStringAsStringSig{bf}
363361
sig.setPbCode(tipb.ScalarFuncSig_CastStringAsString)
364362
default:
365-
return nil, errors.Errorf("cannot cast from %s to %s", argTp, "String")
363+
return nil, errors.Errorf("cannot cast from %s to %s", argFt.EvalType(), "String")
366364
}
367365
return sig, nil
368366
}
369367

368+
func adjustRetFtForCastString(retFt, argFt *types.FieldType) {
369+
originalFlen := retFt.GetFlen()
370+
371+
// Only estimate the length for variable length string types, because different length for fixed
372+
// length string types will have different behaviors and may cause compatibility issues.
373+
if retFt.GetType() == mysql.TypeString {
374+
return
375+
}
376+
377+
if argFt.GetType() == mysql.TypeNull {
378+
return
379+
}
380+
381+
argTp := argFt.EvalType()
382+
switch argTp {
383+
case types.ETInt:
384+
if originalFlen == types.UnspecifiedLength {
385+
// check https://github.com/pingcap/tidb/issues/44786
386+
// set flen from integers may truncate integers, e.g. char(1) can not display -1[int(1)]
387+
switch argFt.GetType() {
388+
case mysql.TypeTiny:
389+
if mysql.HasUnsignedFlag(argFt.GetFlag()) {
390+
retFt.SetFlen(3)
391+
} else {
392+
retFt.SetFlen(4)
393+
}
394+
case mysql.TypeShort:
395+
if mysql.HasUnsignedFlag(argFt.GetFlag()) {
396+
retFt.SetFlen(5)
397+
} else {
398+
retFt.SetFlen(6)
399+
}
400+
case mysql.TypeInt24:
401+
if mysql.HasUnsignedFlag(argFt.GetFlag()) {
402+
retFt.SetFlen(8)
403+
} else {
404+
retFt.SetFlen(9)
405+
}
406+
case mysql.TypeLong:
407+
if mysql.HasUnsignedFlag(argFt.GetFlag()) {
408+
retFt.SetFlen(10)
409+
} else {
410+
retFt.SetFlen(11)
411+
}
412+
case mysql.TypeLonglong:
413+
// the length of BIGINT is always 20 without considering the unsigned flag, because the
414+
// bigint range from -9223372036854775808 to 9223372036854775807, and unsigned bigint range
415+
// from 0 to 18446744073709551615, they are all 20 characters long.
416+
retFt.SetFlen(20)
417+
case mysql.TypeYear:
418+
retFt.SetFlen(4)
419+
case mysql.TypeBit:
420+
retFt.SetFlen(argFt.GetFlen())
421+
case mysql.TypeEnum:
422+
intest.Assert(false, "cast Enum to String should not set mysql.EnumSetAsIntFlag")
423+
return
424+
case mysql.TypeSet:
425+
intest.Assert(false, "cast Set to String should not set mysql.EnumSetAsIntFlag")
426+
return
427+
default:
428+
intest.Assert(false, "unknown type %d for INT", argFt.GetType())
429+
return
430+
}
431+
}
432+
case types.ETReal:
433+
// MySQL used 12/22 for float/double, it's because MySQL turns float/double into scientific notation
434+
// in some situations. TiDB choose to use 'f' format for all the cases, so TiDB needs much longer length
435+
// for float/double.
436+
//
437+
// The largest float/double value is around `3.40e38`/`1.79e308`, and the smallest positive float/double value
438+
// is around `1.40e-45`/`4.94e-324`. Therefore, we need at least `1 (sign) + 1 (integer) + 1 (dot) + (45 + 39) (decimal) = 87`
439+
// for float and `1 (sign) + 1 (integer) + 1 (dot) + (324 + 43) (decimal) = 370` for double.
440+
//
441+
// Actually, the golang will usually generate a much smaller string. It used ryu algorithm to generate the shortest
442+
// decimal representation. It's not necessary to keep all decimals. Ref:
443+
// - https://github.com/ulfjack/ryu
444+
// - https://dl.acm.org/doi/10.1145/93548.93559
445+
// So maybe 48/327 is enough for float/double, but we still set 87/370 for safety.
446+
if originalFlen == types.UnspecifiedLength {
447+
if argFt.GetType() == mysql.TypeFloat {
448+
retFt.SetFlen(87)
449+
} else if argFt.GetType() == mysql.TypeDouble {
450+
retFt.SetFlen(370)
451+
}
452+
}
453+
case types.ETDecimal:
454+
if originalFlen == types.UnspecifiedLength {
455+
retFt.SetFlen(decimalPrecisionToLength(argFt))
456+
}
457+
case types.ETDatetime, types.ETTimestamp:
458+
if originalFlen == types.UnspecifiedLength {
459+
if argFt.GetType() == mysql.TypeDate {
460+
retFt.SetFlen(mysql.MaxDateWidth)
461+
} else {
462+
retFt.SetFlen(mysql.MaxDatetimeWidthNoFsp)
463+
}
464+
465+
// Theoretically, the decimal of `DATE` will never be greater than 0.
466+
decimal := argFt.GetDecimal()
467+
if decimal > 0 {
468+
// If the type is datetime or timestamp with fractional seconds, we need to set the length to
469+
// accommodate the fractional seconds part.
470+
retFt.SetFlen(retFt.GetFlen() + 1 + decimal)
471+
}
472+
}
473+
case types.ETDuration:
474+
if originalFlen == types.UnspecifiedLength {
475+
retFt.SetFlen(mysql.MaxDurationWidthNoFsp)
476+
decimal := argFt.GetDecimal()
477+
if decimal > 0 {
478+
// If the type is time with fractional seconds, we need to set the length to
479+
// accommodate the fractional seconds part.
480+
retFt.SetFlen(retFt.GetFlen() + 1 + decimal)
481+
}
482+
}
483+
case types.ETJson:
484+
if originalFlen == types.UnspecifiedLength {
485+
retFt.SetFlen(mysql.MaxLongBlobWidth)
486+
retFt.SetType(mysql.TypeLongBlob)
487+
}
488+
case types.ETVectorFloat32:
489+
490+
case types.ETString:
491+
if originalFlen == types.UnspecifiedLength {
492+
switch argFt.GetType() {
493+
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString:
494+
if argFt.GetFlen() > 0 {
495+
retFt.SetFlen(argFt.GetFlen())
496+
}
497+
case mysql.TypeTinyBlob:
498+
retFt.SetFlen(maxTinyBlobSize)
499+
case mysql.TypeBlob:
500+
retFt.SetFlen(castBlobFlen)
501+
case mysql.TypeMediumBlob:
502+
retFt.SetFlen(castMediumBlobFlen)
503+
case mysql.TypeLongBlob:
504+
retFt.SetFlen(maxLongBlobSize)
505+
default:
506+
intest.Assert(false, "unknown type %d for String", argFt.GetType())
507+
return
508+
}
509+
}
510+
}
511+
}
512+
370513
type castAsTimeFunctionClass struct {
371514
baseFunctionClass
372515

@@ -1248,6 +1391,7 @@ func (b *builtinCastRealAsStringSig) evalString(ctx EvalContext, row chunk.Row)
12481391
// If we strconv.FormatFloat the value with 64bits, the result is incorrect!
12491392
bits = 32
12501393
}
1394+
12511395
res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, typeCtx(ctx), false)
12521396
if err != nil {
12531397
return res, false, err
@@ -2669,3 +2813,27 @@ func TryPushCastIntoControlFunctionForHybridType(ctx BuildContext, expr Expressi
26692813
}
26702814
return expr
26712815
}
2816+
2817+
func decimalPrecisionToLength(ft *types.FieldType) int {
2818+
precision := ft.GetFlen()
2819+
scale := ft.GetDecimal()
2820+
unsigned := mysql.HasUnsignedFlag(ft.GetFlag())
2821+
2822+
if precision == types.UnspecifiedLength || scale == types.UnspecifiedLength {
2823+
return types.UnspecifiedLength
2824+
}
2825+
2826+
ret := precision
2827+
if scale > 0 {
2828+
ret++
2829+
}
2830+
2831+
if !unsigned && precision > 0 {
2832+
ret++ // for negative sign
2833+
}
2834+
2835+
if ret == 0 {
2836+
return 1
2837+
}
2838+
return ret
2839+
}

0 commit comments

Comments
 (0)