Skip to content

Commit 13ced20

Browse files
authored
[feature](function) add approx_top_sum aggregation function (#43643)
Problem Summary: 1. The function approx_top_sum has been implemented. Here is an example of its usage: select approx_top_sum(c1, c2, c3, 10, 300) from tbl. ### Release note Add new function `approx_top_sum`.
1 parent db03e33 commit 13ced20

17 files changed

+798
-162
lines changed

be/src/vec/aggregate_functions/aggregate_function.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class IDataType;
4343

4444
struct AggregateFunctionAttr {
4545
bool enable_decimal256 {false};
46-
std::vector<std::pair<std::string, bool>> column_infos;
46+
std::vector<std::string> column_names;
4747
};
4848

4949
template <bool nullable, typename ColVecType>

be/src/vec/aggregate_functions/aggregate_function_approx_top.h

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,92 @@
1818
#pragma once
1919

2020
#include "vec/core/types.h"
21+
#include "vec/data_types/data_type.h"
22+
#include "vec/data_types/data_type_nullable.h"
2123

2224
namespace doris::vectorized {
2325

2426
class AggregateFunctionApproxTop {
2527
public:
28+
AggregateFunctionApproxTop(const std::vector<std::string>& column_names)
29+
: _column_names(column_names) {}
30+
31+
static int32_t is_valid_const_columns(const std::vector<bool>& is_const_columns) {
32+
int32_t true_count = 0;
33+
bool found_false_after_true = false;
34+
for (int32_t i = is_const_columns.size() - 1; i >= 0; --i) {
35+
if (is_const_columns[i]) {
36+
true_count++;
37+
if (found_false_after_true) {
38+
return false;
39+
}
40+
} else {
41+
if (true_count > 2) {
42+
return false;
43+
}
44+
found_false_after_true = true;
45+
}
46+
}
47+
if (true_count > 2) {
48+
throw Exception(ErrorCode::INVALID_ARGUMENT, "Invalid is_const_columns configuration");
49+
}
50+
return true_count;
51+
}
52+
53+
protected:
54+
void lazy_init(const IColumn** columns, ssize_t row_num,
55+
const DataTypes& argument_types) const {
56+
auto get_param = [](size_t idx, const DataTypes& data_types,
57+
const IColumn** columns) -> uint64_t {
58+
const auto& data_type = data_types.at(idx);
59+
const IColumn* column = columns[idx];
60+
61+
const auto* type = data_type.get();
62+
if (type->is_nullable()) {
63+
type = assert_cast<const DataTypeNullable*, TypeCheckOnRelease::DISABLE>(type)
64+
->get_nested_type()
65+
.get();
66+
}
67+
int64_t value = 0;
68+
WhichDataType which(type);
69+
if (which.idx == TypeIndex::Int8) {
70+
value = assert_cast<const ColumnInt8*, TypeCheckOnRelease::DISABLE>(column)
71+
->get_element(0);
72+
} else if (which.idx == TypeIndex::Int16) {
73+
value = assert_cast<const ColumnInt16*, TypeCheckOnRelease::DISABLE>(column)
74+
->get_element(0);
75+
} else if (which.idx == TypeIndex::Int32) {
76+
value = assert_cast<const ColumnInt32*, TypeCheckOnRelease::DISABLE>(column)
77+
->get_element(0);
78+
}
79+
if (value <= 0) {
80+
throw Exception(ErrorCode::INVALID_ARGUMENT,
81+
"The parameter cannot be less than or equal to 0.");
82+
}
83+
return value;
84+
};
85+
86+
_threshold =
87+
std::min(get_param(_column_names.size(), argument_types, columns), (uint64_t)4096);
88+
_reserved = std::min(
89+
std::max(get_param(_column_names.size() + 1, argument_types, columns), _threshold),
90+
(uint64_t)4096);
91+
92+
if (_threshold == 0 || _reserved == 0 || _threshold > 4096 || _reserved > 4096) {
93+
throw Exception(ErrorCode::INTERNAL_ERROR,
94+
"approx_top_sum param error, _threshold: {}, _reserved: {}", _threshold,
95+
_reserved);
96+
}
97+
98+
_init_flag = true;
99+
}
100+
26101
static inline constexpr UInt64 TOP_K_MAX_SIZE = 0xFFFFFF;
102+
103+
mutable std::vector<std::string> _column_names;
104+
mutable bool _init_flag = false;
105+
mutable uint64_t _threshold = 10;
106+
mutable uint64_t _reserved = 30;
27107
};
28108

29109
} // namespace doris::vectorized

be/src/vec/aggregate_functions/aggregate_function_approx_top_k.cpp

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,58 +24,16 @@
2424

2525
namespace doris::vectorized {
2626

27-
int32_t is_valid_const_columns(const std::vector<bool>& is_const_columns) {
28-
int32_t true_count = 0;
29-
bool found_false_after_true = false;
30-
for (int32_t i = is_const_columns.size() - 1; i >= 0; --i) {
31-
if (is_const_columns[i]) {
32-
true_count++;
33-
if (found_false_after_true) {
34-
return false;
35-
}
36-
} else {
37-
if (true_count > 2) {
38-
return false;
39-
}
40-
found_false_after_true = true;
41-
}
42-
}
43-
if (true_count > 2) {
44-
throw Exception(ErrorCode::INVALID_ARGUMENT, "Invalid is_const_columns configuration");
45-
}
46-
return true_count;
47-
}
48-
4927
AggregateFunctionPtr create_aggregate_function_approx_top_k(const std::string& name,
5028
const DataTypes& argument_types,
5129
const bool result_is_nullable,
5230
const AggregateFunctionAttr& attr) {
53-
if (argument_types.empty()) {
31+
if (argument_types.size() < 3) {
5432
return nullptr;
5533
}
5634

57-
std::vector<bool> is_const_columns;
58-
std::vector<std::string> column_names;
59-
for (const auto& [name, is_const] : attr.column_infos) {
60-
is_const_columns.push_back(is_const);
61-
if (!is_const) {
62-
column_names.push_back(name);
63-
}
64-
}
65-
66-
int32_t true_count = is_valid_const_columns(is_const_columns);
67-
if (true_count == 0) {
68-
return creator_without_type::create<AggregateFunctionApproxTopK<0>>(
69-
argument_types, result_is_nullable, column_names);
70-
} else if (true_count == 1) {
71-
return creator_without_type::create<AggregateFunctionApproxTopK<1>>(
72-
argument_types, result_is_nullable, column_names);
73-
} else if (true_count == 2) {
74-
return creator_without_type::create<AggregateFunctionApproxTopK<2>>(
75-
argument_types, result_is_nullable, column_names);
76-
} else {
77-
return nullptr;
78-
}
35+
return creator_without_type::create<AggregateFunctionApproxTopK>(
36+
argument_types, result_is_nullable, attr.column_names);
7937
}
8038

8139
void register_aggregate_function_approx_top_k(AggregateFunctionSimpleFactory& factory) {

be/src/vec/aggregate_functions/aggregate_function_approx_top_k.h

Lines changed: 5 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -45,28 +45,25 @@
4545

4646
namespace doris::vectorized {
4747

48-
inline constexpr UInt64 TOP_K_MAX_SIZE = 0xFFFFFF;
49-
5048
struct AggregateFunctionTopKGenericData {
5149
using Set = SpaceSaving<StringRef, StringRefHash>;
5250

5351
Set value;
5452
};
5553

56-
template <int32_t ArgsSize>
5754
class AggregateFunctionApproxTopK final
5855
: public IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData,
59-
AggregateFunctionApproxTopK<ArgsSize>>,
56+
AggregateFunctionApproxTopK>,
6057
AggregateFunctionApproxTop {
6158
private:
6259
using State = AggregateFunctionTopKGenericData;
6360

6461
public:
65-
AggregateFunctionApproxTopK(std::vector<std::string> column_names,
62+
AggregateFunctionApproxTopK(const std::vector<std::string>& column_names,
6663
const DataTypes& argument_types_)
6764
: IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData,
68-
AggregateFunctionApproxTopK<ArgsSize>>(argument_types_),
69-
_column_names(std::move(column_names)) {}
65+
AggregateFunctionApproxTopK>(argument_types_),
66+
AggregateFunctionApproxTop(column_names) {}
7067

7168
String get_name() const override { return "approx_top_k"; }
7269

@@ -141,7 +138,7 @@ class AggregateFunctionApproxTopK final
141138
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
142139
Arena* arena) const override {
143140
if (!_init_flag) {
144-
lazy_init(columns, row_num);
141+
lazy_init(columns, row_num, this->get_argument_types());
145142
}
146143

147144
auto& set = this->data(place).value;
@@ -227,64 +224,6 @@ class AggregateFunctionApproxTopK final
227224
std::string res = buffer.GetString();
228225
data_to.insert_data(res.data(), res.size());
229226
}
230-
231-
private:
232-
void lazy_init(const IColumn** columns, ssize_t row_num) const {
233-
auto get_param = [](size_t idx, const DataTypes& data_types,
234-
const IColumn** columns) -> uint64_t {
235-
const auto& data_type = data_types.at(idx);
236-
const IColumn* column = columns[idx];
237-
238-
const auto* type = data_type.get();
239-
if (type->is_nullable()) {
240-
type = assert_cast<const DataTypeNullable*, TypeCheckOnRelease::DISABLE>(type)
241-
->get_nested_type()
242-
.get();
243-
}
244-
int64_t value = 0;
245-
WhichDataType which(type);
246-
if (which.idx == TypeIndex::Int8) {
247-
value = assert_cast<const ColumnInt8*, TypeCheckOnRelease::DISABLE>(column)
248-
->get_element(0);
249-
} else if (which.idx == TypeIndex::Int16) {
250-
value = assert_cast<const ColumnInt16*, TypeCheckOnRelease::DISABLE>(column)
251-
->get_element(0);
252-
} else if (which.idx == TypeIndex::Int32) {
253-
value = assert_cast<const ColumnInt32*, TypeCheckOnRelease::DISABLE>(column)
254-
->get_element(0);
255-
}
256-
if (value <= 0) {
257-
throw Exception(ErrorCode::INVALID_ARGUMENT,
258-
"The parameter cannot be less than or equal to 0.");
259-
}
260-
return value;
261-
};
262-
263-
const auto& data_types = this->get_argument_types();
264-
if (ArgsSize == 1) {
265-
_threshold =
266-
std::min(get_param(_column_names.size(), data_types, columns), (uint64_t)1000);
267-
} else if (ArgsSize == 2) {
268-
_threshold =
269-
std::min(get_param(_column_names.size(), data_types, columns), (uint64_t)1000);
270-
_reserved = std::min(
271-
std::max(get_param(_column_names.size() + 1, data_types, columns), _threshold),
272-
(uint64_t)1000);
273-
}
274-
275-
if (_threshold == 0 || _reserved == 0 || _threshold > 1000 || _reserved > 1000) {
276-
throw Exception(ErrorCode::INTERNAL_ERROR,
277-
"approx_top_k param error, _threshold: {}, _reserved: {}", _threshold,
278-
_reserved);
279-
}
280-
281-
_init_flag = true;
282-
}
283-
284-
mutable std::vector<std::string> _column_names;
285-
mutable bool _init_flag = false;
286-
mutable uint64_t _threshold = 10;
287-
mutable uint64_t _reserved = 300;
288227
};
289228

290229
} // namespace doris::vectorized
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "vec/aggregate_functions/aggregate_function_approx_top_sum.h"
19+
20+
#include "common/exception.h"
21+
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
22+
#include "vec/aggregate_functions/helpers.h"
23+
#include "vec/data_types/data_type.h"
24+
25+
namespace doris::vectorized {
26+
27+
template <size_t N>
28+
AggregateFunctionPtr create_aggregate_function_multi_top_sum_impl(
29+
const DataTypes& argument_types, const bool result_is_nullable,
30+
const std::vector<std::string>& column_names) {
31+
if (N == argument_types.size() - 3) {
32+
return creator_with_type_base<true, false, false, N>::template create<
33+
AggregateFunctionApproxTopSumSimple>(argument_types, result_is_nullable,
34+
column_names);
35+
} else {
36+
return create_aggregate_function_multi_top_sum_impl<N - 1>(
37+
argument_types, result_is_nullable, column_names);
38+
}
39+
}
40+
41+
template <>
42+
AggregateFunctionPtr create_aggregate_function_multi_top_sum_impl<0>(
43+
const DataTypes& argument_types, const bool result_is_nullable,
44+
const std::vector<std::string>& column_names) {
45+
return creator_with_type_base<true, false, false, 0>::template create<
46+
AggregateFunctionApproxTopSumSimple>(argument_types, result_is_nullable, column_names);
47+
}
48+
49+
AggregateFunctionPtr create_aggregate_function_approx_top_sum(const std::string& name,
50+
const DataTypes& argument_types,
51+
const bool result_is_nullable,
52+
const AggregateFunctionAttr& attr) {
53+
if (argument_types.size() < 3) {
54+
return nullptr;
55+
}
56+
57+
constexpr size_t max_param_value = 10;
58+
if (argument_types.size() > max_param_value) {
59+
throw Exception(ErrorCode::INTERNAL_ERROR,
60+
"Argument types size exceeds the supported limit.");
61+
}
62+
63+
return create_aggregate_function_multi_top_sum_impl<max_param_value>(
64+
argument_types, result_is_nullable, attr.column_names);
65+
}
66+
67+
void register_aggregate_function_approx_top_sum(AggregateFunctionSimpleFactory& factory) {
68+
factory.register_function_both("approx_top_sum", create_aggregate_function_approx_top_sum);
69+
}
70+
71+
} // namespace doris::vectorized

0 commit comments

Comments
 (0)