Skip to content

Commit 0e5c428

Browse files
authored
[fix](function) fix Substring/SubReplace error result with input utf8… (#40954)
… string (#40929) #40929 ``` mysql [(none)]>select sub_replace("你好世界","a",1); +-------------------------------------+ | sub_replace('你好世界', 'a', 1) | +-------------------------------------+ | �a�好世界 | +-------------------------------------+ mysql [(none)]>select SUBSTRING('中文测试',5); +------------------------------------------+ | substring('中文测试', 5, 2147483647) | +------------------------------------------+ | 中文测试 | +------------------------------------------+ 1 row in set (0.04 sec) now mysql [(none)]>select sub_replace("你好世界","a",1); +-------------------------------------+ | sub_replace('你好世界', 'a', 1) | +-------------------------------------+ | 你a世界 | +-------------------------------------+ 1 row in set (0.05 sec) mysql [(none)]>select SUBSTRING('中文测试',5); +------------------------------------------+ | substring('中文测试', 5, 2147483647) | +------------------------------------------+ | | +------------------------------------------+ 1 row in set (0.13 sec) ``` ## Proposed changes Issue Number: close #xxx <!--Describe your changes.-->
1 parent d44ee1c commit 0e5c428

File tree

4 files changed

+194
-38
lines changed

4 files changed

+194
-38
lines changed

be/src/util/simd/vstring_function.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,23 @@ class VStringFunctions {
224224
}
225225
}
226226

227+
// Iterate a UTF-8 string without exceeding a given length n.
228+
// The function returns two values:
229+
// the first represents the byte length traversed, and the second represents the char length traversed.
230+
static inline std::pair<size_t, size_t> iterate_utf8_with_limit_length(const char* begin,
231+
const char* end,
232+
size_t n) {
233+
const char* p = begin;
234+
int char_size = 0;
235+
236+
size_t i = 0;
237+
for (; i < n && p < end; ++i, p += char_size) {
238+
char_size = UTF8_BYTE_LENGTH[static_cast<uint8_t>(*p)];
239+
}
240+
241+
return {p - begin, i};
242+
}
243+
227244
static void hex_encode(const unsigned char* src_str, size_t length, char* dst_str) {
228245
static constexpr auto hex_table = "0123456789ABCDEF";
229246
auto src_str_end = src_str + length;

be/src/vec/functions/function_string.h

Lines changed: 96 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,11 @@ struct SubstringUtil {
216216
const char* str_data = (char*)chars.data() + offsets[i - 1];
217217
int start_value = is_const ? start[0] : start[i];
218218
int len_value = is_const ? len[0] : len[i];
219-
219+
// Unsigned numbers cannot be used here because start_value can be negative.
220+
int char_len = simd::VStringFunctions::get_char_len(str_data, str_size);
220221
// return empty string if start > src.length
221-
if (start_value > str_size || str_size == 0 || start_value == 0 || len_value <= 0) {
222+
// Here, start_value is compared against the length of the character.
223+
if (start_value > char_len || str_size == 0 || start_value == 0 || len_value <= 0) {
222224
StringOP::push_empty_string(i, res_chars, res_offsets);
223225
continue;
224226
}
@@ -3728,8 +3730,6 @@ class FunctionSubReplace : public IFunction {
37283730
return get_variadic_argument_types_impl().size();
37293731
}
37303732

3731-
bool use_default_implementation_for_nulls() const override { return false; }
3732-
37333733
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
37343734
size_t result, size_t input_rows_count) const override {
37353735
return Impl::execute_impl(context, block, arguments, result, input_rows_count);
@@ -3740,59 +3740,116 @@ struct SubReplaceImpl {
37403740
static Status replace_execute(Block& block, const ColumnNumbers& arguments, size_t result,
37413741
size_t input_rows_count) {
37423742
auto res_column = ColumnString::create();
3743-
auto result_column = assert_cast<ColumnString*>(res_column.get());
3743+
auto* result_column = assert_cast<ColumnString*>(res_column.get());
37443744
auto args_null_map = ColumnUInt8::create(input_rows_count, 0);
37453745
ColumnPtr argument_columns[4];
3746+
bool col_const[4];
37463747
for (int i = 0; i < 4; ++i) {
3747-
argument_columns[i] =
3748-
block.get_by_position(arguments[i]).column->convert_to_full_column_if_const();
3749-
if (auto* nullable = check_and_get_column<ColumnNullable>(*argument_columns[i])) {
3750-
// Danger: Here must dispose the null map data first! Because
3751-
// argument_columns[i]=nullable->get_nested_column_ptr(); will release the mem
3752-
// of column nullable mem of null map
3753-
VectorizedUtils::update_null_map(args_null_map->get_data(),
3754-
nullable->get_null_map_data());
3755-
argument_columns[i] = nullable->get_nested_column_ptr();
3756-
}
3748+
std::tie(argument_columns[i], col_const[i]) =
3749+
unpack_if_const(block.get_by_position(arguments[i]).column);
37573750
}
3758-
3759-
auto data_column = assert_cast<const ColumnString*>(argument_columns[0].get());
3760-
auto mask_column = assert_cast<const ColumnString*>(argument_columns[1].get());
3761-
auto start_column = assert_cast<const ColumnVector<Int32>*>(argument_columns[2].get());
3762-
auto length_column = assert_cast<const ColumnVector<Int32>*>(argument_columns[3].get());
3763-
3764-
vector(data_column, mask_column, start_column->get_data(), length_column->get_data(),
3765-
args_null_map->get_data(), result_column, input_rows_count);
3766-
3751+
const auto* data_column = assert_cast<const ColumnString*>(argument_columns[0].get());
3752+
const auto* mask_column = assert_cast<const ColumnString*>(argument_columns[1].get());
3753+
const auto* start_column =
3754+
assert_cast<const ColumnVector<Int32>*>(argument_columns[2].get());
3755+
const auto* length_column =
3756+
assert_cast<const ColumnVector<Int32>*>(argument_columns[3].get());
3757+
3758+
std::visit(
3759+
[&](auto origin_str_const, auto new_str_const, auto start_const, auto len_const) {
3760+
if (simd::VStringFunctions::is_ascii(
3761+
StringRef {data_column->get_chars().data(), data_column->size()})) {
3762+
vector_ascii<origin_str_const, new_str_const, start_const, len_const>(
3763+
data_column, mask_column, start_column->get_data(),
3764+
length_column->get_data(), args_null_map->get_data(), result_column,
3765+
input_rows_count);
3766+
} else {
3767+
vector_utf8<origin_str_const, new_str_const, start_const, len_const>(
3768+
data_column, mask_column, start_column->get_data(),
3769+
length_column->get_data(), args_null_map->get_data(), result_column,
3770+
input_rows_count);
3771+
}
3772+
},
3773+
vectorized::make_bool_variant(col_const[0]),
3774+
vectorized::make_bool_variant(col_const[1]),
3775+
vectorized::make_bool_variant(col_const[2]),
3776+
vectorized::make_bool_variant(col_const[3]));
37673777
block.get_by_position(result).column =
37683778
ColumnNullable::create(std::move(res_column), std::move(args_null_map));
37693779
return Status::OK();
37703780
}
37713781

37723782
private:
3773-
static void vector(const ColumnString* data_column, const ColumnString* mask_column,
3774-
const PaddedPODArray<Int32>& start, const PaddedPODArray<Int32>& length,
3775-
NullMap& args_null_map, ColumnString* result_column,
3776-
size_t input_rows_count) {
3783+
template <bool origin_str_const, bool new_str_const, bool start_const, bool len_const>
3784+
static void vector_ascii(const ColumnString* data_column, const ColumnString* mask_column,
3785+
const PaddedPODArray<Int32>& args_start,
3786+
const PaddedPODArray<Int32>& args_length, NullMap& args_null_map,
3787+
ColumnString* result_column, size_t input_rows_count) {
37773788
ColumnString::Chars& res_chars = result_column->get_chars();
37783789
ColumnString::Offsets& res_offsets = result_column->get_offsets();
37793790
for (size_t row = 0; row < input_rows_count; ++row) {
3780-
StringRef origin_str = data_column->get_data_at(row);
3781-
StringRef new_str = mask_column->get_data_at(row);
3782-
size_t origin_str_len = origin_str.size;
3791+
StringRef origin_str =
3792+
data_column->get_data_at(index_check_const<origin_str_const>(row));
3793+
StringRef new_str = mask_column->get_data_at(index_check_const<new_str_const>(row));
3794+
const auto start = args_start[index_check_const<start_const>(row)];
3795+
const auto length = args_length[index_check_const<len_const>(row)];
3796+
const size_t origin_str_len = origin_str.size;
37833797
//input is null, start < 0, len < 0, str_size <= start. return NULL
3784-
if (args_null_map[row] || start[row] < 0 || length[row] < 0 ||
3785-
origin_str_len <= start[row]) {
3798+
if (args_null_map[row] || start < 0 || length < 0 || origin_str_len <= start) {
37863799
res_offsets.push_back(res_chars.size());
37873800
args_null_map[row] = 1;
37883801
} else {
37893802
std::string_view replace_str = new_str.to_string_view();
37903803
std::string result = origin_str.to_string();
3791-
result.replace(start[row], length[row], replace_str);
3804+
result.replace(start, length, replace_str);
37923805
result_column->insert_data(result.data(), result.length());
37933806
}
37943807
}
37953808
}
3809+
3810+
template <bool origin_str_const, bool new_str_const, bool start_const, bool len_const>
3811+
static void vector_utf8(const ColumnString* data_column, const ColumnString* mask_column,
3812+
const PaddedPODArray<Int32>& args_start,
3813+
const PaddedPODArray<Int32>& args_length, NullMap& args_null_map,
3814+
ColumnString* result_column, size_t input_rows_count) {
3815+
ColumnString::Chars& res_chars = result_column->get_chars();
3816+
ColumnString::Offsets& res_offsets = result_column->get_offsets();
3817+
3818+
for (size_t row = 0; row < input_rows_count; ++row) {
3819+
StringRef origin_str =
3820+
data_column->get_data_at(index_check_const<origin_str_const>(row));
3821+
StringRef new_str = mask_column->get_data_at(index_check_const<new_str_const>(row));
3822+
const auto start = args_start[index_check_const<start_const>(row)];
3823+
const auto length = args_length[index_check_const<len_const>(row)];
3824+
//input is null, start < 0, len < 0 return NULL
3825+
if (args_null_map[row] || start < 0 || length < 0) {
3826+
res_offsets.push_back(res_chars.size());
3827+
args_null_map[row] = 1;
3828+
continue;
3829+
}
3830+
3831+
const auto [start_byte_len, start_char_len] =
3832+
simd::VStringFunctions::iterate_utf8_with_limit_length(origin_str.begin(),
3833+
origin_str.end(), start);
3834+
3835+
// start >= orgin.size
3836+
DCHECK(start_char_len <= start);
3837+
if (start_byte_len == origin_str.size) {
3838+
res_offsets.push_back(res_chars.size());
3839+
args_null_map[row] = 1;
3840+
continue;
3841+
}
3842+
3843+
auto [end_byte_len, end_char_len] =
3844+
simd::VStringFunctions::iterate_utf8_with_limit_length(
3845+
origin_str.begin() + start_byte_len, origin_str.end(), length);
3846+
DCHECK(end_char_len <= length);
3847+
std::string_view replace_str = new_str.to_string_view();
3848+
std::string result = origin_str.to_string();
3849+
result.replace(start_byte_len, end_byte_len, replace_str);
3850+
result_column->insert_data(result.data(), result.length());
3851+
}
3852+
}
37963853
};
37973854

37983855
struct SubReplaceThreeImpl {
@@ -3809,13 +3866,14 @@ struct SubReplaceThreeImpl {
38093866

38103867
auto str_col =
38113868
block.get_by_position(arguments[1]).column->convert_to_full_column_if_const();
3812-
if (auto* nullable = check_and_get_column<const ColumnNullable>(*str_col)) {
3869+
if (const auto* nullable = check_and_get_column<const ColumnNullable>(*str_col)) {
38133870
str_col = nullable->get_nested_column_ptr();
38143871
}
3815-
auto& str_offset = assert_cast<const ColumnString*>(str_col.get())->get_offsets();
3816-
3872+
const auto* str_column = assert_cast<const ColumnString*>(str_col.get());
3873+
// use utf8 len
38173874
for (int i = 0; i < input_rows_count; ++i) {
3818-
strlen_data[i] = str_offset[i] - str_offset[i - 1];
3875+
StringRef str_ref = str_column->get_data_at(i);
3876+
strlen_data[i] = simd::VStringFunctions::get_char_len(str_ref.data, str_ref.size);
38193877
}
38203878

38213879
block.insert({std::move(params), std::make_shared<DataTypeInt32>(), "strlen"});

regression-test/data/nereids_p0/sql_functions/string_functions/test_string_function.out

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,63 @@ tNEW-STRorigin str
386386
-- !sql --
387387
d***is
388388

389+
-- !sub_replace_utf8_sql1 --
390+
你a世界
391+
392+
-- !sub_replace_utf8_sql2 --
393+
你ab界
394+
395+
-- !sub_replace_utf8_sql3 --
396+
你ab
397+
398+
-- !sub_replace_utf8_sql4 --
399+
你abcd我界
400+
401+
-- !sub_replace_utf8_sql5 --
402+
\N
403+
404+
-- !sub_replace_utf8_sql6 --
405+
大家世界
406+
407+
-- !sub_replace_utf8_sql7 --
408+
你大家114514
409+
410+
-- !sub_replace_utf8_sql8 --
411+
\N
412+
413+
-- !sub_replace_utf8_sql9 --
414+
\N
415+
416+
-- !sub_replace_utf8_sql10 --
417+
\N
418+
419+
-- !sub_replace_utf8_sql1 --
420+
你a世界
421+
422+
-- !sub_replace_utf8_sql2 --
423+
你ab界
424+
425+
-- !sub_replace_utf8_sql3 --
426+
你ab
427+
428+
-- !sub_replace_utf8_sql4 --
429+
你abcd我界
430+
431+
-- !sub_replace_utf8_sql5 --
432+
\N
433+
434+
-- !sub_replace_utf8_sql6 --
435+
大家世界
436+
437+
-- !sub_replace_utf8_sql7 --
438+
你大家114514
439+
440+
-- !sub_replace_utf8_sql8 --
441+
\N
442+
443+
-- !sub_replace_utf8_sql9 --
444+
\N
445+
446+
-- !sub_replace_utf8_sql10 --
447+
\N
448+

regression-test/suites/nereids_p0/sql_functions/string_functions/test_string_function.groovy

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,4 +191,25 @@ suite("test_string_function") {
191191

192192
qt_sql "select sub_replace(\"this is origin str\",\"NEW-STR\",1);"
193193
qt_sql "select sub_replace(\"doris\",\"***\",1,2);"
194+
qt_sub_replace_utf8_sql1 " select sub_replace('你好世界','a',1);"
195+
qt_sub_replace_utf8_sql2 " select sub_replace('你好世界','ab',1);"
196+
qt_sub_replace_utf8_sql3 " select sub_replace('你好世界','ab',1,20);"
197+
qt_sub_replace_utf8_sql4 " select sub_replace('你好世界','abcd我',1,2);"
198+
qt_sub_replace_utf8_sql5 " select sub_replace('你好世界','a',6);"
199+
qt_sub_replace_utf8_sql6 " select sub_replace('你好世界','大家',0);"
200+
qt_sub_replace_utf8_sql7 " select sub_replace('你好世界','大家114514',1,20);"
201+
qt_sub_replace_utf8_sql8 " select sub_replace('你好世界','大家114514',6,20);"
202+
qt_sub_replace_utf8_sql9 " select sub_replace('你好世界','大家',4);"
203+
qt_sub_replace_utf8_sql10 " select sub_replace('你好世界','大家',-1);"
204+
qt_sub_replace_utf8_sql1 " select sub_replace('你好世界','a',1);"
205+
qt_sub_replace_utf8_sql2 " select sub_replace('你好世界','ab',1);"
206+
qt_sub_replace_utf8_sql3 " select sub_replace('你好世界','ab',1,20);"
207+
qt_sub_replace_utf8_sql4 " select sub_replace('你好世界','abcd我',1,2);"
208+
qt_sub_replace_utf8_sql5 " select sub_replace('你好世界','a',6);"
209+
qt_sub_replace_utf8_sql6 " select sub_replace('你好世界','大家',0);"
210+
qt_sub_replace_utf8_sql7 " select sub_replace('你好世界','大家114514',1,20);"
211+
qt_sub_replace_utf8_sql8 " select sub_replace('你好世界','大家114514',6,20);"
212+
qt_sub_replace_utf8_sql9 " select sub_replace('你好世界','大家',4);"
213+
qt_sub_replace_utf8_sql10 " select sub_replace('你好世界','大家',-1);"
214+
194215
}

0 commit comments

Comments
 (0)