Skip to content

Commit cee07d6

Browse files
authored
[fix](function) fix Substring/SubReplace error result with input utf8 string (#40929)
``` mysql [(none)]>select sub_replace("你好世界","a",1); +-------------------------------------+ | sub_replace('你好世界', 'a', 1) | +-------------------------------------+ | �a�好世界 | +-------------------------------------+ mysql [(none)]>select SUBSTRING('中文测试',5); +------------------------------------------+ | substring('中文测试', 5, 2147483647) | +------------------------------------------+ | 中文测试 | +------------------------------------------+ 1 row in set (0.04 sec) now mysql [(none)]>select sub_replace("你好世界","a",1); +-------------------------------------+ | sub_replace('你好世界', 'a', 1) | +-------------------------------------+ | 你a世界 | +-------------------------------------+ 1 row in set (0.05 sec) mysql [(none)]>select SUBSTRING('中文测试',5); +------------------------------------------+ | substring('中文测试', 5, 2147483647) | +------------------------------------------+ | | +------------------------------------------+ 1 row in set (0.13 sec) ```
1 parent 538817a commit cee07d6

File tree

5 files changed

+188
-37
lines changed

5 files changed

+188
-37
lines changed

be/src/vec/functions/function_string.h

Lines changed: 95 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,11 @@ struct SubstringUtil {
242242
const char* str_data = (char*)chars.data() + offsets[i - 1];
243243
int start_value = is_const ? start[0] : start[i];
244244
int len_value = is_const ? len[0] : len[i];
245-
245+
// Unsigned numbers cannot be used here because start_value can be negative.
246+
int char_len = simd::VStringFunctions::get_char_len(str_data, str_size);
246247
// return empty string if start > src.length
247-
if (start_value > str_size || str_size == 0 || start_value == 0 || len_value <= 0) {
248+
// Here, start_value is compared against the length of the character.
249+
if (start_value > char_len || str_size == 0 || start_value == 0 || len_value <= 0) {
248250
StringOP::push_empty_string(i, res_chars, res_offsets);
249251
continue;
250252
}
@@ -3386,8 +3388,6 @@ class FunctionSubReplace : public IFunction {
33863388
return get_variadic_argument_types_impl().size();
33873389
}
33883390

3389-
bool use_default_implementation_for_nulls() const override { return false; }
3390-
33913391
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
33923392
size_t result, size_t input_rows_count) const override {
33933393
return Impl::execute_impl(context, block, arguments, result, input_rows_count);
@@ -3398,59 +3398,116 @@ struct SubReplaceImpl {
33983398
static Status replace_execute(Block& block, const ColumnNumbers& arguments, size_t result,
33993399
size_t input_rows_count) {
34003400
auto res_column = ColumnString::create();
3401-
auto result_column = assert_cast<ColumnString*>(res_column.get());
3401+
auto* result_column = assert_cast<ColumnString*>(res_column.get());
34023402
auto args_null_map = ColumnUInt8::create(input_rows_count, 0);
34033403
ColumnPtr argument_columns[4];
3404+
bool col_const[4];
34043405
for (int i = 0; i < 4; ++i) {
3405-
argument_columns[i] =
3406-
block.get_by_position(arguments[i]).column->convert_to_full_column_if_const();
3407-
if (auto* nullable = check_and_get_column<ColumnNullable>(*argument_columns[i])) {
3408-
// Danger: Here must dispose the null map data first! Because
3409-
// argument_columns[i]=nullable->get_nested_column_ptr(); will release the mem
3410-
// of column nullable mem of null map
3411-
VectorizedUtils::update_null_map(args_null_map->get_data(),
3412-
nullable->get_null_map_data());
3413-
argument_columns[i] = nullable->get_nested_column_ptr();
3414-
}
3406+
std::tie(argument_columns[i], col_const[i]) =
3407+
unpack_if_const(block.get_by_position(arguments[i]).column);
34153408
}
3409+
const auto* data_column = assert_cast<const ColumnString*>(argument_columns[0].get());
3410+
const auto* mask_column = assert_cast<const ColumnString*>(argument_columns[1].get());
3411+
const auto* start_column =
3412+
assert_cast<const ColumnVector<Int32>*>(argument_columns[2].get());
3413+
const auto* length_column =
3414+
assert_cast<const ColumnVector<Int32>*>(argument_columns[3].get());
34163415

3417-
auto data_column = assert_cast<const ColumnString*>(argument_columns[0].get());
3418-
auto mask_column = assert_cast<const ColumnString*>(argument_columns[1].get());
3419-
auto start_column = assert_cast<const ColumnVector<Int32>*>(argument_columns[2].get());
3420-
auto length_column = assert_cast<const ColumnVector<Int32>*>(argument_columns[3].get());
3421-
3422-
vector(data_column, mask_column, start_column->get_data(), length_column->get_data(),
3423-
args_null_map->get_data(), result_column, input_rows_count);
3424-
3416+
std::visit(
3417+
[&](auto origin_str_const, auto new_str_const, auto start_const, auto len_const) {
3418+
if (simd::VStringFunctions::is_ascii(
3419+
StringRef {data_column->get_chars().data(), data_column->size()})) {
3420+
vector_ascii<origin_str_const, new_str_const, start_const, len_const>(
3421+
data_column, mask_column, start_column->get_data(),
3422+
length_column->get_data(), args_null_map->get_data(), result_column,
3423+
input_rows_count);
3424+
} else {
3425+
vector_utf8<origin_str_const, new_str_const, start_const, len_const>(
3426+
data_column, mask_column, start_column->get_data(),
3427+
length_column->get_data(), args_null_map->get_data(), result_column,
3428+
input_rows_count);
3429+
}
3430+
},
3431+
vectorized::make_bool_variant(col_const[0]),
3432+
vectorized::make_bool_variant(col_const[1]),
3433+
vectorized::make_bool_variant(col_const[2]),
3434+
vectorized::make_bool_variant(col_const[3]));
34253435
block.get_by_position(result).column =
34263436
ColumnNullable::create(std::move(res_column), std::move(args_null_map));
34273437
return Status::OK();
34283438
}
34293439

34303440
private:
3431-
static void vector(const ColumnString* data_column, const ColumnString* mask_column,
3432-
const PaddedPODArray<Int32>& start, const PaddedPODArray<Int32>& length,
3433-
NullMap& args_null_map, ColumnString* result_column,
3434-
size_t input_rows_count) {
3441+
template <bool origin_str_const, bool new_str_const, bool start_const, bool len_const>
3442+
static void vector_ascii(const ColumnString* data_column, const ColumnString* mask_column,
3443+
const PaddedPODArray<Int32>& args_start,
3444+
const PaddedPODArray<Int32>& args_length, NullMap& args_null_map,
3445+
ColumnString* result_column, size_t input_rows_count) {
34353446
ColumnString::Chars& res_chars = result_column->get_chars();
34363447
ColumnString::Offsets& res_offsets = result_column->get_offsets();
34373448
for (size_t row = 0; row < input_rows_count; ++row) {
3438-
StringRef origin_str = data_column->get_data_at(row);
3439-
StringRef new_str = mask_column->get_data_at(row);
3440-
size_t origin_str_len = origin_str.size;
3449+
StringRef origin_str =
3450+
data_column->get_data_at(index_check_const<origin_str_const>(row));
3451+
StringRef new_str = mask_column->get_data_at(index_check_const<new_str_const>(row));
3452+
const auto start = args_start[index_check_const<start_const>(row)];
3453+
const auto length = args_length[index_check_const<len_const>(row)];
3454+
const size_t origin_str_len = origin_str.size;
34413455
//input is null, start < 0, len < 0, str_size <= start. return NULL
3442-
if (args_null_map[row] || start[row] < 0 || length[row] < 0 ||
3443-
origin_str_len <= start[row]) {
3456+
if (args_null_map[row] || start < 0 || length < 0 || origin_str_len <= start) {
34443457
res_offsets.push_back(res_chars.size());
34453458
args_null_map[row] = 1;
34463459
} else {
34473460
std::string_view replace_str = new_str.to_string_view();
34483461
std::string result = origin_str.to_string();
3449-
result.replace(start[row], length[row], replace_str);
3462+
result.replace(start, length, replace_str);
34503463
result_column->insert_data(result.data(), result.length());
34513464
}
34523465
}
34533466
}
3467+
3468+
template <bool origin_str_const, bool new_str_const, bool start_const, bool len_const>
3469+
static void vector_utf8(const ColumnString* data_column, const ColumnString* mask_column,
3470+
const PaddedPODArray<Int32>& args_start,
3471+
const PaddedPODArray<Int32>& args_length, NullMap& args_null_map,
3472+
ColumnString* result_column, size_t input_rows_count) {
3473+
ColumnString::Chars& res_chars = result_column->get_chars();
3474+
ColumnString::Offsets& res_offsets = result_column->get_offsets();
3475+
3476+
for (size_t row = 0; row < input_rows_count; ++row) {
3477+
StringRef origin_str =
3478+
data_column->get_data_at(index_check_const<origin_str_const>(row));
3479+
StringRef new_str = mask_column->get_data_at(index_check_const<new_str_const>(row));
3480+
const auto start = args_start[index_check_const<start_const>(row)];
3481+
const auto length = args_length[index_check_const<len_const>(row)];
3482+
//input is null, start < 0, len < 0 return NULL
3483+
if (args_null_map[row] || start < 0 || length < 0) {
3484+
res_offsets.push_back(res_chars.size());
3485+
args_null_map[row] = 1;
3486+
continue;
3487+
}
3488+
3489+
const auto [start_byte_len, start_char_len] =
3490+
simd::VStringFunctions::iterate_utf8_with_limit_length(origin_str.begin(),
3491+
origin_str.end(), start);
3492+
3493+
// start >= orgin.size
3494+
DCHECK(start_char_len <= start);
3495+
if (start_byte_len == origin_str.size) {
3496+
res_offsets.push_back(res_chars.size());
3497+
args_null_map[row] = 1;
3498+
continue;
3499+
}
3500+
3501+
auto [end_byte_len, end_char_len] =
3502+
simd::VStringFunctions::iterate_utf8_with_limit_length(
3503+
origin_str.begin() + start_byte_len, origin_str.end(), length);
3504+
DCHECK(end_char_len <= length);
3505+
std::string_view replace_str = new_str.to_string_view();
3506+
std::string result = origin_str.to_string();
3507+
result.replace(start_byte_len, end_byte_len, replace_str);
3508+
result_column->insert_data(result.data(), result.length());
3509+
}
3510+
}
34543511
};
34553512

34563513
struct SubReplaceThreeImpl {
@@ -3467,13 +3524,14 @@ struct SubReplaceThreeImpl {
34673524

34683525
auto str_col =
34693526
block.get_by_position(arguments[1]).column->convert_to_full_column_if_const();
3470-
if (auto* nullable = check_and_get_column<const ColumnNullable>(*str_col)) {
3527+
if (const auto* nullable = check_and_get_column<const ColumnNullable>(*str_col)) {
34713528
str_col = nullable->get_nested_column_ptr();
34723529
}
3473-
auto& str_offset = assert_cast<const ColumnString*>(str_col.get())->get_offsets();
3474-
3530+
const auto* str_column = assert_cast<const ColumnString*>(str_col.get());
3531+
// use utf8 len
34753532
for (int i = 0; i < input_rows_count; ++i) {
3476-
strlen_data[i] = str_offset[i] - str_offset[i - 1];
3533+
StringRef str_ref = str_column->get_data_at(i);
3534+
strlen_data[i] = simd::VStringFunctions::get_char_len(str_ref.data, str_ref.size);
34773535
}
34783536

34793537
block.insert({std::move(params), std::make_shared<DataTypeInt32>(), "strlen"});

regression-test/data/nereids_p0/sql_functions/string_functions/test_string_function.out

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,63 @@ tNEW-STRorigin str
386386
-- !sql --
387387
d***is
388388

389+
-- !sub_replace_utf8_sql1 --
390+
你a世界
391+
392+
-- !sub_replace_utf8_sql2 --
393+
你ab界
394+
395+
-- !sub_replace_utf8_sql3 --
396+
你ab
397+
398+
-- !sub_replace_utf8_sql4 --
399+
你abcd我界
400+
401+
-- !sub_replace_utf8_sql5 --
402+
\N
403+
404+
-- !sub_replace_utf8_sql6 --
405+
大家世界
406+
407+
-- !sub_replace_utf8_sql7 --
408+
你大家114514
409+
410+
-- !sub_replace_utf8_sql8 --
411+
\N
412+
413+
-- !sub_replace_utf8_sql9 --
414+
\N
415+
416+
-- !sub_replace_utf8_sql10 --
417+
\N
418+
419+
-- !sub_replace_utf8_sql1 --
420+
你a世界
421+
422+
-- !sub_replace_utf8_sql2 --
423+
你ab界
424+
425+
-- !sub_replace_utf8_sql3 --
426+
你ab
427+
428+
-- !sub_replace_utf8_sql4 --
429+
你abcd我界
430+
431+
-- !sub_replace_utf8_sql5 --
432+
\N
433+
434+
-- !sub_replace_utf8_sql6 --
435+
大家世界
436+
437+
-- !sub_replace_utf8_sql7 --
438+
你大家114514
439+
440+
-- !sub_replace_utf8_sql8 --
441+
\N
442+
443+
-- !sub_replace_utf8_sql9 --
444+
\N
445+
446+
-- !sub_replace_utf8_sql10 --
447+
\N
448+
248 Bytes
Binary file not shown.

regression-test/suites/nereids_p0/sql_functions/string_functions/test_string_function.groovy

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,4 +191,27 @@ suite("test_string_function") {
191191

192192
qt_sql "select sub_replace(\"this is origin str\",\"NEW-STR\",1);"
193193
qt_sql "select sub_replace(\"doris\",\"***\",1,2);"
194+
sql """ set debug_skip_fold_constant = true;"""
195+
qt_sub_replace_utf8_sql1 " select sub_replace('你好世界','a',1);"
196+
qt_sub_replace_utf8_sql2 " select sub_replace('你好世界','ab',1);"
197+
qt_sub_replace_utf8_sql3 " select sub_replace('你好世界','ab',1,20);"
198+
qt_sub_replace_utf8_sql4 " select sub_replace('你好世界','abcd我',1,2);"
199+
qt_sub_replace_utf8_sql5 " select sub_replace('你好世界','a',6);"
200+
qt_sub_replace_utf8_sql6 " select sub_replace('你好世界','大家',0);"
201+
qt_sub_replace_utf8_sql7 " select sub_replace('你好世界','大家114514',1,20);"
202+
qt_sub_replace_utf8_sql8 " select sub_replace('你好世界','大家114514',6,20);"
203+
qt_sub_replace_utf8_sql9 " select sub_replace('你好世界','大家',4);"
204+
qt_sub_replace_utf8_sql10 " select sub_replace('你好世界','大家',-1);"
205+
sql """ set debug_skip_fold_constant = false;"""
206+
qt_sub_replace_utf8_sql1 " select sub_replace('你好世界','a',1);"
207+
qt_sub_replace_utf8_sql2 " select sub_replace('你好世界','ab',1);"
208+
qt_sub_replace_utf8_sql3 " select sub_replace('你好世界','ab',1,20);"
209+
qt_sub_replace_utf8_sql4 " select sub_replace('你好世界','abcd我',1,2);"
210+
qt_sub_replace_utf8_sql5 " select sub_replace('你好世界','a',6);"
211+
qt_sub_replace_utf8_sql6 " select sub_replace('你好世界','大家',0);"
212+
qt_sub_replace_utf8_sql7 " select sub_replace('你好世界','大家114514',1,20);"
213+
qt_sub_replace_utf8_sql8 " select sub_replace('你好世界','大家114514',6,20);"
214+
qt_sub_replace_utf8_sql9 " select sub_replace('你好世界','大家',4);"
215+
qt_sub_replace_utf8_sql10 " select sub_replace('你好世界','大家',-1);"
216+
194217
}

regression-test/suites/query_p0/sql_functions/string_functions/test_string_function.groovy

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,16 @@ suite("test_string_function", "arrow_flight_sql") {
228228
qt_sql "select substring('abcdef',3,-1);"
229229
qt_sql "select substring('abcdef',-3,-1);"
230230
qt_sql "select substring('abcdef',10,1);"
231+
sql """ set debug_skip_fold_constant = true;"""
232+
qt_substring_utf8_sql "select substring('中文测试',5);"
233+
qt_substring_utf8_sql "select substring('中文测试',4);"
234+
qt_substring_utf8_sql "select substring('中文测试',2,2);"
235+
qt_substring_utf8_sql "select substring('中文测试',-1,2);"
236+
sql """ set debug_skip_fold_constant = false;"""
237+
qt_substring_utf8_sql "select substring('中文测试',5);"
238+
qt_substring_utf8_sql "select substring('中文测试',4);"
239+
qt_substring_utf8_sql "select substring('中文测试',2,2);"
240+
qt_substring_utf8_sql "select substring('中文测试',-1,2);"
231241

232242
sql """ drop table if exists test_string_function; """
233243
sql """ create table test_string_function (

0 commit comments

Comments
 (0)