Skip to content

Commit 2dd6b4e

Browse files
committed
fix the length of multiple types
Signed-off-by: Yang Keao <[email protected]>
1 parent c56481e commit 2dd6b4e

File tree

17 files changed

+576
-287
lines changed

17 files changed

+576
-287
lines changed

pkg/expression/builtin_cast.go

Lines changed: 180 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
3838
"github.com/pingcap/tidb/pkg/types"
3939
"github.com/pingcap/tidb/pkg/util/chunk"
40+
"github.com/pingcap/tidb/pkg/util/intest"
4041
"github.com/pingcap/tipb/go-tipb"
4142
)
4243

@@ -114,6 +115,14 @@ var (
114115
_ builtinFunc = &builtinCastVectorFloat32AsUnsupportedSig{}
115116
)
116117

118+
const (
119+
// These two are magic numbers to be compatible with MySQL.
120+
// TODO: at least understand how these values came out. They appears to be `MaxBlobSize * 4` and `MaxMediumBlobSize * 4`
121+
// but I don't know why it needs to multiply by 4. However, the bigger value is always safer, so we use them here.
122+
castBlobFlen = 262140
123+
castMediumBlobFlen = 67108860
124+
)
125+
117126
type castAsIntFunctionClass struct {
118127
baseFunctionClass
119128

@@ -314,26 +323,13 @@ func (c *castAsStringFunctionClass) getFunction(ctx BuildContext, args []Express
314323
tp.AddFlag(mysql.BinaryFlag)
315324
args[0] = BuildCastFunction(ctx, args[0], tp)
316325
}
317-
argTp := args[0].GetType(ctx.GetEvalCtx()).EvalType()
318-
switch argTp {
326+
argFt := args[0].GetType(ctx.GetEvalCtx())
327+
newFlen, newRetTp := estimateLengthForCastString(bf.tp, argFt)
328+
bf.tp.SetFlen(newFlen)
329+
bf.tp.SetType(newRetTp)
330+
331+
switch argFt.EvalType() {
319332
case types.ETInt:
320-
if bf.tp.GetFlen() == types.UnspecifiedLength {
321-
// check https://github.com/pingcap/tidb/issues/44786
322-
// set flen from integers may truncate integers, e.g. char(1) can not display -1[int(1)]
323-
switch args[0].GetType(ctx.GetEvalCtx()).GetType() {
324-
case mysql.TypeTiny:
325-
bf.tp.SetFlen(4)
326-
case mysql.TypeShort:
327-
bf.tp.SetFlen(6)
328-
case mysql.TypeInt24:
329-
bf.tp.SetFlen(9)
330-
case mysql.TypeLong:
331-
// set it to 11 as mysql
332-
bf.tp.SetFlen(11)
333-
default:
334-
bf.tp.SetFlen(args[0].GetType(ctx.GetEvalCtx()).GetFlen())
335-
}
336-
}
337333
sig = &builtinCastIntAsStringSig{bf}
338334
sig.setPbCode(tipb.ScalarFuncSig_CastIntAsString)
339335
case types.ETReal:
@@ -361,11 +357,145 @@ func (c *castAsStringFunctionClass) getFunction(ctx BuildContext, args []Express
361357
sig = &builtinCastStringAsStringSig{bf}
362358
sig.setPbCode(tipb.ScalarFuncSig_CastStringAsString)
363359
default:
364-
return nil, errors.Errorf("cannot cast from %s to %s", argTp, "String")
360+
return nil, errors.Errorf("cannot cast from %s to %s", argFt.EvalType(), "String")
365361
}
366362
return sig, nil
367363
}
368364

365+
func estimateLengthForCastString(retFt, argFt *types.FieldType) (newFlen int, retNewTp byte) {
366+
newFlen = retFt.GetFlen()
367+
retNewTp = retFt.GetType()
368+
369+
// Only estimate the length for variable length string types, because different length for fixed
370+
// length string types will have different behaviors and may cause compatibility issues.
371+
if retFt.GetType() == mysql.TypeString {
372+
return
373+
}
374+
375+
if argFt.GetType() == mysql.TypeNull {
376+
return
377+
}
378+
379+
argTp := argFt.EvalType()
380+
switch argTp {
381+
case types.ETInt:
382+
if newFlen == types.UnspecifiedLength {
383+
// check https://github.com/pingcap/tidb/issues/44786
384+
// set flen from integers may truncate integers, e.g. char(1) can not display -1[int(1)]
385+
switch argFt.GetType() {
386+
case mysql.TypeTiny:
387+
newFlen = 4
388+
if mysql.HasUnsignedFlag(argFt.GetFlag()) {
389+
newFlen = 3
390+
}
391+
case mysql.TypeShort:
392+
newFlen = 6
393+
if mysql.HasUnsignedFlag(argFt.GetFlag()) {
394+
newFlen = 5
395+
}
396+
case mysql.TypeInt24:
397+
newFlen = 9
398+
if mysql.HasUnsignedFlag(argFt.GetFlag()) {
399+
newFlen = 8
400+
}
401+
case mysql.TypeLong:
402+
newFlen = 11
403+
if mysql.HasUnsignedFlag(argFt.GetFlag()) {
404+
newFlen = 10
405+
}
406+
case mysql.TypeLonglong:
407+
// the length of BIGINT is always 20 without considering the unsigned flag, because the
408+
// bigint range from -9223372036854775808 to 9223372036854775807, and unsigned bigint range
409+
// from 0 to 18446744073709551615, they are all 20 characters long.
410+
newFlen = 20
411+
case mysql.TypeYear:
412+
newFlen = 4
413+
case mysql.TypeBit:
414+
newFlen = argFt.GetFlen()
415+
case mysql.TypeEnum:
416+
intest.Assert(false, "cast Enum to String should not set mysql.EnumSetAsIntFlag")
417+
return
418+
case mysql.TypeSet:
419+
intest.Assert(false, "cast Set to String should not set mysql.EnumSetAsIntFlag")
420+
return
421+
default:
422+
intest.Assert(false, "unknown type %d for INT", argFt.GetType())
423+
return
424+
}
425+
}
426+
case types.ETReal:
427+
if newFlen == types.UnspecifiedLength {
428+
switch argFt.GetType() {
429+
case mysql.TypeFloat:
430+
newFlen = maxFloatStrLen
431+
case mysql.TypeDouble:
432+
newFlen = maxDoubleStrLen
433+
default:
434+
intest.Assert(false, "unknown type %d for REAL", argFt.GetType())
435+
return
436+
}
437+
}
438+
case types.ETDecimal:
439+
if newFlen == types.UnspecifiedLength {
440+
newFlen = decimalPrecisionToLength(argFt)
441+
}
442+
case types.ETDatetime, types.ETTimestamp:
443+
if newFlen == types.UnspecifiedLength {
444+
newFlen = mysql.MaxDatetimeWidthNoFsp
445+
if argFt.GetType() == mysql.TypeDate {
446+
newFlen = mysql.MaxDateWidth
447+
}
448+
449+
// Theoretically, the decimal of `DATE` will never be greater than 0.
450+
decimal := argFt.GetDecimal()
451+
if decimal > 0 {
452+
// If the type is datetime or timestamp with fractional seconds, we need to set the length to
453+
// accommodate the fractional seconds part.
454+
newFlen += (1 + decimal)
455+
}
456+
}
457+
case types.ETDuration:
458+
if newFlen == types.UnspecifiedLength {
459+
newFlen = mysql.MaxDurationWidthNoFsp
460+
decimal := argFt.GetDecimal()
461+
if decimal > 0 {
462+
// If the type is time with fractional seconds, we need to set the length to
463+
// accommodate the fractional seconds part.
464+
newFlen += 1 + decimal
465+
}
466+
}
467+
case types.ETJson:
468+
if newFlen == types.UnspecifiedLength {
469+
newFlen = mysql.MaxLongBlobWidth
470+
retNewTp = mysql.TypeLongBlob
471+
}
472+
case types.ETVectorFloat32:
473+
474+
case types.ETString:
475+
if newFlen == types.UnspecifiedLength {
476+
switch argFt.GetType() {
477+
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString:
478+
if argFt.GetFlen() > 0 {
479+
newFlen = argFt.GetFlen()
480+
}
481+
case mysql.TypeTinyBlob:
482+
newFlen = mysql.MaxTinyBlobSize
483+
case mysql.TypeBlob:
484+
newFlen = castBlobFlen
485+
case mysql.TypeMediumBlob:
486+
newFlen = castMediumBlobFlen
487+
case mysql.TypeLongBlob:
488+
newFlen = mysql.MaxLongBlobSize
489+
default:
490+
intest.Assert(false, "unknown type %d for String", argFt.GetType())
491+
return
492+
}
493+
}
494+
}
495+
496+
return
497+
}
498+
369499
type castAsTimeFunctionClass struct {
370500
baseFunctionClass
371501

@@ -1294,6 +1424,9 @@ func (b *builtinCastRealAsStringSig) Clone() builtinFunc {
12941424
return newSig
12951425
}
12961426

1427+
const maxFloatStrLen = 12 // MySQL's max length for float string is 22
1428+
const maxDoubleStrLen = 22 // MySQL's max length for float string is 22
1429+
12971430
func (b *builtinCastRealAsStringSig) evalString(ctx EvalContext, row chunk.Row) (res string, isNull bool, err error) {
12981431
val, isNull, err := b.args[0].EvalReal(ctx, row)
12991432
if isNull || err != nil {
@@ -1307,7 +1440,9 @@ func (b *builtinCastRealAsStringSig) evalString(ctx EvalContext, row chunk.Row)
13071440
// If we strconv.FormatFloat the value with 64bits, the result is incorrect!
13081441
bits = 32
13091442
}
1310-
res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, typeCtx(ctx), false)
1443+
1444+
formatedStr := string(types.AppendFormatFloat(nil, val, -1, bits))
1445+
res, err = types.ProduceStrWithSpecifiedTp(formatedStr, b.tp, typeCtx(ctx), false)
13111446
if err != nil {
13121447
return res, false, err
13131448
}
@@ -2834,3 +2969,27 @@ func TryPushCastIntoControlFunctionForHybridType(ctx BuildContext, expr Expressi
28342969
}
28352970
return expr
28362971
}
2972+
2973+
func decimalPrecisionToLength(ft *types.FieldType) int {
2974+
precision := ft.GetFlen()
2975+
scale := ft.GetDecimal()
2976+
unsigned := mysql.HasUnsignedFlag(ft.GetFlag())
2977+
2978+
if precision == types.UnspecifiedLength || scale == types.UnspecifiedLength {
2979+
return types.UnspecifiedLength
2980+
}
2981+
2982+
ret := precision
2983+
if scale > 0 {
2984+
ret++
2985+
}
2986+
2987+
if !unsigned && precision > 0 {
2988+
ret++ // for negative sign
2989+
}
2990+
2991+
if ret == 0 {
2992+
return 1
2993+
}
2994+
return ret
2995+
}

pkg/expression/builtin_cast_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,3 +1740,81 @@ func TestCastArrayFunc(t *testing.T) {
17401740
}
17411741
}
17421742
}
1743+
1744+
func TestCastAsCharFieldType(t *testing.T) {
1745+
type testCase struct {
1746+
input *Constant
1747+
resultFlen int
1748+
}
1749+
allTestCase := []testCase{
1750+
// test int
1751+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTiny).BuildP(), Value: types.NewIntDatum(0)}, 4},
1752+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeShort).BuildP(), Value: types.NewIntDatum(0)}, 6},
1753+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeInt24).BuildP(), Value: types.NewIntDatum(0)}, 9},
1754+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLong).BuildP(), Value: types.NewIntDatum(0)}, 11},
1755+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLonglong).BuildP(), Value: types.NewIntDatum(0)}, 20},
1756+
// test uint
1757+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTiny).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(0)}, 3},
1758+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeShort).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(1)}, 5},
1759+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeInt24).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(11111)}, 8},
1760+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLong).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(1111111111)}, 10},
1761+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLonglong).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(111111111111111)}, 20},
1762+
// test decimal
1763+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(10).SetDecimal(5).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("12345"))}, 12},
1764+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(2).SetDecimal(1).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("1"))}, 4},
1765+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(30).SetDecimal(0).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("12345"))}, 31},
1766+
// test unsigned decimal
1767+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(10).SetDecimal(5).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("12345"))}, 11},
1768+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(2).SetDecimal(1).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("1"))}, 3},
1769+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(30).SetDecimal(0).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("12345"))}, 30},
1770+
// test real
1771+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeFloat).BuildP(), Value: types.NewFloat64Datum(1.234)}, 12},
1772+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDouble).BuildP(), Value: types.NewFloat64Datum(1.23456789)}, 22},
1773+
// test unsigned real
1774+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeFloat).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewFloat64Datum(1.234)}, 12},
1775+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDouble).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewFloat64Datum(1.23456789)}, 22},
1776+
// test timestamp
1777+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTimestamp).SetFlen(types.UnspecifiedLength).SetDecimal(0).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeTimestamp, 0))}, 19},
1778+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTimestamp).SetFlen(types.UnspecifiedLength).SetDecimal(3).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeTimestamp, 0))}, 23},
1779+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTimestamp).SetFlen(types.UnspecifiedLength).SetDecimal(6).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeTimestamp, 0))}, 26},
1780+
// test datetime
1781+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDatetime).SetFlen(types.UnspecifiedLength).SetDecimal(0).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeDatetime, 0))}, 19},
1782+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDatetime).SetFlen(types.UnspecifiedLength).SetDecimal(3).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeDatetime, 0))}, 23},
1783+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDatetime).SetFlen(types.UnspecifiedLength).SetDecimal(6).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeDatetime, 0))}, 26},
1784+
// test time
1785+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDuration).SetFlen(types.UnspecifiedLength).SetDecimal(0).BuildP(), Value: types.NewDurationDatum(types.NewDuration(10, 10, 10, 110, 0))}, 10},
1786+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDuration).SetFlen(types.UnspecifiedLength).SetDecimal(3).BuildP(), Value: types.NewDurationDatum(types.NewDuration(10, 10, 10, 110, 3))}, 14},
1787+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDuration).SetFlen(types.UnspecifiedLength).SetDecimal(6).BuildP(), Value: types.NewDurationDatum(types.NewDuration(10, 10, 10, 110, 6))}, 17},
1788+
// test date
1789+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDate).SetFlen(types.UnspecifiedLength).SetDecimal(0).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeDate, 0))}, 10},
1790+
// test json
1791+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeJSON).BuildP(), Value: types.NewJSONDatum(types.CreateBinaryJSON(int64(1)))}, 4294967295},
1792+
// test string
1793+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeString).SetFlen(50).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 50},
1794+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeString).SetFlen(50).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 50},
1795+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeVarString).SetFlen(50).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 50},
1796+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeVarString).SetFlen(50).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 50},
1797+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTinyBlob).SetFlen(types.UnspecifiedLength).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 255},
1798+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTinyBlob).SetFlen(types.UnspecifiedLength).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 255},
1799+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeBlob).SetFlen(types.UnspecifiedLength).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 262140},
1800+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeBlob).SetFlen(types.UnspecifiedLength).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 262140},
1801+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeMediumBlob).SetFlen(types.UnspecifiedLength).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 67108860},
1802+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeMediumBlob).SetFlen(types.UnspecifiedLength).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 67108860},
1803+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLongBlob).SetFlen(types.UnspecifiedLength).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 4294967295},
1804+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLongBlob).SetFlen(types.UnspecifiedLength).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 4294967295},
1805+
}
1806+
ctx := createContext(t)
1807+
for i, tc := range allTestCase {
1808+
t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) {
1809+
ft := types.NewFieldTypeBuilder().
1810+
SetType(mysql.TypeVarString).
1811+
SetFlen(types.UnspecifiedLength).
1812+
SetCharset(charset.CharsetUTF8MB4).
1813+
SetCollate(charset.CollationUTF8MB4).
1814+
BuildP()
1815+
expr, err := BuildCastFunctionWithCheck(ctx, tc.input, ft, false, false)
1816+
require.NoError(t, err)
1817+
require.Equal(t, tc.resultFlen, expr.GetType(ctx).GetFlen())
1818+
})
1819+
}
1820+
}

pkg/expression/builtin_cast_vec.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ func (b *builtinCastRealAsStringSig) vecEvalString(ctx EvalContext, input *chunk
222222
result.AppendNull()
223223
continue
224224
}
225-
res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(v, 'f', -1, bits), b.tp, tc, false)
225+
formatedStr := string(types.AppendFormatFloat(nil, v, -1, bits))
226+
res, err = types.ProduceStrWithSpecifiedTp(formatedStr, b.tp, tc, false)
226227
if err != nil {
227228
return err
228229
}

0 commit comments

Comments
 (0)