Skip to content

Commit d781589

Browse files
committed
fix the length of multiple types
Signed-off-by: Yang Keao <[email protected]>
1 parent c56481e commit d781589

File tree

5 files changed

+270
-36
lines changed

5 files changed

+270
-36
lines changed

pkg/expression/builtin_cast.go

Lines changed: 164 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
3838
"github.com/pingcap/tidb/pkg/types"
3939
"github.com/pingcap/tidb/pkg/util/chunk"
40+
"github.com/pingcap/tidb/pkg/util/intest"
4041
"github.com/pingcap/tipb/go-tipb"
4142
)
4243

@@ -114,6 +115,14 @@ var (
114115
_ builtinFunc = &builtinCastVectorFloat32AsUnsupportedSig{}
115116
)
116117

118+
const (
119+
// These two are magic numbers to be compatible with MySQL.
120+
// TODO: at least understand how these values came out. They appears to be `MaxBlobSize * 4` and `MaxMediumBlobSize * 4`
121+
// but I don't know why it needs to multiply by 4. However, the bigger value is always safer, so we use them here.
122+
castBlobFlen = 262140
123+
castMediumBlobFlen = 67108860
124+
)
125+
117126
type castAsIntFunctionClass struct {
118127
baseFunctionClass
119128

@@ -314,26 +323,13 @@ func (c *castAsStringFunctionClass) getFunction(ctx BuildContext, args []Express
314323
tp.AddFlag(mysql.BinaryFlag)
315324
args[0] = BuildCastFunction(ctx, args[0], tp)
316325
}
317-
argTp := args[0].GetType(ctx.GetEvalCtx()).EvalType()
318-
switch argTp {
326+
argFt := args[0].GetType(ctx.GetEvalCtx())
327+
newFlen, newRetTp := estimateLengthForCastString(bf.tp, argFt)
328+
bf.tp.SetFlen(newFlen)
329+
bf.tp.SetType(newRetTp)
330+
331+
switch argFt.EvalType() {
319332
case types.ETInt:
320-
if bf.tp.GetFlen() == types.UnspecifiedLength {
321-
// check https://github.com/pingcap/tidb/issues/44786
322-
// set flen from integers may truncate integers, e.g. char(1) can not display -1[int(1)]
323-
switch args[0].GetType(ctx.GetEvalCtx()).GetType() {
324-
case mysql.TypeTiny:
325-
bf.tp.SetFlen(4)
326-
case mysql.TypeShort:
327-
bf.tp.SetFlen(6)
328-
case mysql.TypeInt24:
329-
bf.tp.SetFlen(9)
330-
case mysql.TypeLong:
331-
// set it to 11 as mysql
332-
bf.tp.SetFlen(11)
333-
default:
334-
bf.tp.SetFlen(args[0].GetType(ctx.GetEvalCtx()).GetFlen())
335-
}
336-
}
337333
sig = &builtinCastIntAsStringSig{bf}
338334
sig.setPbCode(tipb.ScalarFuncSig_CastIntAsString)
339335
case types.ETReal:
@@ -361,11 +357,125 @@ func (c *castAsStringFunctionClass) getFunction(ctx BuildContext, args []Express
361357
sig = &builtinCastStringAsStringSig{bf}
362358
sig.setPbCode(tipb.ScalarFuncSig_CastStringAsString)
363359
default:
364-
return nil, errors.Errorf("cannot cast from %s to %s", argTp, "String")
360+
return nil, errors.Errorf("cannot cast from %s to %s", argFt.EvalType(), "String")
365361
}
366362
return sig, nil
367363
}
368364

365+
func estimateLengthForCastString(retFt, argFt *types.FieldType) (newFlen int, retNewTp byte) {
366+
newFlen = retFt.GetFlen()
367+
retNewTp = retFt.GetType()
368+
369+
// Only estimate the length for variable length string types, because different length for fixed
370+
// length string types will have different behaviors and may cause compatibility issues.
371+
if retFt.GetType() == mysql.TypeString {
372+
return
373+
}
374+
375+
argTp := argFt.EvalType()
376+
switch argTp {
377+
case types.ETInt:
378+
if newFlen == types.UnspecifiedLength {
379+
// check https://github.com/pingcap/tidb/issues/44786
380+
// set flen from integers may truncate integers, e.g. char(1) can not display -1[int(1)]
381+
switch argFt.GetType() {
382+
case mysql.TypeTiny:
383+
newFlen = 4
384+
case mysql.TypeShort:
385+
newFlen = 6
386+
case mysql.TypeInt24:
387+
newFlen = 9
388+
case mysql.TypeLong:
389+
// set it to 11 as mysql
390+
newFlen = 11
391+
case mysql.TypeLonglong:
392+
// set it to 20 as mysql
393+
newFlen = 20
394+
default:
395+
intest.Assert(false, "unknown type %d for INT", argFt.GetType())
396+
newFlen = argFt.GetFlen()
397+
}
398+
399+
// to be compatible with MySQL, `bigint unsigned` doesn't remove the space for sign.
400+
if mysql.HasUnsignedFlag(argFt.GetFlag()) &&
401+
argFt.GetType() != mysql.TypeLonglong {
402+
// Remove the space for sign
403+
newFlen--
404+
}
405+
}
406+
case types.ETReal:
407+
if newFlen == types.UnspecifiedLength {
408+
switch argFt.GetType() {
409+
case mysql.TypeFloat:
410+
newFlen = 12
411+
case mysql.TypeDouble:
412+
newFlen = 22
413+
default:
414+
intest.Assert(false, "unknown type %d for REAL", argFt.GetType())
415+
newFlen = argFt.GetFlen()
416+
}
417+
}
418+
case types.ETDecimal:
419+
if newFlen == types.UnspecifiedLength {
420+
newFlen = decimalPrecisionToLength(argFt)
421+
}
422+
case types.ETDatetime, types.ETTimestamp:
423+
if newFlen == types.UnspecifiedLength {
424+
newFlen = mysql.MaxDatetimeWidthNoFsp
425+
if argFt.GetType() == mysql.TypeDate {
426+
newFlen = mysql.MaxDateWidth
427+
}
428+
429+
// Theoretically, the decimal of `DATE` will never be greater than 0.
430+
decimal := argFt.GetDecimal()
431+
if decimal > 0 {
432+
// If the type is datetime or timestamp with fractional seconds, we need to set the length to
433+
// accommodate the fractional seconds part.
434+
newFlen += (1 + decimal)
435+
}
436+
}
437+
case types.ETDuration:
438+
if newFlen == types.UnspecifiedLength {
439+
newFlen = mysql.MaxDurationWidthNoFsp
440+
decimal := argFt.GetDecimal()
441+
if decimal > 0 {
442+
// If the type is time with fractional seconds, we need to set the length to
443+
// accommodate the fractional seconds part.
444+
newFlen += 1 + decimal
445+
}
446+
}
447+
case types.ETJson:
448+
if newFlen == types.UnspecifiedLength {
449+
newFlen = mysql.MaxLongBlobWidth
450+
retNewTp = mysql.TypeLongBlob
451+
}
452+
case types.ETVectorFloat32:
453+
454+
case types.ETString:
455+
if newFlen == types.UnspecifiedLength {
456+
switch argFt.GetType() {
457+
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString:
458+
if argFt.GetFlen() > 0 {
459+
newFlen = argFt.GetFlen()
460+
}
461+
case mysql.TypeTinyBlob:
462+
newFlen = mysql.MaxTinyBlobSize
463+
case mysql.TypeBlob:
464+
newFlen = castBlobFlen
465+
case mysql.TypeMediumBlob:
466+
newFlen = castMediumBlobFlen
467+
case mysql.TypeLongBlob:
468+
newFlen = mysql.MaxLongBlobSize
469+
default:
470+
intest.Assert(false, "unknown type %d for String", argFt.GetType())
471+
newFlen = argFt.GetFlen()
472+
}
473+
}
474+
}
475+
476+
return
477+
}
478+
369479
type castAsTimeFunctionClass struct {
370480
baseFunctionClass
371481

@@ -1307,7 +1417,16 @@ func (b *builtinCastRealAsStringSig) evalString(ctx EvalContext, row chunk.Row)
13071417
// If we strconv.FormatFloat the value with 64bits, the result is incorrect!
13081418
bits = 32
13091419
}
1310-
res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, bits), b.tp, typeCtx(ctx), false)
1420+
1421+
formatedStr := strconv.FormatFloat(val, 'f', -1, bits)
1422+
// try to use `e` format to format the value if the length of the string is too long.
1423+
// MySQL has a more complicated rule to determine whether to use `e` format or not. The compatibility is hard to achieve,
1424+
// but with the rule below, we at least have no less precision than MySQL, and is compatible in most of the cases.
1425+
if len(formatedStr) > b.tp.GetFlen() {
1426+
formatedStr = strconv.FormatFloat(val, 'e', -1, bits)
1427+
}
1428+
1429+
res, err = types.ProduceStrWithSpecifiedTp(formatedStr, b.tp, typeCtx(ctx), false)
13111430
if err != nil {
13121431
return res, false, err
13131432
}
@@ -2834,3 +2953,27 @@ func TryPushCastIntoControlFunctionForHybridType(ctx BuildContext, expr Expressi
28342953
}
28352954
return expr
28362955
}
2956+
2957+
func decimalPrecisionToLength(ft *types.FieldType) int {
2958+
precision := ft.GetFlen()
2959+
scale := ft.GetDecimal()
2960+
unsigned := mysql.HasUnsignedFlag(ft.GetFlag())
2961+
2962+
if precision == types.UnspecifiedLength || scale == types.UnspecifiedLength {
2963+
return types.UnspecifiedLength
2964+
}
2965+
2966+
ret := precision
2967+
if scale > 0 {
2968+
ret++
2969+
}
2970+
2971+
if !unsigned && precision > 0 {
2972+
ret++ // for negative sign
2973+
}
2974+
2975+
if ret == 0 {
2976+
return 1
2977+
}
2978+
return ret
2979+
}

pkg/expression/builtin_cast_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,3 +1740,81 @@ func TestCastArrayFunc(t *testing.T) {
17401740
}
17411741
}
17421742
}
1743+
1744+
func TestCastAsCharFieldType(t *testing.T) {
1745+
type testCase struct {
1746+
input *Constant
1747+
resultFlen int
1748+
}
1749+
allTestCase := []testCase{
1750+
// test int
1751+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTiny).BuildP(), Value: types.NewIntDatum(0)}, 4},
1752+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeShort).BuildP(), Value: types.NewIntDatum(0)}, 6},
1753+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeInt24).BuildP(), Value: types.NewIntDatum(0)}, 9},
1754+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLong).BuildP(), Value: types.NewIntDatum(0)}, 11},
1755+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLonglong).BuildP(), Value: types.NewIntDatum(0)}, 20},
1756+
// test uint
1757+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTiny).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(0)}, 3},
1758+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeShort).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(1)}, 5},
1759+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeInt24).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(11111)}, 8},
1760+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLong).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(1111111111)}, 10},
1761+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLonglong).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(111111111111111)}, 20},
1762+
// test decimal
1763+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(10).SetDecimal(5).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("12345"))}, 12},
1764+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(2).SetDecimal(1).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("1"))}, 4},
1765+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(30).SetDecimal(0).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("12345"))}, 31},
1766+
// test unsigned decimal
1767+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(10).SetDecimal(5).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("12345"))}, 11},
1768+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(2).SetDecimal(1).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("1"))}, 3},
1769+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(30).SetDecimal(0).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("12345"))}, 30},
1770+
// test real
1771+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeFloat).BuildP(), Value: types.NewFloat64Datum(1.234)}, 12},
1772+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDouble).BuildP(), Value: types.NewFloat64Datum(1.23456789)}, 22},
1773+
// test unsigned real
1774+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeFloat).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewFloat64Datum(1.234)}, 12},
1775+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDouble).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewFloat64Datum(1.23456789)}, 22},
1776+
// test timestamp
1777+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTimestamp).SetFlen(types.UnspecifiedLength).SetDecimal(0).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeTimestamp, 0))}, 19},
1778+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTimestamp).SetFlen(types.UnspecifiedLength).SetDecimal(3).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeTimestamp, 0))}, 23},
1779+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTimestamp).SetFlen(types.UnspecifiedLength).SetDecimal(6).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeTimestamp, 0))}, 26},
1780+
// test datetime
1781+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDatetime).SetFlen(types.UnspecifiedLength).SetDecimal(0).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeDatetime, 0))}, 19},
1782+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDatetime).SetFlen(types.UnspecifiedLength).SetDecimal(3).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeDatetime, 0))}, 23},
1783+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDatetime).SetFlen(types.UnspecifiedLength).SetDecimal(6).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeDatetime, 0))}, 26},
1784+
// test time
1785+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDuration).SetFlen(types.UnspecifiedLength).SetDecimal(0).BuildP(), Value: types.NewDurationDatum(types.NewDuration(10, 10, 10, 110, 0))}, 10},
1786+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDuration).SetFlen(types.UnspecifiedLength).SetDecimal(3).BuildP(), Value: types.NewDurationDatum(types.NewDuration(10, 10, 10, 110, 3))}, 14},
1787+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDuration).SetFlen(types.UnspecifiedLength).SetDecimal(6).BuildP(), Value: types.NewDurationDatum(types.NewDuration(10, 10, 10, 110, 6))}, 17},
1788+
// test date
1789+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDate).SetFlen(types.UnspecifiedLength).SetDecimal(0).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeDate, 0))}, 10},
1790+
// test json
1791+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeJSON).BuildP(), Value: types.NewJSONDatum(types.CreateBinaryJSON(int64(1)))}, 4294967295},
1792+
// test string
1793+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeString).SetFlen(50).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 50},
1794+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeString).SetFlen(50).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 50},
1795+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeVarString).SetFlen(50).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 50},
1796+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeVarString).SetFlen(50).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 50},
1797+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTinyBlob).SetFlen(types.UnspecifiedLength).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 255},
1798+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTinyBlob).SetFlen(types.UnspecifiedLength).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 255},
1799+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeBlob).SetFlen(types.UnspecifiedLength).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 262140},
1800+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeBlob).SetFlen(types.UnspecifiedLength).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 262140},
1801+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeMediumBlob).SetFlen(types.UnspecifiedLength).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 67108860},
1802+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeMediumBlob).SetFlen(types.UnspecifiedLength).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 67108860},
1803+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLongBlob).SetFlen(types.UnspecifiedLength).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 4294967295},
1804+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLongBlob).SetFlen(types.UnspecifiedLength).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 4294967295},
1805+
}
1806+
ctx := createContext(t)
1807+
for i, tc := range allTestCase {
1808+
t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) {
1809+
ft := types.NewFieldTypeBuilder().
1810+
SetType(mysql.TypeVarString).
1811+
SetFlen(types.UnspecifiedLength).
1812+
SetCharset(charset.CharsetUTF8MB4).
1813+
SetCollate(charset.CollationUTF8MB4).
1814+
BuildP()
1815+
expr, err := BuildCastFunctionWithCheck(ctx, tc.input, ft, false, false)
1816+
require.NoError(t, err)
1817+
require.Equal(t, tc.resultFlen, expr.GetType(ctx).GetFlen())
1818+
})
1819+
}
1820+
}

pkg/expression/builtin_cast_vec.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,12 @@ func (b *builtinCastRealAsStringSig) vecEvalString(ctx EvalContext, input *chunk
222222
result.AppendNull()
223223
continue
224224
}
225+
formatedStr := strconv.FormatFloat(v, 'f', -1, bits)
226+
// try to use `e` format to format the value if the length of the string is too long.
227+
// ref the comment in `builtinCastRealAsStringSig.evalString`
228+
if len(formatedStr) > b.tp.GetFlen() {
229+
formatedStr = strconv.FormatFloat(v, 'e', -1, bits)
230+
}
225231
res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(v, 'f', -1, bits), b.tp, tc, false)
226232
if err != nil {
227233
return err

0 commit comments

Comments
 (0)