@@ -446,3 +446,185 @@ TEST(
446
446
Run (
447
447
/* m=*/ 4 , /* k=*/ 2 , /* n=*/ 1 , 32 );
448
448
}
449
+
450
+ class FP32A_QuantizedB_FP32C_Interface_Test
451
+ : public ::testing::TestWithParam<float > {
452
+ public:
453
+ int m;
454
+ int k;
455
+ int n;
456
+ int stride;
457
+
458
+ bool rhs_has_zeros;
459
+ bool lhs_is_transposed;
460
+ bool rhs_is_transposed;
461
+
462
+ std::vector<float > init_output;
463
+ std::vector<float > expected_output;
464
+
465
+ std::vector<float > lhs;
466
+
467
+ std::vector<float > rhs;
468
+ std::vector<int8_t > rhs_qvals;
469
+ std::vector<float > rhs_scales;
470
+ std::vector<int8_t > rhs_zeros;
471
+
472
+ void generate (
473
+ int m_,
474
+ int k_,
475
+ int n_,
476
+ bool rhs_has_zeros_,
477
+ bool lhs_is_transposed_,
478
+ bool rhs_is_transposed_,
479
+ int stride_ = 1 ) {
480
+ assert (!lhs_is_transposed_);
481
+ assert (rhs_has_zeros_);
482
+ m = m_;
483
+ k = k_;
484
+ n = n_;
485
+ stride = stride_;
486
+ rhs_has_zeros = rhs_has_zeros_;
487
+ lhs_is_transposed = lhs_is_transposed_;
488
+ rhs_is_transposed = rhs_is_transposed_;
489
+
490
+ assert (!rhs_is_transposed || stride == 1 );
491
+
492
+ // Generate activations
493
+ lhs = get_random_vector (m * k, -1.0 , 1.0 );
494
+
495
+ // The strange thing this is doing is that instead of quantizing
496
+ // each output channel separately, we are quantizing each input channel
497
+ // Reason why we do !rhs_is_transposed is because
498
+ // we actually want k x n matrix not n x k matrix
499
+ // because each input channel is quantized separately
500
+ std::tie (rhs, rhs_qvals, rhs_scales, rhs_zeros) =
501
+ generate_per_token_quantized_tensor (k * stride, n, rhs_is_transposed);
502
+
503
+ // Compute expected output
504
+ init_output = get_random_vector (m * n, -1.0 , 1.0 );
505
+
506
+ assert (init_output.size () == m * n);
507
+ assert (lhs.size () == m * k);
508
+ assert (rhs.size () == n * stride * k);
509
+ assert (rhs_qvals.size () == n * stride * k);
510
+ assert (rhs_scales.size () == k * stride);
511
+ assert (rhs_zeros.size () == k * stride);
512
+ }
513
+
514
+ void execute (float beta) {
515
+ // Compute expected output
516
+ expected_output = init_output;
517
+
518
+ for (int m_idx = 0 ; m_idx < m; m_idx++) {
519
+ for (int n_idx = 0 ; n_idx < n; n_idx++) {
520
+ float res = 0.0 ;
521
+ for (int k_idx = 0 ; k_idx < k; k_idx++) {
522
+ int lhs_idx = m_idx * k + k_idx;
523
+ int rhs_idx = k_idx * stride * n + n_idx;
524
+ if (rhs_is_transposed) {
525
+ rhs_idx = n_idx * k * stride + k_idx * stride;
526
+ }
527
+ float rhs_dequant = rhs_scales[k_idx * stride] *
528
+ (static_cast <int16_t >(rhs_qvals[rhs_idx]) -
529
+ static_cast <int16_t >(rhs_zeros[k_idx * stride]));
530
+
531
+ res += lhs[lhs_idx] * rhs_dequant;
532
+ }
533
+ expected_output[m_idx * n + n_idx] =
534
+ expected_output[m_idx * n + n_idx] * beta + res;
535
+ }
536
+ }
537
+ }
538
+
539
+ float beta () const {
540
+ return GetParam ();
541
+ }
542
+ };
543
+
544
+ static void test_fp32_a_input_channelwise_8bit_b (
545
+ int m,
546
+ int k,
547
+ int n,
548
+ float beta,
549
+ FP32A_QuantizedB_FP32C_Interface_Test& test_case,
550
+ int stride = 1 ) {
551
+ test_case.execute (beta);
552
+
553
+ int a_stride_m, b_stride_n;
554
+ auto kernel = torchao::kernels::cpu::quantized_matmul::
555
+ get_fp32_a_input_channelwise_8bit_b_f32_c_matmul (
556
+ m, n, k, false , false , a_stride_m, b_stride_n);
557
+ b_stride_n = b_stride_n * stride;
558
+
559
+ std::vector<float > output (test_case.init_output );
560
+ kernel (
561
+ m,
562
+ n,
563
+ k,
564
+ test_case.lhs .data (),
565
+ a_stride_m /* lhs_stride_m*/ ,
566
+ test_case.rhs_qvals .data (),
567
+ b_stride_n /* rhs_stride_n*/ ,
568
+ output.data (),
569
+ n /* out_stride_n*/ ,
570
+ test_case.rhs_zeros .data (),
571
+ test_case.rhs_scales .data (),
572
+ beta,
573
+ stride /* rhs qparams stride*/ );
574
+
575
+ for (int i = 0 ; i < m * n; i++) {
576
+ EXPECT_NEAR (output[i], test_case.expected_output [i], kTol );
577
+ }
578
+ }
579
+
580
+ TEST_P (FP32A_QuantizedB_FP32C_Interface_Test, BTranposedWithZeroPoints) {
581
+ generate (3 , 128 , 16 , true , false , false );
582
+ test_fp32_a_input_channelwise_8bit_b (
583
+ /* m=*/ 3 , /* k=*/ 128 , /* n=*/ 16 , beta (), *this );
584
+ }
585
+
586
+ TEST_P (
587
+ FP32A_QuantizedB_FP32C_Interface_Test,
588
+ BTranposedWithZeroPointsOddSizes) {
589
+ generate (4 , 37 , 19 , true , false , false );
590
+ test_fp32_a_input_channelwise_8bit_b (
591
+ /* m=*/ 4 , /* k=*/ 37 , /* n=*/ 19 , beta (), *this );
592
+ }
593
+
594
+ // Test shapes for which we have to use fallback kernel
595
+ TEST_P (
596
+ FP32A_QuantizedB_FP32C_Interface_Test,
597
+ BTranposedWithZeroPointsOddSizesFallback) {
598
+ generate (4 , 37 , 3 , true , false , false );
599
+ test_fp32_a_input_channelwise_8bit_b (
600
+ /* m=*/ 4 , /* k=*/ 37 , /* n=*/ 3 , beta (), *this );
601
+ }
602
+
603
+ TEST_P (
604
+ FP32A_QuantizedB_FP32C_Interface_Test,
605
+ BTranposedWithZeroPointsOddSizes2Fallback) {
606
+ generate (4 , 1 , 3 , true , false , false );
607
+ test_fp32_a_input_channelwise_8bit_b (
608
+ /* m=*/ 4 , /* k=*/ 1 , /* n=*/ 3 , beta (), *this );
609
+ }
610
+
611
+ TEST_P (
612
+ FP32A_QuantizedB_FP32C_Interface_Test,
613
+ BTranposedWithZeroPointsOddSizesStrided) {
614
+ generate (4 , 37 , 19 , true , false , false , 32 );
615
+ test_fp32_a_input_channelwise_8bit_b (
616
+ /* m=*/ 4 , /* k=*/ 37 , /* n=*/ 19 , beta (), *this , 32 );
617
+ }
618
+
619
+ TEST_P (
620
+ FP32A_QuantizedB_FP32C_Interface_Test,
621
+ BTranposedWithZeroPointsOddSizes2FallbackStrided) {
622
+ generate (4 , 5 , 3 , true , false , false , 32 );
623
+ test_fp32_a_input_channelwise_8bit_b (
624
+ /* m=*/ 4 , /* k=*/ 5 , /* n=*/ 3 , beta (), *this , 32 );
625
+ }
626
+
627
+ INSTANTIATE_TEST_SUITE_P (
628
+ F32AInt8BFP32CTest,
629
+ FP32A_QuantizedB_FP32C_Interface_Test,
630
+ ::testing::Values (0.0 , 1.0 , 3.1 ));
0 commit comments