Skip to content

Commit 2ab330d

Browse files
committed
fix the length of multiple types
Signed-off-by: Yang Keao <[email protected]>
1 parent fa3c79b commit 2ab330d

File tree

3 files changed

+195
-10
lines changed

3 files changed

+195
-10
lines changed

pkg/expression/builtin_cast.go

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
3838
"github.com/pingcap/tidb/pkg/types"
3939
"github.com/pingcap/tidb/pkg/util/chunk"
40+
"github.com/pingcap/tidb/pkg/util/intest"
4041
"github.com/pingcap/tipb/go-tipb"
4142
)
4243

@@ -114,6 +115,14 @@ var (
114115
_ builtinFunc = &builtinCastVectorFloat32AsUnsupportedSig{}
115116
)
116117

118+
const (
119+
// These two are magic numbers to be compatible with MySQL.
120+
// TODO: at least understand how these values came out. They appears to be `MaxBlobSize * 4` and `MaxMediumBlobSize * 4`
121+
// but I don't know why it needs to multiply by 4. However, the bigger value is always safer, so we use them here.
122+
castBlobFlen = 262140
123+
castMediumBlobFlen = 67108860
124+
)
125+
117126
type castAsIntFunctionClass struct {
118127
baseFunctionClass
119128

@@ -314,13 +323,14 @@ func (c *castAsStringFunctionClass) getFunction(ctx BuildContext, args []Express
314323
tp.AddFlag(mysql.BinaryFlag)
315324
args[0] = BuildCastFunction(ctx, args[0], tp)
316325
}
317-
argTp := args[0].GetType(ctx.GetEvalCtx()).EvalType()
326+
argFt := args[0].GetType(ctx.GetEvalCtx())
327+
argTp := argFt.EvalType()
318328
switch argTp {
319329
case types.ETInt:
320330
if bf.tp.GetFlen() == types.UnspecifiedLength {
321331
// check https://github.com/pingcap/tidb/issues/44786
322332
// set flen from integers may truncate integers, e.g. char(1) can not display -1[int(1)]
323-
switch args[0].GetType(ctx.GetEvalCtx()).GetType() {
333+
switch argFt.GetType() {
324334
case mysql.TypeTiny:
325335
bf.tp.SetFlen(4)
326336
case mysql.TypeShort:
@@ -330,31 +340,99 @@ func (c *castAsStringFunctionClass) getFunction(ctx BuildContext, args []Express
330340
case mysql.TypeLong:
331341
// set it to 11 as mysql
332342
bf.tp.SetFlen(11)
343+
case mysql.TypeLonglong:
344+
// set it to 20 as mysql
345+
bf.tp.SetFlen(20)
333346
default:
334-
bf.tp.SetFlen(args[0].GetType(ctx.GetEvalCtx()).GetFlen())
347+
intest.Assert(false, "unknown type %d for INT", argFt.GetType())
348+
bf.tp.SetFlen(argFt.GetFlen())
349+
}
350+
351+
// to be compatible with MySQL, `bigint unsigned` doesn't remove the space for sign.
352+
if mysql.HasUnsignedFlag(argFt.GetFlag()) &&
353+
argFt.GetType() != mysql.TypeLonglong {
354+
// Remove the space for sign
355+
bf.tp.SetFlen(bf.tp.GetFlen() - 1)
335356
}
336357
}
337358
sig = &builtinCastIntAsStringSig{bf}
338359
sig.setPbCode(tipb.ScalarFuncSig_CastIntAsString)
339360
case types.ETReal:
361+
if bf.tp.GetFlen() == types.UnspecifiedLength {
362+
switch argFt.GetType() {
363+
case mysql.TypeFloat:
364+
bf.tp.SetFlen(12)
365+
case mysql.TypeDouble:
366+
bf.tp.SetFlen(22)
367+
}
368+
}
340369
sig = &builtinCastRealAsStringSig{bf}
341370
sig.setPbCode(tipb.ScalarFuncSig_CastRealAsString)
342371
case types.ETDecimal:
372+
if bf.tp.GetFlen() == types.UnspecifiedLength {
373+
bf.tp.SetFlen(decimalPrecisionToLength(argFt))
374+
}
343375
sig = &builtinCastDecimalAsStringSig{bf}
344376
sig.setPbCode(tipb.ScalarFuncSig_CastDecimalAsString)
345377
case types.ETDatetime, types.ETTimestamp:
378+
if bf.tp.GetFlen() == types.UnspecifiedLength {
379+
bf.tp.SetFlen(mysql.MaxDatetimeWidthNoFsp)
380+
if argFt.GetType() == mysql.TypeDate {
381+
bf.tp.SetFlen(mysql.MaxDateWidth)
382+
}
383+
384+
// Theoretically, the decimal of `DATE` will never be greater than 0.
385+
decimal := argFt.GetDecimal()
386+
if decimal > 0 {
387+
// If the type is datetime or timestamp with fractional seconds, we need to set the length to
388+
// accommodate the fractional seconds part.
389+
bf.tp.SetFlen(bf.tp.GetFlen() + (1 + decimal))
390+
}
391+
}
346392
sig = &builtinCastTimeAsStringSig{bf}
347393
sig.setPbCode(tipb.ScalarFuncSig_CastTimeAsString)
348394
case types.ETDuration:
395+
if bf.tp.GetFlen() == types.UnspecifiedLength {
396+
bf.tp.SetFlen(mysql.MaxDurationWidthNoFsp)
397+
decimal := argFt.GetDecimal()
398+
if decimal > 0 {
399+
// If the type is time with fractional seconds, we need to set the length to
400+
// accommodate the fractional seconds part.
401+
bf.tp.SetFlen(bf.tp.GetFlen() + 1 + decimal)
402+
}
403+
}
349404
sig = &builtinCastDurationAsStringSig{bf}
350405
sig.setPbCode(tipb.ScalarFuncSig_CastDurationAsString)
351406
case types.ETJson:
407+
if bf.tp.GetFlen() == types.UnspecifiedLength {
408+
bf.tp.SetFlen(mysql.MaxLongBlobWidth)
409+
bf.tp.SetType(mysql.TypeLongBlob)
410+
}
352411
sig = &builtinCastJSONAsStringSig{bf}
353412
sig.setPbCode(tipb.ScalarFuncSig_CastJsonAsString)
354413
case types.ETVectorFloat32:
355414
sig = &builtinCastVectorFloat32AsStringSig{bf}
356415
sig.setPbCode(tipb.ScalarFuncSig_CastVectorFloat32AsString)
357416
case types.ETString:
417+
if bf.tp.GetFlen() == types.UnspecifiedLength {
418+
flen := types.UnspecifiedLength
419+
switch argFt.GetType() {
420+
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString:
421+
if argFt.GetFlen() > 0 {
422+
flen = argFt.GetFlen()
423+
}
424+
case mysql.TypeTinyBlob:
425+
flen = mysql.MaxTinyBlobSize
426+
case mysql.TypeBlob:
427+
flen = castBlobFlen
428+
case mysql.TypeMediumBlob:
429+
flen = castMediumBlobFlen
430+
case mysql.TypeLongBlob:
431+
flen = mysql.MaxLongBlobSize
432+
}
433+
bf.tp.SetFlen(flen)
434+
}
435+
358436
// When cast from binary to some other charsets, we should check if the binary is valid or not.
359437
// so we build a from_binary function to do this check.
360438
bf.args[0] = HandleBinaryLiteral(ctx, args[0], &ExprCollation{Charset: c.tp.GetCharset(), Collation: c.tp.GetCollate()}, c.funcName, true)
@@ -2847,3 +2925,27 @@ func TryPushCastIntoControlFunctionForHybridType(ctx BuildContext, expr Expressi
28472925
}
28482926
return expr
28492927
}
2928+
2929+
func decimalPrecisionToLength(ft *types.FieldType) int {
2930+
precision := ft.GetFlen()
2931+
scale := ft.GetDecimal()
2932+
unsigned := mysql.HasUnsignedFlag(ft.GetFlag())
2933+
2934+
if precision == types.UnspecifiedLength || scale == types.UnspecifiedLength {
2935+
return types.UnspecifiedLength
2936+
}
2937+
2938+
ret := precision
2939+
if scale > 0 {
2940+
ret++
2941+
}
2942+
2943+
if !unsigned && precision > 0 {
2944+
ret++ // for negative sign
2945+
}
2946+
2947+
if ret == 0 {
2948+
return 1
2949+
}
2950+
return ret
2951+
}

pkg/expression/builtin_cast_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,3 +1740,81 @@ func TestCastArrayFunc(t *testing.T) {
17401740
}
17411741
}
17421742
}
1743+
1744+
func TestCastAsCharFieldType(t *testing.T) {
1745+
type testCase struct {
1746+
input *Constant
1747+
resultFlen int
1748+
}
1749+
allTestCase := []testCase{
1750+
// test int
1751+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTiny).BuildP(), Value: types.NewIntDatum(0)}, 4},
1752+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeShort).BuildP(), Value: types.NewIntDatum(0)}, 6},
1753+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeInt24).BuildP(), Value: types.NewIntDatum(0)}, 9},
1754+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLong).BuildP(), Value: types.NewIntDatum(0)}, 11},
1755+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLonglong).BuildP(), Value: types.NewIntDatum(0)}, 20},
1756+
// test uint
1757+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTiny).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(0)}, 3},
1758+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeShort).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(1)}, 5},
1759+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeInt24).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(11111)}, 8},
1760+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLong).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(1111111111)}, 10},
1761+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLonglong).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewUintDatum(111111111111111)}, 20},
1762+
// test decimal
1763+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(10).SetDecimal(5).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("12345"))}, 12},
1764+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(2).SetDecimal(1).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("1"))}, 4},
1765+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(30).SetDecimal(0).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("12345"))}, 31},
1766+
// test unsigned decimal
1767+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(10).SetDecimal(5).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("12345"))}, 11},
1768+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(2).SetDecimal(1).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("1"))}, 3},
1769+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeNewDecimal).SetFlen(30).SetDecimal(0).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewDecimalDatum(types.NewDecFromStringForTest("12345"))}, 30},
1770+
// test real
1771+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeFloat).BuildP(), Value: types.NewFloat64Datum(1.234)}, 12},
1772+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDouble).BuildP(), Value: types.NewFloat64Datum(1.23456789)}, 22},
1773+
// test unsigned real
1774+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeFloat).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewFloat64Datum(1.234)}, 12},
1775+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDouble).SetFlag(mysql.UnsignedFlag).BuildP(), Value: types.NewFloat64Datum(1.23456789)}, 22},
1776+
// test timestamp
1777+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTimestamp).SetFlen(types.UnspecifiedLength).SetDecimal(0).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeTimestamp, 0))}, 19},
1778+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTimestamp).SetFlen(types.UnspecifiedLength).SetDecimal(3).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeTimestamp, 0))}, 23},
1779+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTimestamp).SetFlen(types.UnspecifiedLength).SetDecimal(6).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeTimestamp, 0))}, 26},
1780+
// test datetime
1781+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDatetime).SetFlen(types.UnspecifiedLength).SetDecimal(0).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeDatetime, 0))}, 19},
1782+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDatetime).SetFlen(types.UnspecifiedLength).SetDecimal(3).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeDatetime, 0))}, 23},
1783+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDatetime).SetFlen(types.UnspecifiedLength).SetDecimal(6).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeDatetime, 0))}, 26},
1784+
// test time
1785+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDuration).SetFlen(types.UnspecifiedLength).SetDecimal(0).BuildP(), Value: types.NewDurationDatum(types.NewDuration(10, 10, 10, 110, 0))}, 10},
1786+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDuration).SetFlen(types.UnspecifiedLength).SetDecimal(3).BuildP(), Value: types.NewDurationDatum(types.NewDuration(10, 10, 10, 110, 3))}, 14},
1787+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDuration).SetFlen(types.UnspecifiedLength).SetDecimal(6).BuildP(), Value: types.NewDurationDatum(types.NewDuration(10, 10, 10, 110, 6))}, 17},
1788+
// test date
1789+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeDate).SetFlen(types.UnspecifiedLength).SetDecimal(0).BuildP(), Value: types.NewTimeDatum(types.NewTime(types.FromDate(2020, 10, 10, 10, 10, 10, 110), mysql.TypeDate, 0))}, 10},
1790+
// test json
1791+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeJSON).BuildP(), Value: types.NewJSONDatum(types.CreateBinaryJSON(int64(1)))}, 4294967295},
1792+
// test string
1793+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeString).SetFlen(50).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 50},
1794+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeString).SetFlen(50).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 50},
1795+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeVarString).SetFlen(50).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 50},
1796+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeVarString).SetFlen(50).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 50},
1797+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTinyBlob).SetFlen(types.UnspecifiedLength).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 255},
1798+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeTinyBlob).SetFlen(types.UnspecifiedLength).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 255},
1799+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeBlob).SetFlen(types.UnspecifiedLength).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 262140},
1800+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeBlob).SetFlen(types.UnspecifiedLength).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 262140},
1801+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeMediumBlob).SetFlen(types.UnspecifiedLength).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 67108860},
1802+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeMediumBlob).SetFlen(types.UnspecifiedLength).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 67108860},
1803+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLongBlob).SetFlen(types.UnspecifiedLength).SetCollate("binary").BuildP(), Value: types.NewStringDatum("abcde")}, 4294967295},
1804+
{&Constant{RetType: types.NewFieldTypeBuilder().SetType(mysql.TypeLongBlob).SetFlen(types.UnspecifiedLength).SetCollate("utf8mb4_bin").BuildP(), Value: types.NewStringDatum("abcde")}, 4294967295},
1805+
}
1806+
ctx := createContext(t)
1807+
for i, tc := range allTestCase {
1808+
t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) {
1809+
ft := types.NewFieldTypeBuilder().
1810+
SetType(mysql.TypeString).
1811+
SetFlen(types.UnspecifiedLength).
1812+
SetCharset(charset.CharsetUTF8MB4).
1813+
SetCollate(charset.CollationUTF8MB4).
1814+
BuildP()
1815+
expr, err := BuildCastFunctionWithCheck(ctx, tc.input, ft, false, false)
1816+
require.NoError(t, err)
1817+
require.Equal(t, tc.resultFlen, expr.GetType(ctx).GetFlen())
1818+
})
1819+
}
1820+
}

pkg/parser/mysql/const.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -241,17 +241,22 @@ const (
241241
MaxFloatingTypeWidth = 255
242242
MaxDecimalScale = 30
243243
MaxDecimalWidth = 65
244-
MaxDateWidth = 10 // YYYY-MM-DD.
245-
MaxDatetimeWidthNoFsp = 19 // YYYY-MM-DD HH:MM:SS
246-
MaxDatetimeWidthWithFsp = 26 // YYYY-MM-DD HH:MM:SS[.fraction]
247-
MaxDatetimeFullWidth = 29 // YYYY-MM-DD HH:MM:SS.###### AM
248-
MaxDurationWidthNoFsp = 10 // HH:MM:SS
249-
MaxDurationWidthWithFsp = 17 // HH:MM:SS[.fraction] -838:59:59.000000 to 838:59:59.000000
250-
MaxBlobWidth = 16777216
244+
MaxDateWidth = 10 // YYYY-MM-DD.
245+
MaxDatetimeWidthNoFsp = 19 // YYYY-MM-DD HH:MM:SS
246+
MaxDatetimeWidthWithFsp = 26 // YYYY-MM-DD HH:MM:SS[.fraction]
247+
MaxDatetimeFullWidth = 29 // YYYY-MM-DD HH:MM:SS.###### AM
248+
MaxDurationWidthNoFsp = 10 // HH:MM:SS
249+
MaxDurationWidthWithFsp = 17 // HH:MM:SS[.fraction] -838:59:59.000000 to 838:59:59.000000
250+
MaxBlobWidth = 16777216 // `MaxBlobWidth` is greater than `MaxBlobSize`. It's compatible with MySQL, but doesn't have a good reason.
251251
MaxLongBlobWidth = 4294967295
252252
MaxBitDisplayWidth = 64
253253
MaxFloatPrecisionLength = 24
254254
MaxDoublePrecisionLength = 53
255+
256+
MaxTinyBlobSize = 255
257+
MaxBlobSize = 65535
258+
MaxMediumBlobSize = 16777215
259+
MaxLongBlobSize = 4294967295
255260
)
256261

257262
// MySQL max type field length.

0 commit comments

Comments
 (0)