@@ -18,6 +18,7 @@ import (
18
18
"bytes"
19
19
"context"
20
20
goJSON "encoding/json"
21
+ "strconv"
21
22
"strings"
22
23
23
24
"github.com/pingcap/errors"
@@ -108,7 +109,7 @@ func (b *builtinJSONTypeSig) Clone() builtinFunc {
108
109
}
109
110
110
111
func (c * jsonTypeFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
111
- if err := c .verifyArgs (args ); err != nil {
112
+ if err := c .verifyArgs (ctx . GetEvalCtx (), args ); err != nil {
112
113
return nil , err
113
114
}
114
115
bf , err := newBaseBuiltinFuncWithTp (ctx , c .funcName , args , types .ETString , types .ETJson )
@@ -125,6 +126,51 @@ func (c *jsonTypeFunctionClass) getFunction(ctx BuildContext, args []Expression)
125
126
return sig , nil
126
127
}
127
128
129
+ func (c * jsonTypeFunctionClass ) verifyArgs (ctx EvalContext , args []Expression ) error {
130
+ if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
131
+ return err
132
+ }
133
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
134
+ }
135
+
136
+ // verifyJSONArgsType verifies that all args specified in `jsonArgsIndex` are JSON or non-binary string or NULL.
137
+ // the `useJSONErr` specifies to use `ErrIncorrectType` or `ErrInvalidTypeForJSON`. If it's true, the error will be `ErrInvalidTypeForJSON`
138
+ func verifyJSONArgsType (ctx EvalContext , funcName string , useJSONErr bool , args []Expression , jsonArgsIndex ... int ) error {
139
+ if jsonArgsIndex == nil {
140
+ // if no index is specified, verify all args
141
+ jsonArgsIndex = make ([]int , len (args ))
142
+ for i := 0 ; i < len (args ); i ++ {
143
+ jsonArgsIndex [i ] = i
144
+ }
145
+ }
146
+ for _ , argIndex := range jsonArgsIndex {
147
+ arg := args [argIndex ]
148
+
149
+ typ := arg .GetType (ctx )
150
+ if typ .GetType () == mysql .TypeNull {
151
+ continue
152
+ }
153
+
154
+ evalType := typ .EvalType ()
155
+ switch evalType {
156
+ case types .ETString :
157
+ cs := typ .GetCharset ()
158
+ if cs == charset .CharsetBin {
159
+ return types .ErrInvalidJSONCharset .GenWithStackByArgs (cs )
160
+ }
161
+ continue
162
+ case types .ETJson :
163
+ continue
164
+ default :
165
+ if useJSONErr {
166
+ return ErrInvalidTypeForJSON .GenWithStackByArgs (argIndex + 1 , funcName )
167
+ }
168
+ return ErrIncorrectType .GenWithStackByArgs (strconv .Itoa (argIndex + 1 ), funcName )
169
+ }
170
+ }
171
+ return nil
172
+ }
173
+
128
174
func (b * builtinJSONTypeSig ) evalString (ctx EvalContext , row chunk.Row ) (val string , isNull bool , err error ) {
129
175
var j types.BinaryJSON
130
176
j , isNull , err = b .args [0 ].EvalJSON (ctx , row )
@@ -155,10 +201,7 @@ func (c *jsonExtractFunctionClass) verifyArgs(ctx EvalContext, args []Expression
155
201
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
156
202
return err
157
203
}
158
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
159
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_extract" )
160
- }
161
- return nil
204
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
162
205
}
163
206
164
207
func (c * jsonExtractFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -225,10 +268,7 @@ func (c *jsonUnquoteFunctionClass) verifyArgs(ctx EvalContext, args []Expression
225
268
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
226
269
return err
227
270
}
228
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
229
- return ErrIncorrectType .GenWithStackByArgs ("1" , "json_unquote" )
230
- }
231
- return nil
271
+ return verifyJSONArgsType (ctx , c .funcName , false , args , 0 )
232
272
}
233
273
234
274
func (c * jsonUnquoteFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -469,12 +509,7 @@ func (c *jsonMergeFunctionClass) verifyArgs(ctx EvalContext, args []Expression)
469
509
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
470
510
return err
471
511
}
472
- for i , arg := range args {
473
- if evalType := arg .GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
474
- return ErrInvalidTypeForJSON .GenWithStackByArgs (i + 1 , "json_merge" )
475
- }
476
- }
477
- return nil
512
+ return verifyJSONArgsType (ctx , c .funcName , true , args )
478
513
}
479
514
480
515
type builtinJSONMergeSig struct {
@@ -682,10 +717,7 @@ func (c *jsonContainsPathFunctionClass) verifyArgs(ctx EvalContext, args []Expre
682
717
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
683
718
return err
684
719
}
685
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
686
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_contains_path" )
687
- }
688
- return nil
720
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
689
721
}
690
722
691
723
func (c * jsonContainsPathFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -801,10 +833,7 @@ func (c *jsonMemberOfFunctionClass) verifyArgs(ctx EvalContext, args []Expressio
801
833
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
802
834
return err
803
835
}
804
- if evalType := args [1 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
805
- return ErrInvalidTypeForJSON .GenWithStackByArgs (2 , "member of" )
806
- }
807
- return nil
836
+ return verifyJSONArgsType (ctx , "member of" , true , args , 1 )
808
837
}
809
838
810
839
func (c * jsonMemberOfFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -867,13 +896,7 @@ func (c *jsonContainsFunctionClass) verifyArgs(ctx EvalContext, args []Expressio
867
896
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
868
897
return err
869
898
}
870
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
871
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_contains" )
872
- }
873
- if evalType := args [1 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
874
- return ErrInvalidTypeForJSON .GenWithStackByArgs (2 , "json_contains" )
875
- }
876
- return nil
899
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 , 1 )
877
900
}
878
901
879
902
func (c * jsonContainsFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -950,13 +973,7 @@ func (c *jsonOverlapsFunctionClass) verifyArgs(ctx EvalContext, args []Expressio
950
973
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
951
974
return err
952
975
}
953
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
954
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_overlaps" )
955
- }
956
- if evalType := args [1 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
957
- return ErrInvalidTypeForJSON .GenWithStackByArgs (2 , "json_overlaps" )
958
- }
959
- return nil
976
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 , 1 )
960
977
}
961
978
962
979
func (c * jsonOverlapsFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1284,12 +1301,7 @@ func (c *jsonMergePatchFunctionClass) verifyArgs(ctx EvalContext, args []Express
1284
1301
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1285
1302
return err
1286
1303
}
1287
- for i , arg := range args {
1288
- if evalType := arg .GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1289
- return ErrInvalidTypeForJSON .GenWithStackByArgs (i + 1 , "json_merge_patch" )
1290
- }
1291
- }
1292
- return nil
1304
+ return verifyJSONArgsType (ctx , c .funcName , true , args )
1293
1305
}
1294
1306
1295
1307
func (c * jsonMergePatchFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1356,12 +1368,7 @@ func (c *jsonMergePreserveFunctionClass) verifyArgs(ctx EvalContext, args []Expr
1356
1368
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1357
1369
return err
1358
1370
}
1359
- for i , arg := range args {
1360
- if evalType := arg .GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1361
- return ErrInvalidTypeForJSON .GenWithStackByArgs (i + 1 , "json_merge_preserve" )
1362
- }
1363
- }
1364
- return nil
1371
+ return verifyJSONArgsType (ctx , c .funcName , true , args )
1365
1372
}
1366
1373
1367
1374
func (c * jsonMergePreserveFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1510,10 +1517,7 @@ func (c *jsonSearchFunctionClass) verifyArgs(ctx EvalContext, args []Expression)
1510
1517
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1511
1518
return err
1512
1519
}
1513
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1514
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_search" )
1515
- }
1516
- return nil
1520
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
1517
1521
}
1518
1522
1519
1523
func (c * jsonSearchFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1728,10 +1732,7 @@ func (c *jsonKeysFunctionClass) verifyArgs(ctx EvalContext, args []Expression) e
1728
1732
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1729
1733
return err
1730
1734
}
1731
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1732
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_keys" )
1733
- }
1734
- return nil
1735
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
1735
1736
}
1736
1737
1737
1738
func (c * jsonKeysFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1903,11 +1904,9 @@ func (c *jsonSchemaValidFunctionClass) verifyArgs(ctx EvalContext, args []Expres
1903
1904
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1904
1905
return err
1905
1906
}
1906
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1907
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_schema_valid" )
1908
- }
1909
- if evalType := args [1 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1910
- return ErrInvalidTypeForJSON .GenWithStackByArgs (2 , "json_schema_valid" )
1907
+
1908
+ if err := verifyJSONArgsType (ctx , c .funcName , true , args , 0 , 1 ); err != nil {
1909
+ return err
1911
1910
}
1912
1911
if c , ok := args [0 ].(* Constant ); ok {
1913
1912
// If args[0] is NULL, then don't check the length of *both* arguments.
0 commit comments