Skip to content

Commit f7c315b

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Add fallback kernel and interface for rhs only quantized matmul
Summary: as the title Reviewed By: metascroy Differential Revision: D71370602
1 parent 97d6d74 commit f7c315b

File tree

3 files changed

+302
-0
lines changed

3 files changed

+302
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <cassert>
10+
11+
// TODO: Remove all ::kernels. No need for extra namespace.
12+
namespace torchao::kernels::cpu::fallback::quantized_matmul {
13+
namespace fp32_a_input_channelwise_8bit_b_fp32 {
14+
template <bool b_has_zeros, bool a_transposed, bool b_transposed>
15+
void kernel(
16+
int m,
17+
int n,
18+
int k,
19+
const float* lhs,
20+
int lhs_stride_m,
21+
const int8_t* rhs,
22+
int rhs_stride_n,
23+
float* output,
24+
int out_stride_m,
25+
const int8_t* rhs_zero_points,
26+
const float* rhs_scales,
27+
const float beta,
28+
const int rhs_qparams_stride) {
29+
assert(a_transposed == false);
30+
for (int m_idx = 0; m_idx < m; m_idx++) {
31+
for (int n_idx = 0; n_idx < n; n_idx++) {
32+
float res = 0.0;
33+
for (int k_idx = 0; k_idx < k; k_idx++) {
34+
int lhs_idx = m_idx * lhs_stride_m + k_idx;
35+
int rhs_idx = k_idx * rhs_stride_n + n_idx;
36+
if (b_transposed) {
37+
rhs_idx = n_idx * rhs_stride_n + k_idx;
38+
}
39+
float rhs_dequant = rhs_scales[k_idx * rhs_qparams_stride] *
40+
(static_cast<int16_t>(rhs[rhs_idx]) -
41+
static_cast<int16_t>(rhs_zero_points[k_idx * rhs_qparams_stride]));
42+
43+
res += lhs[lhs_idx] * rhs_dequant;
44+
}
45+
output[m_idx * n + n_idx] = output[m_idx * n + n_idx] * beta + res;
46+
}
47+
}
48+
}
49+
} // namespace fp32_a_input_channelwise_8bit_b_fp32
50+
} // namespace torchao::kernels::cpu::fallback::quantized_matmul

torchao/experimental/kernels/cpu/interface/quantized_matmul.h

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include <cassert>
1010

1111
#include <torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h>
12+
#include <torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h>
13+
1214
#if defined(__aarch64__) || defined(__ARM_NEON)
1315
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h>
1416
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h>
@@ -85,4 +87,72 @@ get_int8_a_int8_b_channelwise_qmatmul(
8587
channelwise_8bit_a_channelwise_8bit_b::kernel<true, true, false, false>;
8688
}
8789
}
90+
91+
/*
92+
a_stride_m: stride of a in memory to indiciate how far apart each row is.
93+
b_stride_n: stride of b in memory to indiciate how far apart each row is.
94+
If b is transposed (n x k), then this is how many bytes to skip to get to the
95+
next row. If b is not transposed (k x n), then this is how many bytes to skip to
96+
get to the next row.
97+
98+
It also returns the stride of a and b, that should be used in the kernel.
99+
100+
Will need to think of a better way to find the right
101+
ukernel. Perhaps via ukernelconfig + registry?.
102+
*/
103+
using fp32_a_input_channelwise_8bit_b_f32_c_matmul_type = void (*)(
104+
int,
105+
int,
106+
int,
107+
const float*,
108+
int,
109+
const int8_t*,
110+
int,
111+
float*,
112+
int,
113+
const int8_t*,
114+
const float*,
115+
const float,
116+
const int);
117+
118+
fp32_a_input_channelwise_8bit_b_f32_c_matmul_type
119+
get_fp32_a_input_channelwise_8bit_b_f32_c_matmul(
120+
int m,
121+
int n,
122+
int k,
123+
bool a_transposed,
124+
bool b_transposed,
125+
int& a_stride_m,
126+
int& b_stride_n);
127+
128+
fp32_a_input_channelwise_8bit_b_f32_c_matmul_type
129+
get_fp32_a_input_channelwise_8bit_b_f32_c_matmul(
130+
int m,
131+
int n,
132+
int k,
133+
bool a_transposed,
134+
bool b_transposed,
135+
int& a_stride_m,
136+
int& b_stride_n) {
137+
#if defined(__aarch64__) || defined(__ARM_NEON)
138+
if (!a_transposed && !b_transposed && n >= 16) {
139+
a_stride_m = k;
140+
b_stride_n = n;
141+
return aarch64::quantized_matmul::
142+
fp32_a_input_channelwise_8bit_b_1x16x4_f32::kernel<true, false, false>;
143+
}
144+
#endif // defined(__aarch64__) || defined(__ARM_NEON)
145+
assert(!a_transposed);
146+
if (b_transposed) {
147+
a_stride_m = k;
148+
b_stride_n = k;
149+
return torchao::kernels::cpu::fallback::quantized_matmul::
150+
fp32_a_input_channelwise_8bit_b_fp32::kernel<true, false, true>;
151+
} else {
152+
a_stride_m = k;
153+
b_stride_n = n;
154+
return torchao::kernels::cpu::fallback::quantized_matmul::
155+
fp32_a_input_channelwise_8bit_b_fp32::kernel<true, false, false>;
156+
}
157+
}
88158
} // namespace torchao::kernels::cpu::quantized_matmul

torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,185 @@ TEST(
446446
Run(
447447
/*m=*/4, /*k=*/2, /*n=*/1, 32);
448448
}
449+
450+
class FP32A_QuantizedB_FP32C_Interface_Test
451+
: public ::testing::TestWithParam<float> {
452+
public:
453+
int m;
454+
int k;
455+
int n;
456+
int stride;
457+
458+
bool rhs_has_zeros;
459+
bool lhs_is_transposed;
460+
bool rhs_is_transposed;
461+
462+
std::vector<float> init_output;
463+
std::vector<float> expected_output;
464+
465+
std::vector<float> lhs;
466+
467+
std::vector<float> rhs;
468+
std::vector<int8_t> rhs_qvals;
469+
std::vector<float> rhs_scales;
470+
std::vector<int8_t> rhs_zeros;
471+
472+
void generate(
473+
int m_,
474+
int k_,
475+
int n_,
476+
bool rhs_has_zeros_,
477+
bool lhs_is_transposed_,
478+
bool rhs_is_transposed_,
479+
int stride_ = 1) {
480+
assert(!lhs_is_transposed_);
481+
assert(rhs_has_zeros_);
482+
m = m_;
483+
k = k_;
484+
n = n_;
485+
stride = stride_;
486+
rhs_has_zeros = rhs_has_zeros_;
487+
lhs_is_transposed = lhs_is_transposed_;
488+
rhs_is_transposed = rhs_is_transposed_;
489+
490+
assert(!rhs_is_transposed || stride == 1);
491+
492+
// Generate activations
493+
lhs = get_random_vector(m * k, -1.0, 1.0);
494+
495+
// The strange thing this is doing is that instead of quantizing
496+
// each output channel separately, we are quantizing each input channel
497+
// Reason why we do !rhs_is_transposed is because
498+
// we actually want k x n matrix not n x k matrix
499+
// because each input channel is quantized separately
500+
std::tie(rhs, rhs_qvals, rhs_scales, rhs_zeros) =
501+
generate_per_token_quantized_tensor(k * stride, n, rhs_is_transposed);
502+
503+
// Compute expected output
504+
init_output = get_random_vector(m * n, -1.0, 1.0);
505+
506+
assert(init_output.size() == m * n);
507+
assert(lhs.size() == m * k);
508+
assert(rhs.size() == n * stride * k);
509+
assert(rhs_qvals.size() == n * stride * k);
510+
assert(rhs_scales.size() == k * stride);
511+
assert(rhs_zeros.size() == k * stride);
512+
}
513+
514+
void execute(float beta) {
515+
// Compute expected output
516+
expected_output = init_output;
517+
518+
for (int m_idx = 0; m_idx < m; m_idx++) {
519+
for (int n_idx = 0; n_idx < n; n_idx++) {
520+
float res = 0.0;
521+
for (int k_idx = 0; k_idx < k; k_idx++) {
522+
int lhs_idx = m_idx * k + k_idx;
523+
int rhs_idx = k_idx * stride * n + n_idx;
524+
if (rhs_is_transposed) {
525+
rhs_idx = n_idx * k * stride + k_idx * stride;
526+
}
527+
float rhs_dequant = rhs_scales[k_idx * stride] *
528+
(static_cast<int16_t>(rhs_qvals[rhs_idx]) -
529+
static_cast<int16_t>(rhs_zeros[k_idx * stride]));
530+
531+
res += lhs[lhs_idx] * rhs_dequant;
532+
}
533+
expected_output[m_idx * n + n_idx] =
534+
expected_output[m_idx * n + n_idx] * beta + res;
535+
}
536+
}
537+
}
538+
539+
float beta() const {
540+
return GetParam();
541+
}
542+
};
543+
544+
static void test_fp32_a_input_channelwise_8bit_b(
545+
int m,
546+
int k,
547+
int n,
548+
float beta,
549+
FP32A_QuantizedB_FP32C_Interface_Test& test_case,
550+
int stride = 1) {
551+
test_case.execute(beta);
552+
553+
int a_stride_m, b_stride_n;
554+
auto kernel = torchao::kernels::cpu::quantized_matmul::
555+
get_fp32_a_input_channelwise_8bit_b_f32_c_matmul(
556+
m, n, k, false, false, a_stride_m, b_stride_n);
557+
b_stride_n = b_stride_n * stride;
558+
559+
std::vector<float> output(test_case.init_output);
560+
kernel(
561+
m,
562+
n,
563+
k,
564+
test_case.lhs.data(),
565+
a_stride_m /*lhs_stride_m*/,
566+
test_case.rhs_qvals.data(),
567+
b_stride_n /*rhs_stride_n*/,
568+
output.data(),
569+
n /*out_stride_n*/,
570+
test_case.rhs_zeros.data(),
571+
test_case.rhs_scales.data(),
572+
beta,
573+
stride /*rhs qparams stride*/);
574+
575+
for (int i = 0; i < m * n; i++) {
576+
EXPECT_NEAR(output[i], test_case.expected_output[i], kTol);
577+
}
578+
}
579+
580+
TEST_P(FP32A_QuantizedB_FP32C_Interface_Test, BTranposedWithZeroPoints) {
581+
generate(3, 128, 16, true, false, false);
582+
test_fp32_a_input_channelwise_8bit_b(
583+
/*m=*/3, /*k=*/128, /*n=*/16, beta(), *this);
584+
}
585+
586+
TEST_P(
587+
FP32A_QuantizedB_FP32C_Interface_Test,
588+
BTranposedWithZeroPointsOddSizes) {
589+
generate(4, 37, 19, true, false, false);
590+
test_fp32_a_input_channelwise_8bit_b(
591+
/*m=*/4, /*k=*/37, /*n=*/19, beta(), *this);
592+
}
593+
594+
// Test shapes for which we have to use fallback kernel
595+
TEST_P(
596+
FP32A_QuantizedB_FP32C_Interface_Test,
597+
BTranposedWithZeroPointsOddSizesFallback) {
598+
generate(4, 37, 3, true, false, false);
599+
test_fp32_a_input_channelwise_8bit_b(
600+
/*m=*/4, /*k=*/37, /*n=*/3, beta(), *this);
601+
}
602+
603+
TEST_P(
604+
FP32A_QuantizedB_FP32C_Interface_Test,
605+
BTranposedWithZeroPointsOddSizes2Fallback) {
606+
generate(4, 1, 3, true, false, false);
607+
test_fp32_a_input_channelwise_8bit_b(
608+
/*m=*/4, /*k=*/1, /*n=*/3, beta(), *this);
609+
}
610+
611+
TEST_P(
612+
FP32A_QuantizedB_FP32C_Interface_Test,
613+
BTranposedWithZeroPointsOddSizesStrided) {
614+
generate(4, 37, 19, true, false, false, 32);
615+
test_fp32_a_input_channelwise_8bit_b(
616+
/*m=*/4, /*k=*/37, /*n=*/19, beta(), *this, 32);
617+
}
618+
619+
TEST_P(
620+
FP32A_QuantizedB_FP32C_Interface_Test,
621+
BTranposedWithZeroPointsOddSizes2FallbackStrided) {
622+
generate(4, 5, 3, true, false, false, 32);
623+
test_fp32_a_input_channelwise_8bit_b(
624+
/*m=*/4, /*k=*/5, /*n=*/3, beta(), *this, 32);
625+
}
626+
627+
INSTANTIATE_TEST_SUITE_P(
628+
F32AInt8BFP32CTest,
629+
FP32A_QuantizedB_FP32C_Interface_Test,
630+
::testing::Values(0.0, 1.0, 3.1));

0 commit comments

Comments
 (0)