Skip to content

Commit 0c44067

Browse files
authored
[SYCL] Optimize gradients calculations. (#10325)
--------- Co-authored-by: Dmitry Razdoburdin <>
1 parent c9f5fca commit 0c44067

File tree

3 files changed

+386
-91
lines changed

3 files changed

+386
-91
lines changed

plugin/sycl/common/linalg_op.h

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
/**
2+
* Copyright 2021-2024, XGBoost Contributors
3+
* \file linalg_op.h
4+
*/
5+
#ifndef PLUGIN_SYCL_COMMON_LINALG_OP_H_
6+
#define PLUGIN_SYCL_COMMON_LINALG_OP_H_
7+
8+
#include <vector>
9+
#include <utility>
10+
11+
#include "../data.h"
12+
13+
#include <CL/sycl.hpp>
14+
15+
namespace xgboost {
16+
namespace sycl {
17+
namespace linalg {
18+
19+
struct WorkGroupsParams {
20+
size_t n_workgroups;
21+
size_t workgroup_size;
22+
};
23+
24+
template <typename Fn>
25+
::sycl::event GroupWiseKernel(::sycl::queue* qu, int* flag_ptr,
26+
const std::vector<::sycl::event>& events,
27+
const WorkGroupsParams& wg, Fn &&fn) {
28+
::sycl::buffer<int, 1> flag_buf(flag_ptr, 1);
29+
auto event = qu->submit([&](::sycl::handler& cgh) {
30+
cgh.depends_on(events);
31+
auto flag = flag_buf.get_access<::sycl::access::mode::write>(cgh);
32+
cgh.parallel_for_work_group<>(::sycl::range<1>(wg.n_workgroups),
33+
::sycl::range<1>(wg.workgroup_size),
34+
[=](::sycl::group<1> group) {
35+
group.parallel_for_work_item([&](::sycl::h_item<1> item) {
36+
const size_t idx = item.get_global_id()[0];
37+
fn(idx, flag);
38+
});
39+
});
40+
});
41+
return event;
42+
}
43+
44+
struct Argument {
45+
template <typename T>
46+
operator T&&() const;
47+
};
48+
49+
template <typename Fn, typename Is, typename = void>
50+
struct ArgumentsPassedImpl
51+
: std::false_type {};
52+
53+
template <typename Fn, size_t ...Is>
54+
struct ArgumentsPassedImpl<Fn, std::index_sequence<Is...>,
55+
decltype(std::declval<Fn>()(((void)Is, Argument{})...), void())>
56+
: std::true_type {};
57+
58+
template <typename Fn, size_t N>
59+
struct ArgumentsPassed : ArgumentsPassedImpl<Fn, std::make_index_sequence<N>> {};
60+
61+
template <typename OutputDType, typename InputDType,
62+
size_t BatchSize, size_t MaxNumInputs>
63+
class BatchProcessingHelper {
64+
public:
65+
static constexpr size_t kBatchSize = BatchSize;
66+
using InputType = HostDeviceVector<InputDType>;
67+
using OutputType = HostDeviceVector<OutputDType>;
68+
69+
private:
70+
template <size_t NumInput = 0>
71+
void Host2Buffers(InputDType* in_buffer_ptr, const InputType& input) {
72+
/*
73+
* Some inputs may have less than 1 sample per output symbol.
74+
*/
75+
const size_t sub_sample_rate = ndata_ * sample_rates_[NumInput+1] / input.Size();
76+
const size_t n_samples = batch_size_ * sample_rates_[NumInput+1] / sub_sample_rate;
77+
78+
const InputDType* in_host_ptr = input.HostPointer() +
79+
batch_begin_ * sample_rates_[NumInput+1] / sub_sample_rate;
80+
81+
events_[NumInput] =
82+
qu_->memcpy(in_buffer_ptr, in_host_ptr, n_samples * sizeof(InputDType),
83+
events_[MaxNumInputs - 2]);
84+
}
85+
86+
template <size_t NumInput = 0, class... InputTypes>
87+
void Host2Buffers(InputDType* in_buffer_ptr, const InputType& input,
88+
const InputTypes&... other_inputs) {
89+
// Make copy for the first input in the list
90+
Host2Buffers<NumInput>(in_buffer_ptr, input);
91+
// Recurent call for next inputs
92+
InputDType* next_input = in_buffer_.Data() + in_buff_offsets_[NumInput + 1];
93+
Host2Buffers<NumInput+1>(next_input, other_inputs...);
94+
}
95+
96+
void Buffers2Host(OutputType* output) {
97+
const size_t n_samples = batch_size_ * sample_rates_[0];
98+
OutputDType* out_host_ptr = output->HostPointer() + batch_begin_* sample_rates_[0];
99+
events_[MaxNumInputs - 1] =
100+
qu_->memcpy(out_host_ptr, out_buffer_.DataConst(), n_samples * sizeof(OutputDType),
101+
events_[MaxNumInputs - 2]);
102+
}
103+
104+
void Buffers2Host(InputType* output) {
105+
const size_t n_samples = batch_size_ * sample_rates_[1];
106+
InputDType* out_host_ptr = output->HostPointer() + batch_begin_* sample_rates_[1];
107+
events_[MaxNumInputs - 1] =
108+
qu_->memcpy(out_host_ptr, in_buffer_.DataConst(), n_samples * sizeof(InputDType),
109+
events_[MaxNumInputs - 2]);
110+
}
111+
112+
template <size_t NumInputs = 1, typename Fn, class... InputTypes>
113+
void Call(Fn &&fn, const InputDType* input, const InputTypes*... other_inputs) {
114+
static_assert(NumInputs <= MaxNumInputs,
115+
"To many arguments in the passed function");
116+
/* Passed lambda may have less inputs than MaxNumInputs,
117+
* need to pass only requared number of arguments
118+
*/
119+
// 1 for events, 1 for batch_size, 1 for output
120+
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1 + 1>::value) {
121+
events_[MaxNumInputs - 2] = fn(events_, batch_size_,
122+
out_buffer_.Data(), input, other_inputs...);
123+
} else {
124+
const InputDType* next_input = in_buffer_.DataConst() +
125+
in_buff_offsets_[MaxNumInputs - 1 - NumInputs];
126+
Call<NumInputs+1>(std::forward<Fn>(fn), next_input, input, other_inputs...);
127+
}
128+
}
129+
130+
template <size_t NumInputs = 1, typename Fn, class... InputTypes>
131+
void Call(Fn &&fn, InputDType* io, const InputDType* input, const InputTypes*... other_inputs) {
132+
static_assert(NumInputs <= MaxNumInputs,
133+
"To many arguments in the passed function");
134+
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1>::value) {
135+
events_[MaxNumInputs - 2] = fn(events_, batch_size_,
136+
io, input, other_inputs...);
137+
} else {
138+
const InputDType* next_input = in_buffer_.DataConst() +
139+
in_buff_offsets_[MaxNumInputs - NumInputs];
140+
Call<NumInputs+1>(std::forward<Fn>(fn), io, next_input, input, other_inputs...);
141+
}
142+
}
143+
144+
template <size_t NumInputs = 1, typename Fn>
145+
void Call(Fn &&fn, InputDType* io) {
146+
static_assert(NumInputs <= MaxNumInputs,
147+
"To many arguments in the passed function");
148+
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1>::value) {
149+
events_[MaxNumInputs - 2] = fn(events_, batch_size_, io);
150+
} else {
151+
const InputDType* next_input = in_buffer_.DataConst() +
152+
in_buff_offsets_[MaxNumInputs - 1];
153+
Call<NumInputs+1>(std::forward<Fn>(fn), io, next_input);
154+
}
155+
}
156+
157+
public:
158+
BatchProcessingHelper() = default;
159+
160+
// The first element of sample_rate always corresonds to output sample rate
161+
void InitBuffers(::sycl::queue* qu, const std::vector<int>& sample_rate) {
162+
assert(sample_rate.size() == MaxNumInputs + 1);
163+
sample_rates_ = sample_rate;
164+
qu_ = qu;
165+
events_.resize(MaxNumInputs + 2);
166+
out_buffer_.Resize(qu, kBatchSize * sample_rate.front());
167+
168+
in_buff_offsets_[0] = 0;
169+
for (size_t i = 1; i < MaxNumInputs; ++i) {
170+
in_buff_offsets_[i] = in_buff_offsets_[i - 1] + kBatchSize * sample_rate[i];
171+
}
172+
const size_t in_buff_size = in_buff_offsets_.back() + kBatchSize * sample_rate.back();
173+
in_buffer_.Resize(qu, in_buff_size);
174+
}
175+
176+
/*
177+
* Batch-wise proces on sycl device
178+
* output = fn(inputs)
179+
*/
180+
template <typename Fn, class... InputTypes>
181+
void Calculate(Fn &&fn, OutputType* output, const InputTypes&... inputs) {
182+
ndata_ = output->Size() / sample_rates_.front();
183+
const size_t nBatch = ndata_ / kBatchSize + (ndata_ % kBatchSize > 0);
184+
for (size_t batch = 0; batch < nBatch; ++batch) {
185+
batch_begin_ = batch * kBatchSize;
186+
batch_end_ = (batch == nBatch - 1) ? ndata_ : batch_begin_ + kBatchSize;
187+
batch_size_ = batch_end_ - batch_begin_;
188+
189+
// Iteratively copy all inputs to device buffers
190+
Host2Buffers(in_buffer_.Data(), inputs...);
191+
// Pack buffers and call function
192+
// We shift input pointer to keep the same order of inputs after packing
193+
Call(std::forward<Fn>(fn), in_buffer_.DataConst() + in_buff_offsets_.back());
194+
// Copy results to host
195+
Buffers2Host(output);
196+
}
197+
}
198+
199+
/*
200+
* Batch-wise proces on sycl device
201+
* input = fn(input, other_inputs)
202+
*/
203+
template <typename Fn, class... InputTypes>
204+
void Calculate(Fn &&fn, InputType* input, const InputTypes&... other_inputs) {
205+
ndata_ = input->Size();
206+
const size_t nBatch = ndata_ / kBatchSize + (ndata_ % kBatchSize > 0);
207+
for (size_t batch = 0; batch < nBatch; ++batch) {
208+
batch_begin_ = batch * kBatchSize;
209+
batch_end_ = (batch == nBatch - 1) ? ndata_ : batch_begin_ + kBatchSize;
210+
batch_size_ = batch_end_ - batch_begin_;
211+
212+
// Iteratively copy all inputs to device buffers.
213+
// inputs are pased by const reference
214+
Host2Buffers(in_buffer_.Data(), *(input), other_inputs...);
215+
// Pack buffers and call function
216+
// We shift input pointer to keep the same order of inputs after packing
217+
Call(std::forward<Fn>(fn), in_buffer_.Data());
218+
// Copy results to host
219+
Buffers2Host(input);
220+
}
221+
}
222+
223+
private:
224+
std::array<int, MaxNumInputs> in_buff_offsets_;
225+
std::vector<int> sample_rates_;
226+
size_t ndata_;
227+
size_t batch_begin_;
228+
size_t batch_end_;
229+
// is not equal to kBatchSize for the last batch
230+
size_t batch_size_;
231+
::sycl::queue* qu_;
232+
std::vector<::sycl::event> events_;
233+
USMVector<InputDType, MemoryType::on_device> in_buffer_;
234+
USMVector<OutputDType, MemoryType::on_device> out_buffer_;
235+
};
236+
237+
} // namespace linalg
238+
} // namespace sycl
239+
} // namespace xgboost
240+
#endif // PLUGIN_SYCL_COMMON_LINALG_OP_H_

plugin/sycl/objective/multiclass_obj.cc

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222

2323
#include "../../../src/objective/multiclass_param.h"
2424

25+
#include "../common/linalg_op.h"
26+
2527
#include "../device_manager.h"
28+
#include "../data.h"
2629
#include <CL/sycl.hpp>
2730

2831
namespace xgboost {
@@ -32,6 +35,15 @@ namespace obj {
3235
DMLC_REGISTRY_FILE_TAG(multiclass_obj_sycl);
3336

3437
class SoftmaxMultiClassObj : public ObjFunction {
38+
mutable bool are_buffs_init = false;
39+
40+
void InitBuffers(const std::vector<int>& sample_rate) const {
41+
if (!are_buffs_init) {
42+
batch_processor_.InitBuffers(&qu_, sample_rate);
43+
are_buffs_init = true;
44+
}
45+
}
46+
3547
public:
3648
explicit SoftmaxMultiClassObj(bool output_prob)
3749
: output_prob_(output_prob) {}
@@ -44,7 +56,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
4456
void GetGradient(const HostDeviceVector<bst_float>& preds,
4557
const MetaInfo& info,
4658
int iter,
47-
linalg::Matrix<GradientPair>* out_gpair) override {
59+
xgboost::linalg::Matrix<GradientPair>* out_gpair) override {
4860
if (preds.Size() == 0) return;
4961
if (info.labels.Size() == 0) return;
5062

@@ -66,54 +78,68 @@ class SoftmaxMultiClassObj : public ObjFunction {
6678
<< "Number of weights should be equal to number of data points.";
6779
}
6880

69-
::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
70-
::sycl::buffer<bst_float, 1> labels_buf(info.labels.Data()->HostPointer(), info.labels.Size());
71-
::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->Data()->HostPointer(),
72-
out_gpair->Size());
73-
::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
74-
is_null_weight ? 1 : info.weights_.Size());
75-
7681
int flag = 1;
77-
{
78-
::sycl::buffer<int, 1> flag_buf(&flag, 1);
79-
qu_.submit([&](::sycl::handler& cgh) {
80-
auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh);
81-
auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh);
82-
auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh);
83-
auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh);
84-
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
85-
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
86-
int idx = pid[0];
87-
88-
bst_float const * point = &preds_acc[idx * nclass];
82+
auto objective_fn = [=, &flag]
83+
(const std::vector<::sycl::event>& events,
84+
size_t ndata,
85+
GradientPair* out_gpair,
86+
const bst_float* preds,
87+
const bst_float* labels,
88+
const bst_float* weights) {
89+
const size_t wg_size = 32;
90+
const size_t nwgs = ndata / wg_size + (ndata % wg_size > 0);
91+
return linalg::GroupWiseKernel(&qu_, &flag, events, {nwgs, wg_size},
92+
[=] (size_t idx, auto flag) {
93+
const bst_float* pred = preds + idx * nclass;
8994

9095
// Part of Softmax function
9196
bst_float wmax = std::numeric_limits<bst_float>::min();
92-
for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(point[k], wmax); }
93-
float wsum = 0.0f;
94-
for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(point[k] - wmax); }
95-
auto label = labels_acc[idx];
97+
for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(pred[k], wmax); }
98+
bst_float wsum = 0.0f;
99+
for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(pred[k] - wmax); }
100+
bst_float label = labels[idx];
101+
96102
if (label < 0 || label >= nclass) {
97-
flag_buf_acc[0] = 0;
103+
AtomicRef<int> flag_ref(flag[0]);
104+
flag_ref = 0;
98105
label = 0;
99106
}
100-
bst_float wt = is_null_weight ? 1.0f : weights_acc[idx];
107+
108+
bst_float wt = is_null_weight ? 1.0f : weights[idx];
101109
for (int k = 0; k < nclass; ++k) {
102-
bst_float p = expf(point[k] - wmax) / static_cast<float>(wsum);
110+
bst_float p = expf(pred[k] - wmax) / static_cast<float>(wsum);
103111
const float eps = 1e-16f;
104112
const bst_float h = ::sycl::max(2.0f * p * (1.0f - p) * wt, eps);
105113
p = label == k ? p - 1.0f : p;
106-
out_gpair_acc[idx * nclass + k] = GradientPair(p * wt, h);
114+
out_gpair[idx * nclass + k] = GradientPair(p * wt, h);
107115
}
108-
});
109-
}).wait();
116+
});
117+
};
118+
119+
// out_gpair and preds have nclass points per sample
120+
// labels and weights have 1 points per sample
121+
InitBuffers({nclass, nclass, 1, 1});
122+
if (is_null_weight) {
123+
// Output is passed by pointer
124+
// Inputs are passed by const reference
125+
batch_processor_.Calculate(std::move(objective_fn),
126+
out_gpair->Data(),
127+
preds,
128+
*(info.labels.Data()));
129+
} else {
130+
batch_processor_.Calculate(std::move(objective_fn),
131+
out_gpair->Data(),
132+
preds,
133+
*(info.labels.Data()),
134+
info.weights_);
110135
}
111-
// flag_buf is destroyed, content is copyed to the "flag"
136+
qu_.wait_and_throw();
112137

113138
if (flag == 0) {
114139
LOG(FATAL) << "SYCL::SoftmaxMultiClassObj: label must be in [0, num_class).";
115140
}
116141
}
142+
117143
void PredTransform(HostDeviceVector<bst_float>* io_preds) const override {
118144
this->Transform(io_preds, output_prob_);
119145
}
@@ -190,6 +216,8 @@ class SoftmaxMultiClassObj : public ObjFunction {
190216
sycl::DeviceManager device_manager;
191217

192218
mutable ::sycl::queue qu_;
219+
static constexpr size_t kBatchSize = 1u << 22;
220+
mutable linalg::BatchProcessingHelper<GradientPair, bst_float, kBatchSize, 3> batch_processor_;
193221
};
194222

195223
XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax_sycl")

0 commit comments

Comments
 (0)