@@ -18,6 +18,7 @@ import (
18
18
"bytes"
19
19
"context"
20
20
goJSON "encoding/json"
21
+ "strconv"
21
22
"strings"
22
23
23
24
"github.com/pingcap/errors"
@@ -105,7 +106,7 @@ func (b *builtinJSONTypeSig) Clone() builtinFunc {
105
106
}
106
107
107
108
func (c * jsonTypeFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
108
- if err := c .verifyArgs (args ); err != nil {
109
+ if err := c .verifyArgs (ctx . GetEvalCtx (), args ); err != nil {
109
110
return nil , err
110
111
}
111
112
bf , err := newBaseBuiltinFuncWithTp (ctx , c .funcName , args , types .ETString , types .ETJson )
@@ -122,6 +123,51 @@ func (c *jsonTypeFunctionClass) getFunction(ctx BuildContext, args []Expression)
122
123
return sig , nil
123
124
}
124
125
126
+ func (c * jsonTypeFunctionClass ) verifyArgs (ctx EvalContext , args []Expression ) error {
127
+ if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
128
+ return err
129
+ }
130
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
131
+ }
132
+
133
+ // verifyJSONArgsType verifies that all args specified in `jsonArgsIndex` are JSON or non-binary string or NULL.
134
+ // the `useJSONErr` specifies to use `ErrIncorrectType` or `ErrInvalidTypeForJSON`. If it's true, the error will be `ErrInvalidTypeForJSON`
135
+ func verifyJSONArgsType (ctx EvalContext , funcName string , useJSONErr bool , args []Expression , jsonArgsIndex ... int ) error {
136
+ if jsonArgsIndex == nil {
137
+ // if no index is specified, verify all args
138
+ jsonArgsIndex = make ([]int , len (args ))
139
+ for i := 0 ; i < len (args ); i ++ {
140
+ jsonArgsIndex [i ] = i
141
+ }
142
+ }
143
+ for _ , argIndex := range jsonArgsIndex {
144
+ arg := args [argIndex ]
145
+
146
+ typ := arg .GetType (ctx )
147
+ if typ .GetType () == mysql .TypeNull {
148
+ continue
149
+ }
150
+
151
+ evalType := typ .EvalType ()
152
+ switch evalType {
153
+ case types .ETString :
154
+ cs := typ .GetCharset ()
155
+ if cs == charset .CharsetBin {
156
+ return types .ErrInvalidJSONCharset .GenWithStackByArgs (cs )
157
+ }
158
+ continue
159
+ case types .ETJson :
160
+ continue
161
+ default :
162
+ if useJSONErr {
163
+ return ErrInvalidTypeForJSON .GenWithStackByArgs (argIndex + 1 , funcName )
164
+ }
165
+ return ErrIncorrectType .GenWithStackByArgs (strconv .Itoa (argIndex + 1 ), funcName )
166
+ }
167
+ }
168
+ return nil
169
+ }
170
+
125
171
func (b * builtinJSONTypeSig ) evalString (ctx EvalContext , row chunk.Row ) (val string , isNull bool , err error ) {
126
172
var j types.BinaryJSON
127
173
j , isNull , err = b .args [0 ].EvalJSON (ctx , row )
@@ -149,10 +195,7 @@ func (c *jsonExtractFunctionClass) verifyArgs(ctx EvalContext, args []Expression
149
195
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
150
196
return err
151
197
}
152
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
153
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_extract" )
154
- }
155
- return nil
198
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
156
199
}
157
200
158
201
func (c * jsonExtractFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -216,10 +259,7 @@ func (c *jsonUnquoteFunctionClass) verifyArgs(ctx EvalContext, args []Expression
216
259
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
217
260
return err
218
261
}
219
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
220
- return ErrIncorrectType .GenWithStackByArgs ("1" , "json_unquote" )
221
- }
222
- return nil
262
+ return verifyJSONArgsType (ctx , c .funcName , false , args , 0 )
223
263
}
224
264
225
265
func (c * jsonUnquoteFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -448,12 +488,7 @@ func (c *jsonMergeFunctionClass) verifyArgs(ctx EvalContext, args []Expression)
448
488
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
449
489
return err
450
490
}
451
- for i , arg := range args {
452
- if evalType := arg .GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
453
- return ErrInvalidTypeForJSON .GenWithStackByArgs (i + 1 , "json_merge" )
454
- }
455
- }
456
- return nil
491
+ return verifyJSONArgsType (ctx , c .funcName , true , args )
457
492
}
458
493
459
494
type builtinJSONMergeSig struct {
@@ -649,10 +684,7 @@ func (c *jsonContainsPathFunctionClass) verifyArgs(ctx EvalContext, args []Expre
649
684
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
650
685
return err
651
686
}
652
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
653
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_contains_path" )
654
- }
655
- return nil
687
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
656
688
}
657
689
658
690
func (c * jsonContainsPathFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -765,10 +797,7 @@ func (c *jsonMemberOfFunctionClass) verifyArgs(ctx EvalContext, args []Expressio
765
797
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
766
798
return err
767
799
}
768
- if evalType := args [1 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
769
- return ErrInvalidTypeForJSON .GenWithStackByArgs (2 , "member of" )
770
- }
771
- return nil
800
+ return verifyJSONArgsType (ctx , "member of" , true , args , 1 )
772
801
}
773
802
774
803
func (c * jsonMemberOfFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -828,13 +857,7 @@ func (c *jsonContainsFunctionClass) verifyArgs(ctx EvalContext, args []Expressio
828
857
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
829
858
return err
830
859
}
831
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
832
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_contains" )
833
- }
834
- if evalType := args [1 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
835
- return ErrInvalidTypeForJSON .GenWithStackByArgs (2 , "json_contains" )
836
- }
837
- return nil
860
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 , 1 )
838
861
}
839
862
840
863
func (c * jsonContainsFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -908,13 +931,7 @@ func (c *jsonOverlapsFunctionClass) verifyArgs(ctx EvalContext, args []Expressio
908
931
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
909
932
return err
910
933
}
911
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
912
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_overlaps" )
913
- }
914
- if evalType := args [1 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
915
- return ErrInvalidTypeForJSON .GenWithStackByArgs (2 , "json_overlaps" )
916
- }
917
- return nil
934
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 , 1 )
918
935
}
919
936
920
937
func (c * jsonOverlapsFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1227,12 +1244,7 @@ func (c *jsonMergePatchFunctionClass) verifyArgs(ctx EvalContext, args []Express
1227
1244
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1228
1245
return err
1229
1246
}
1230
- for i , arg := range args {
1231
- if evalType := arg .GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1232
- return ErrInvalidTypeForJSON .GenWithStackByArgs (i + 1 , "json_merge_patch" )
1233
- }
1234
- }
1235
- return nil
1247
+ return verifyJSONArgsType (ctx , c .funcName , true , args )
1236
1248
}
1237
1249
1238
1250
func (c * jsonMergePatchFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1296,12 +1308,7 @@ func (c *jsonMergePreserveFunctionClass) verifyArgs(ctx EvalContext, args []Expr
1296
1308
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1297
1309
return err
1298
1310
}
1299
- for i , arg := range args {
1300
- if evalType := arg .GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1301
- return ErrInvalidTypeForJSON .GenWithStackByArgs (i + 1 , "json_merge_preserve" )
1302
- }
1303
- }
1304
- return nil
1311
+ return verifyJSONArgsType (ctx , c .funcName , true , args )
1305
1312
}
1306
1313
1307
1314
func (c * jsonMergePreserveFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1441,10 +1448,7 @@ func (c *jsonSearchFunctionClass) verifyArgs(ctx EvalContext, args []Expression)
1441
1448
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1442
1449
return err
1443
1450
}
1444
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1445
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_search" )
1446
- }
1447
- return nil
1451
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
1448
1452
}
1449
1453
1450
1454
func (c * jsonSearchFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1650,10 +1654,7 @@ func (c *jsonKeysFunctionClass) verifyArgs(ctx EvalContext, args []Expression) e
1650
1654
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1651
1655
return err
1652
1656
}
1653
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1654
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_keys" )
1655
- }
1656
- return nil
1657
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
1657
1658
}
1658
1659
1659
1660
func (c * jsonKeysFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1816,11 +1817,9 @@ func (c *jsonSchemaValidFunctionClass) verifyArgs(ctx EvalContext, args []Expres
1816
1817
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1817
1818
return err
1818
1819
}
1819
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1820
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_schema_valid" )
1821
- }
1822
- if evalType := args [1 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1823
- return ErrInvalidTypeForJSON .GenWithStackByArgs (2 , "json_schema_valid" )
1820
+
1821
+ if err := verifyJSONArgsType (ctx , c .funcName , true , args , 0 , 1 ); err != nil {
1822
+ return err
1824
1823
}
1825
1824
if c , ok := args [0 ].(* Constant ); ok {
1826
1825
// If args[0] is NULL, then don't check the length of *both* arguments.
0 commit comments