Skip to content

Commit 7a2d080

Browse files
kimishpatelliangel-02
authored andcommitted
add fallback kernel and interface
Differential Revision: D71370598 Pull Request resolved: #2010
1 parent 91c034e commit 7a2d080

File tree

4 files changed

+670
-0
lines changed

4 files changed

+670
-0
lines changed

torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ struct test_channelwise_8bit_channelwise_8bit_b<
7070
false,
7171
false> {
7272
static void Run(int m, int k, int n, int stride = 1) {
73+
// TODO: make use of stride for this kernel
7374
auto test_case =
7475
torchao::channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case::
7576
generate(m, k, n, a_has_zeros, a_has_zeros, false, false);
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <cstdint>
10+
11+
namespace torchao::kernels::cpu::fallback::quantized_matmul {
12+
namespace channelwise_8bit_a_channelwise_8bit_b::internal {
13+
14+
template <
15+
bool a_has_zeros,
16+
bool b_has_zeros,
17+
bool a_transposed,
18+
bool b_tranposed>
19+
struct KernelImpl {
20+
static void run(
21+
int m,
22+
int n,
23+
int k,
24+
const void* lhs,
25+
int lhs_stride_m,
26+
const void* rhs,
27+
int rhs_stride_n,
28+
float* output,
29+
int out_stride_m,
30+
const int8_t* lhs_zero_points,
31+
const int8_t* rhs_zero_points,
32+
const float* lhs_scales,
33+
const float* rhs_scales,
34+
const int lhs_qparams_stride,
35+
const int rhs_qparams_stride);
36+
};
37+
38+
template <bool b_transposed>
39+
struct KernelImpl<true, true, false, b_transposed> {
40+
static void run(
41+
int m,
42+
int n,
43+
int k,
44+
const void* lhs,
45+
int lhs_stride_m,
46+
const void* rhs,
47+
int rhs_stride_n,
48+
float* output,
49+
int out_stride_m,
50+
const int8_t* lhs_zero_points,
51+
const int8_t* rhs_zero_points,
52+
const float* lhs_scales,
53+
const float* rhs_scales,
54+
const int lhs_qparams_stride,
55+
const int rhs_qparams_stride) {
56+
const int8_t* lhs_qvals = static_cast<const int8_t*>(lhs);
57+
const int8_t* rhs_qvals = static_cast<const int8_t*>(rhs);
58+
for (int m_idx = 0; m_idx < m; m_idx++) {
59+
for (int n_idx = 0; n_idx < n; n_idx++) {
60+
float res = 0.0;
61+
for (int k_idx = 0; k_idx < k; k_idx++) {
62+
int lhs_idx = m_idx * lhs_stride_m + k_idx;
63+
int rhs_idx = k_idx * rhs_stride_n + n_idx;
64+
if (b_transposed) {
65+
rhs_idx = n_idx * rhs_stride_n + k_idx;
66+
}
67+
68+
float lhs_dequant = lhs_scales[m_idx * lhs_qparams_stride] *
69+
(static_cast<int16_t>(lhs_qvals[lhs_idx]) -
70+
static_cast<int16_t>(
71+
lhs_zero_points[m_idx * lhs_qparams_stride]));
72+
73+
float rhs_dequant = rhs_scales[n_idx * rhs_qparams_stride] *
74+
(static_cast<int16_t>(rhs_qvals[rhs_idx]) -
75+
static_cast<int16_t>(
76+
rhs_zero_points[n_idx * rhs_qparams_stride]));
77+
78+
res += lhs_dequant * rhs_dequant;
79+
}
80+
output[m_idx * n + n_idx] = res;
81+
}
82+
}
83+
}
84+
};
85+
86+
} // namespace
87+
// channelwise_8bit_a_channelwise_8bit_b::internal
88+
} // namespace torchao::kernels::cpu::fallback::quantized_matmul
89+
90+
// TODO: Remove all ::kernels. No need for extra namespace.
91+
namespace torchao::kernels::cpu::fallback::quantized_matmul {
92+
namespace channelwise_8bit_a_channelwise_8bit_b {
93+
template <
94+
bool a_has_zeros,
95+
bool b_has_zeros,
96+
bool a_transposed,
97+
bool b_transposed>
98+
void kernel(
99+
int m,
100+
int n,
101+
int k,
102+
const void* lhs,
103+
int lhs_stride_m,
104+
const void* rhs,
105+
int rhs_stride_n,
106+
float* output,
107+
int out_stride_m,
108+
const int8_t* lhs_zero_points,
109+
const int8_t* rhs_zero_points,
110+
const float* lhs_scales,
111+
const float* rhs_scales,
112+
const int lhs_qparams_stride,
113+
const int rhs_qparams_stride) {
114+
channelwise_8bit_a_channelwise_8bit_b::internal::
115+
KernelImpl<a_has_zeros, b_has_zeros, a_transposed, b_transposed>::run(
116+
m,
117+
n,
118+
k,
119+
lhs,
120+
lhs_stride_m,
121+
rhs,
122+
rhs_stride_n,
123+
output,
124+
out_stride_m,
125+
lhs_zero_points,
126+
rhs_zero_points,
127+
lhs_scales,
128+
rhs_scales,
129+
lhs_qparams_stride,
130+
rhs_qparams_stride);
131+
}
132+
} // namespace channelwise_8bit_a_channelwise_8bit_b
133+
} // namespace torchao::kernels::cpu::fallback::quantized_matmul
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <cassert>
10+
11+
#include <torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h>
12+
#if defined(__aarch64__) || defined(__ARM_NEON)
13+
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h>
14+
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h>
15+
#include <torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h>
16+
#endif // defined(__aarch64__) || defined(__ARM_NEON)
17+
18+
namespace torchao::kernels::cpu::quantized_matmul {
19+
20+
/*
21+
a_stride_m: stride of a in memory to indiciate how far apart each row is.
22+
b_stride_n: stride of b in memory to indiciate how far apart each row is.
23+
If b is transposed (n x k), then this is how many bytes to skip to get to the
24+
next row. If b is not transposed (k x n), then this is how many bytes to skip to
25+
get to the next row.
26+
27+
It also returns the stride of a and b, that should be used in the kernel.
28+
29+
Will need to think of a better way to find the right
30+
ukernel. Perhaps via ukernelconfig + registry?.
31+
*/
32+
using int8_a_int8_b_channelwise_fp32_c_qmatmul_type = void (*)(
33+
int,
34+
int,
35+
int,
36+
const void*,
37+
int,
38+
const void*,
39+
int,
40+
float*,
41+
int,
42+
const int8_t*,
43+
const int8_t*,
44+
const float*,
45+
const float*,
46+
const int,
47+
const int);
48+
49+
int8_a_int8_b_channelwise_fp32_c_qmatmul_type
50+
get_int8_a_int8_b_channelwise_qmatmul(
51+
int m,
52+
int n,
53+
int k,
54+
bool a_transposed,
55+
bool b_transposed,
56+
int& a_stride_m,
57+
int& b_stride_n);
58+
59+
int8_a_int8_b_channelwise_fp32_c_qmatmul_type
60+
get_int8_a_int8_b_channelwise_qmatmul(
61+
int m,
62+
int n,
63+
int k,
64+
bool a_transposed,
65+
bool b_transposed,
66+
int& a_stride_m,
67+
int& b_stride_n) {
68+
#if defined(__aarch64__) || defined(__ARM_NEON)
69+
if (!a_transposed && b_transposed && n >= 8) {
70+
a_stride_m = k;
71+
b_stride_n = k;
72+
return aarch64::quantized_matmul::
73+
channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::
74+
kernel<true, true, false, true>;
75+
}
76+
#endif // defined(__aarch64__) || defined(__ARM_NEON)
77+
assert(!a_transposed);
78+
if (b_transposed) {
79+
a_stride_m = k;
80+
b_stride_n = k;
81+
return torchao::kernels::cpu::fallback::quantized_matmul::
82+
channelwise_8bit_a_channelwise_8bit_b::kernel<true, true, false, true>;
83+
} else {
84+
return torchao::kernels::cpu::fallback::quantized_matmul::
85+
channelwise_8bit_a_channelwise_8bit_b::kernel<true, true, false, false>;
86+
}
87+
}
88+
} // namespace torchao::kernels::cpu::quantized_matmul

0 commit comments

Comments
 (0)