@@ -18,6 +18,7 @@ import (
18
18
"bytes"
19
19
"context"
20
20
goJSON "encoding/json"
21
+ "strconv"
21
22
"strings"
22
23
23
24
"github.com/pingcap/errors"
@@ -108,7 +109,7 @@ func (b *builtinJSONTypeSig) Clone() builtinFunc {
108
109
}
109
110
110
111
func (c * jsonTypeFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
111
- if err := c .verifyArgs (args ); err != nil {
112
+ if err := c .verifyArgs (ctx . GetEvalCtx (), args ); err != nil {
112
113
return nil , err
113
114
}
114
115
bf , err := newBaseBuiltinFuncWithTp (ctx , c .funcName , args , types .ETString , types .ETJson )
@@ -125,6 +126,51 @@ func (c *jsonTypeFunctionClass) getFunction(ctx BuildContext, args []Expression)
125
126
return sig , nil
126
127
}
127
128
129
+ func (c * jsonTypeFunctionClass ) verifyArgs (ctx EvalContext , args []Expression ) error {
130
+ if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
131
+ return err
132
+ }
133
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
134
+ }
135
+
136
+ // verifyJSONArgsType verifies that all args specified in `jsonArgsIndex` are JSON or non-binary string or NULL.
137
+ // the `useJSONErr` specifies to use `ErrIncorrectType` or `ErrInvalidTypeForJSON`. If it's true, the error will be `ErrInvalidTypeForJSON`
138
+ func verifyJSONArgsType (ctx EvalContext , funcName string , useJSONErr bool , args []Expression , jsonArgsIndex ... int ) error {
139
+ if jsonArgsIndex == nil {
140
+ // if no index is specified, verify all args
141
+ jsonArgsIndex = make ([]int , len (args ))
142
+ for i := 0 ; i < len (args ); i ++ {
143
+ jsonArgsIndex [i ] = i
144
+ }
145
+ }
146
+ for _ , argIndex := range jsonArgsIndex {
147
+ arg := args [argIndex ]
148
+
149
+ typ := arg .GetType (ctx )
150
+ if typ .GetType () == mysql .TypeNull {
151
+ continue
152
+ }
153
+
154
+ evalType := typ .EvalType ()
155
+ switch evalType {
156
+ case types .ETString :
157
+ cs := typ .GetCharset ()
158
+ if cs == charset .CharsetBin {
159
+ return types .ErrInvalidJSONCharset .GenWithStackByArgs (cs )
160
+ }
161
+ continue
162
+ case types .ETJson :
163
+ continue
164
+ default :
165
+ if useJSONErr {
166
+ return ErrInvalidTypeForJSON .GenWithStackByArgs (argIndex + 1 , funcName )
167
+ }
168
+ return ErrIncorrectType .GenWithStackByArgs (strconv .Itoa (argIndex + 1 ), funcName )
169
+ }
170
+ }
171
+ return nil
172
+ }
173
+
128
174
func (b * builtinJSONTypeSig ) evalString (ctx EvalContext , row chunk.Row ) (val string , isNull bool , err error ) {
129
175
var j types.BinaryJSON
130
176
j , isNull , err = b .args [0 ].EvalJSON (ctx , row )
@@ -155,10 +201,7 @@ func (c *jsonExtractFunctionClass) verifyArgs(ctx EvalContext, args []Expression
155
201
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
156
202
return err
157
203
}
158
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
159
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_extract" )
160
- }
161
- return nil
204
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
162
205
}
163
206
164
207
func (c * jsonExtractFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -225,10 +268,7 @@ func (c *jsonUnquoteFunctionClass) verifyArgs(ctx EvalContext, args []Expression
225
268
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
226
269
return err
227
270
}
228
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
229
- return ErrIncorrectType .GenWithStackByArgs ("1" , "json_unquote" )
230
- }
231
- return nil
271
+ return verifyJSONArgsType (ctx , c .funcName , false , args , 0 )
232
272
}
233
273
234
274
func (c * jsonUnquoteFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -469,12 +509,7 @@ func (c *jsonMergeFunctionClass) verifyArgs(ctx EvalContext, args []Expression)
469
509
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
470
510
return err
471
511
}
472
- for i , arg := range args {
473
- if evalType := arg .GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
474
- return ErrInvalidTypeForJSON .GenWithStackByArgs (i + 1 , "json_merge" )
475
- }
476
- }
477
- return nil
512
+ return verifyJSONArgsType (ctx , c .funcName , true , args )
478
513
}
479
514
480
515
type builtinJSONMergeSig struct {
@@ -682,10 +717,7 @@ func (c *jsonContainsPathFunctionClass) verifyArgs(ctx EvalContext, args []Expre
682
717
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
683
718
return err
684
719
}
685
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
686
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_contains_path" )
687
- }
688
- return nil
720
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
689
721
}
690
722
691
723
func (c * jsonContainsPathFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -801,10 +833,7 @@ func (c *jsonMemberOfFunctionClass) verifyArgs(ctx EvalContext, args []Expressio
801
833
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
802
834
return err
803
835
}
804
- if evalType := args [1 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
805
- return ErrInvalidTypeForJSON .GenWithStackByArgs (2 , "member of" )
806
- }
807
- return nil
836
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 1 )
808
837
}
809
838
810
839
func (c * jsonMemberOfFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -867,13 +896,7 @@ func (c *jsonContainsFunctionClass) verifyArgs(ctx EvalContext, args []Expressio
867
896
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
868
897
return err
869
898
}
870
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
871
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_contains" )
872
- }
873
- if evalType := args [1 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
874
- return ErrInvalidTypeForJSON .GenWithStackByArgs (2 , "json_contains" )
875
- }
876
- return nil
899
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 , 1 )
877
900
}
878
901
879
902
func (c * jsonContainsFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -950,13 +973,7 @@ func (c *jsonOverlapsFunctionClass) verifyArgs(ctx EvalContext, args []Expressio
950
973
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
951
974
return err
952
975
}
953
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
954
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_overlaps" )
955
- }
956
- if evalType := args [1 ].GetType (ctx ).EvalType (); evalType != types .ETJson && evalType != types .ETString {
957
- return ErrInvalidTypeForJSON .GenWithStackByArgs (2 , "json_overlaps" )
958
- }
959
- return nil
976
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 , 1 )
960
977
}
961
978
962
979
func (c * jsonOverlapsFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1283,12 +1300,7 @@ func (c *jsonMergePatchFunctionClass) verifyArgs(ctx EvalContext, args []Express
1283
1300
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1284
1301
return err
1285
1302
}
1286
- for i , arg := range args {
1287
- if evalType := arg .GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1288
- return ErrInvalidTypeForJSON .GenWithStackByArgs (i + 1 , "json_merge_patch" )
1289
- }
1290
- }
1291
- return nil
1303
+ return verifyJSONArgsType (ctx , c .funcName , true , args )
1292
1304
}
1293
1305
1294
1306
func (c * jsonMergePatchFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1355,12 +1367,7 @@ func (c *jsonMergePreserveFunctionClass) verifyArgs(ctx EvalContext, args []Expr
1355
1367
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1356
1368
return err
1357
1369
}
1358
- for i , arg := range args {
1359
- if evalType := arg .GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1360
- return ErrInvalidTypeForJSON .GenWithStackByArgs (i + 1 , "json_merge_preserve" )
1361
- }
1362
- }
1363
- return nil
1370
+ return verifyJSONArgsType (ctx , c .funcName , true , args )
1364
1371
}
1365
1372
1366
1373
func (c * jsonMergePreserveFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1509,10 +1516,7 @@ func (c *jsonSearchFunctionClass) verifyArgs(ctx EvalContext, args []Expression)
1509
1516
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1510
1517
return err
1511
1518
}
1512
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1513
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_search" )
1514
- }
1515
- return nil
1519
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
1516
1520
}
1517
1521
1518
1522
func (c * jsonSearchFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1727,10 +1731,7 @@ func (c *jsonKeysFunctionClass) verifyArgs(ctx EvalContext, args []Expression) e
1727
1731
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1728
1732
return err
1729
1733
}
1730
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1731
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_keys" )
1732
- }
1733
- return nil
1734
+ return verifyJSONArgsType (ctx , c .funcName , true , args , 0 )
1734
1735
}
1735
1736
1736
1737
func (c * jsonKeysFunctionClass ) getFunction (ctx BuildContext , args []Expression ) (builtinFunc , error ) {
@@ -1902,11 +1903,9 @@ func (c *jsonSchemaValidFunctionClass) verifyArgs(ctx EvalContext, args []Expres
1902
1903
if err := c .baseFunctionClass .verifyArgs (args ); err != nil {
1903
1904
return err
1904
1905
}
1905
- if evalType := args [0 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1906
- return ErrInvalidTypeForJSON .GenWithStackByArgs (1 , "json_schema_valid" )
1907
- }
1908
- if evalType := args [1 ].GetType (ctx ).EvalType (); evalType != types .ETString && evalType != types .ETJson {
1909
- return ErrInvalidTypeForJSON .GenWithStackByArgs (2 , "json_schema_valid" )
1906
+
1907
+ if err := verifyJSONArgsType (ctx , c .funcName , true , args , 0 , 1 ); err != nil {
1908
+ return err
1910
1909
}
1911
1910
if c , ok := args [0 ].(* Constant ); ok {
1912
1911
// If args[0] is NULL, then don't check the length of *both* arguments.
0 commit comments