Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 11 additions & 62 deletions pkg/expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,19 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
return nil, err
}

fieldTps := make([]*types.FieldType, 0, len(args))
flag := uint(0)
for _, arg := range args {
fieldTps = append(fieldTps, arg.GetType())
flag |= arg.GetType().GetFlag() & mysql.NotNullFlag
}

// Use the aggregated field type as retType.
resultFieldType := types.AggFieldType(fieldTps)
var tempType uint
resultEvalType := types.AggregateEvalType(fieldTps, &tempType)
resultFieldType.SetFlag(tempType)
retEvalTp := resultFieldType.EvalType()
resultFieldType, err := InferType4ControlFuncs(ctx, c.funcName, args...)
if err != nil {
return nil, err
}

resultFieldType.AddFlag(flag)

retEvalTp := resultFieldType.EvalType()
fieldEvalTps := make([]types.EvalType, 0, len(args))
for range args {
fieldEvalTps = append(fieldEvalTps, retEvalTp)
Expand All @@ -141,60 +142,7 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
return nil, err
}

bf.tp.AddFlag(resultFieldType.GetFlag())
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(types.UnspecifiedLength)

// Set retType to BINARY(0) if all arguments are of type NULL.
if resultFieldType.GetType() == mysql.TypeNull {
types.SetBinChsClnFlag(bf.tp)
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(0)
} else {
maxIntLen := 0
maxFlen := 0

// Find the max length of field in `maxFlen`,
// and max integer-part length in `maxIntLen`.
for _, argTp := range fieldTps {
if argTp.GetDecimal() > resultFieldType.GetDecimal() {
resultFieldType.SetDecimalUnderLimit(argTp.GetDecimal())
}
argIntLen := argTp.GetFlen()
if argTp.GetDecimal() > 0 {
argIntLen -= argTp.GetDecimal() + 1
}

// Reduce the sign bit if it is a signed integer/decimal
if !mysql.HasUnsignedFlag(argTp.GetFlag()) {
argIntLen--
}
if argIntLen > maxIntLen {
maxIntLen = argIntLen
}
if argTp.GetFlen() > maxFlen || argTp.GetFlen() == types.UnspecifiedLength {
maxFlen = argTp.GetFlen()
}
}
// For integer, field length = maxIntLen + (1/0 for sign bit)
// For decimal, field length = maxIntLen + maxDecimal + (1/0 for sign bit)
if resultEvalType == types.ETInt || resultEvalType == types.ETDecimal {
resultFieldType.SetFlenUnderLimit(maxIntLen + resultFieldType.GetDecimal())
if resultFieldType.GetDecimal() > 0 {
resultFieldType.SetFlenUnderLimit(resultFieldType.GetFlen() + 1)
}
if !mysql.HasUnsignedFlag(resultFieldType.GetFlag()) {
resultFieldType.SetFlenUnderLimit(resultFieldType.GetFlen() + 1)
}
bf.tp = resultFieldType
} else {
bf.tp.SetFlen(maxFlen)
}
// Set the field length to maxFlen for other types.
if bf.tp.GetFlen() > mysql.MaxDecimalWidth {
bf.tp.SetFlen(mysql.MaxDecimalWidth)
}
}
bf.tp = resultFieldType

switch retEvalTp {
case types.ETInt:
Expand Down Expand Up @@ -1252,6 +1200,7 @@ func (b *builtinIntervalRealSig) evalInt(row chunk.Row) (int64, bool, error) {
if isNull {
return -1, false, nil
}

var idx int
if b.hasNullable {
idx, err = b.linearSearch(arg0, b.args[1:], row)
Expand Down
9 changes: 9 additions & 0 deletions pkg/expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,12 @@ func TestRefineArgsWithCastEnum(t *testing.T) {
require.Equal(t, zeroUintConst, args[0])
require.Equal(t, enumCol, args[1])
}

func TestIssue46475(t *testing.T) {
ctx := createContext(t)
args := []interface{}{nil, dt, nil}

f, err := newFunctionForTest(ctx, ast.Coalesce, primitiveValsToConstants(ctx, args)...)
require.NoError(t, err)
require.Equal(t, f.GetType().GetType(), mysql.TypeDate)
}
61 changes: 58 additions & 3 deletions pkg/expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,25 @@ func setDecimalFromArgs(evalType types.EvalType, resultFieldType *types.FieldTyp
}
}

// NonBinaryStr means the arg is a string but not binary string
func hasNonBinaryStr(args []*types.FieldType) bool {
for _, arg := range args {
if types.IsNonBinaryStr(arg) {
return true
}
}
return false
}

func hasBinaryStr(args []*types.FieldType) bool {
for _, arg := range args {
if types.IsBinaryStr(arg) {
return true
}
}
return false
}

func addCollateAndCharsetAndFlagFromArgs(ctx sessionctx.Context, funcName string, evalType types.EvalType, resultFieldType *types.FieldType, args ...Expression) error {
switch funcName {
case ast.If, ast.Ifnull, ast.WindowFuncLead, ast.WindowFuncLag:
Expand Down Expand Up @@ -170,13 +189,49 @@ func addCollateAndCharsetAndFlagFromArgs(ctx sessionctx.Context, funcName string
break
}
}
case ast.Coalesce: // TODO ast.Case and ast.Coalesce should be merged into the same branch
argTypes := make([]*types.FieldType, 0)
for _, arg := range args {
argTypes = append(argTypes, arg.GetType())
}

nonBinaryStrExist := hasNonBinaryStr(argTypes)
binaryStrExist := hasBinaryStr(argTypes)
if !binaryStrExist && nonBinaryStrExist {
ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, args...)
if err != nil {
return err
}
resultFieldType.SetCollate(ec.Collation)
resultFieldType.SetCharset(ec.Charset)
resultFieldType.SetFlag(0)

// hasNonStringType means that there is a type that is not string
hasNonStringType := false
for _, argType := range argTypes {
if !types.IsString(argType.GetType()) {
hasNonStringType = true
break
}
}

if hasNonStringType {
resultFieldType.AddFlag(mysql.BinaryFlag)
}
} else if binaryStrExist || !evalType.IsStringKind() {
types.SetBinChsClnFlag(resultFieldType)
} else {
resultFieldType.SetCharset(mysql.DefaultCharset)
resultFieldType.SetCollate(mysql.DefaultCollationName)
resultFieldType.SetFlag(0)
}
default:
panic("unexpected function: " + funcName)
}
return nil
}

// InferType4ControlFuncs infer result type for builtin IF, IFNULL, NULLIF, CASEWHEN, LEAD and LAG.
// InferType4ControlFuncs infer result type for builtin IF, IFNULL, NULLIF, CASEWHEN, COALESCE, LEAD and LAG.
func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, args ...Expression) (*types.FieldType, error) {
argsNum := len(args)
if argsNum == 0 {
Expand All @@ -198,8 +253,8 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, args ...Exp
tempFlag := resultFieldType.GetFlag()
types.SetTypeFlag(&tempFlag, mysql.NotNullFlag, false)
resultFieldType.SetFlag(tempFlag)
// If both arguments are NULL, make resulting type BINARY(0).
resultFieldType.SetType(mysql.TypeString)

resultFieldType.SetType(mysql.TypeNull)
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(0)
types.SetBinChsClnFlag(resultFieldType)
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/expr_to_pb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ func TestOtherFunc2Pb(t *testing.T) {
pbExprs, err := ExpressionsToPBList(sc, otherFuncs, client)
require.NoError(t, err)
jsons := map[string]string{
ast.Coalesce: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}],\"sig\":4201,\"field_type\":{\"tp\":3,\"flag\":128,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}",
ast.Coalesce: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}],\"sig\":4201,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}",
ast.IsNull: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}],\"sig\":3116,\"field_type\":{\"tp\":8,\"flag\":524417,\"flen\":1,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}",
}
for i, pbExpr := range pbExprs {
Expand Down
14 changes: 10 additions & 4 deletions pkg/expression/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1042,10 +1042,16 @@ func (s *InferTypeSuite) createTestCase4EncryptionFuncs() []typeInferTestCase {

func (s *InferTypeSuite) createTestCase4CompareFuncs() []typeInferTestCase {
return []typeInferTestCase{
{"coalesce(c_int_d, 1)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"coalesce(NULL, c_int_d)", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"coalesce(c_int_d, c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 15, 3},
{"coalesce(c_int_d, c_datetime)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 22, types.UnspecifiedLength},
{"coalesce(c_int_d, c_int_d)", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"coalesce(c_int_d, c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 14, 3},
{"coalesce(c_int_d, c_char)", mysql.TypeString, charset.CharsetUTF8MB4, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"coalesce(c_int_d, c_binary)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"coalesce(c_char, c_binary)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"coalesce(null, null)", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag, 0, 0},
{"coalesce(c_double_d, c_timestamp_d)", mysql.TypeVarchar, charset.CharsetUTF8MB4, 0, 22, types.UnspecifiedLength},
{"coalesce(c_json, c_decimal)", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, math.MaxUint32, types.UnspecifiedLength},
{"coalesce(c_time, c_date)", mysql.TypeDatetime, charset.CharsetUTF8MB4, 0, mysql.MaxDatetimeWidthNoFsp + 3 + 1, 3},
{"coalesce(c_time_d, c_date)", mysql.TypeDatetime, charset.CharsetUTF8MB4, 0, mysql.MaxDatetimeWidthNoFsp, 0},

{"isnull(c_int_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.NotNullFlag | mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0},
{"isnull(c_bigint_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.NotNullFlag | mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0},
Expand Down