JohannesGaessler Diego Devesa commited on
Commit
f328957
·
1 Parent(s): 9ed1962

CUDA: use mma PTX instructions for FlashAttention (llama/11583)

Browse files

* CUDA: use mma PTX instructions for FlashAttention

* __shfl_sync workaround for movmatrix

* add __shfl_sync to HIP

Co-authored-by: Diego Devesa <[email protected]>

Files changed (28) hide show
  1. ggml/include/ggml.h +1 -1
  2. ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  3. ggml/src/ggml-cuda/common.cuh +4 -2
  4. ggml/src/ggml-cuda/fattn-common.cuh +154 -25
  5. ggml/src/ggml-cuda/fattn-mma-f16.cuh +637 -0
  6. ggml/src/ggml-cuda/fattn-tile-f16.cu +18 -6
  7. ggml/src/ggml-cuda/fattn-tile-f32.cu +13 -6
  8. ggml/src/ggml-cuda/fattn-vec-f16.cuh +8 -1
  9. ggml/src/ggml-cuda/fattn-vec-f32.cuh +7 -1
  10. ggml/src/ggml-cuda/fattn-wmma-f16.cu +648 -0
  11. ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -541
  12. ggml/src/ggml-cuda/fattn.cu +50 -124
  13. ggml/src/ggml-cuda/mma.cuh +286 -49
  14. ggml/src/ggml-cuda/mmq.cu +1 -1
  15. ggml/src/ggml-cuda/mmq.cuh +176 -173
  16. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu +10 -0
  17. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu +10 -0
  18. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu +10 -0
  19. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu +10 -0
  20. ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +0 -10
  21. ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +0 -9
  22. ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +0 -10
  23. ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +0 -10
  24. ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +0 -8
  25. ggml/src/ggml-cuda/template-instances/generate_cu_files.py +8 -16
  26. ggml/src/ggml-cuda/vendors/hip.h +1 -0
  27. ggml/src/ggml-hip/CMakeLists.txt +1 -1
  28. ggml/src/ggml-musa/CMakeLists.txt +1 -1
ggml/include/ggml.h CHANGED
@@ -1775,7 +1775,7 @@ extern "C" {
1775
  struct ggml_tensor * a,
1776
  int k);
1777
 
1778
- #define GGML_KQ_MASK_PAD 32
1779
 
1780
  // q: [n_embd, n_batch, n_head, 1]
1781
  // k: [n_embd, n_kv, n_head_kv, 1]
 
1775
  struct ggml_tensor * a,
1776
  int k);
1777
 
1778
+ #define GGML_KQ_MASK_PAD 64
1779
 
1780
  // q: [n_embd, n_batch, n_head, 1]
1781
  // k: [n_embd, n_kv, n_head_kv, 1]
ggml/src/ggml-cuda/CMakeLists.txt CHANGED
@@ -28,7 +28,7 @@ if (CUDAToolkit_FOUND)
28
  list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
29
 
30
  file(GLOB GGML_SOURCES_CUDA "*.cu")
31
- file(GLOB SRCS "template-instances/fattn-wmma*.cu")
32
  list(APPEND GGML_SOURCES_CUDA ${SRCS})
33
  file(GLOB SRCS "template-instances/mmq*.cu")
34
  list(APPEND GGML_SOURCES_CUDA ${SRCS})
 
28
  list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
29
 
30
  file(GLOB GGML_SOURCES_CUDA "*.cu")
31
+ file(GLOB SRCS "template-instances/fattn-mma*.cu")
32
  list(APPEND GGML_SOURCES_CUDA ${SRCS})
33
  file(GLOB SRCS "template-instances/mmq*.cu")
34
  list(APPEND GGML_SOURCES_CUDA ${SRCS})
ggml/src/ggml-cuda/common.cuh CHANGED
@@ -148,7 +148,7 @@ typedef float2 dfloat2;
148
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
149
 
150
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
151
- #define INT8_MMA_AVAILABLE
152
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
153
 
154
  #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
@@ -159,11 +159,13 @@ static constexpr bool fast_fp16_available(const int cc) {
159
  return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
160
  }
161
 
 
162
  static constexpr bool fp16_mma_available(const int cc) {
163
  return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
164
  }
165
 
166
- static constexpr bool int8_mma_available(const int cc) {
 
167
  return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
168
  }
169
 
 
148
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
149
 
150
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
151
+ #define NEW_MMA_AVAILABLE
152
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
153
 
154
  #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
 
159
  return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
160
  }
161
 
162
+ // Any FP16 tensor cores are available.
163
  static constexpr bool fp16_mma_available(const int cc) {
164
  return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
165
  }
166
 
167
+ // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
168
+ static constexpr bool new_mma_available(const int cc) {
169
  return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
170
  }
171
 
ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -516,6 +516,104 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
516
  nullptr;
517
  }
518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  template<int D, int parallel_blocks> // D == head size
520
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
521
  __launch_bounds__(D, 1)
@@ -581,10 +679,11 @@ static void on_no_fattn_vec_case(const int D) {
581
  }
582
  }
583
 
584
- template <int D, int parallel_blocks>
 
585
  void launch_fattn(
586
  ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
587
- const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
588
  ) {
589
  const ggml_tensor * Q = dst->src[0];
590
  const ggml_tensor * K = dst->src[1];
@@ -603,20 +702,23 @@ void launch_fattn(
603
 
604
  GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
605
 
 
 
606
  ggml_cuda_pool & pool = ctx.pool();
607
  cudaStream_t main_stream = ctx.stream();
 
608
 
609
  ggml_cuda_pool_alloc<half> K_f16(pool);
610
  ggml_cuda_pool_alloc<half> V_f16(pool);
611
  ggml_cuda_pool_alloc<float> dst_tmp(pool);
612
  ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
613
 
614
- char * K_data = (char *) K->data;
615
  size_t nb11 = K->nb[1];
616
  size_t nb12 = K->nb[2];
617
  size_t nb13 = K->nb[3];
618
 
619
- char * V_data = (char *) V->data;
620
  size_t nb21 = V->nb[1];
621
  size_t nb22 = V->nb[2];
622
  size_t nb23 = V->nb[3];
@@ -649,39 +751,60 @@ void launch_fattn(
649
  nb23 = nb23*bs*sizeof(half)/ts;
650
  }
651
 
652
- if (parallel_blocks > 1) {
653
- dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
654
- dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
655
- }
656
 
657
  const dim3 block_dim(WARP_SIZE, nwarps, 1);
658
- const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
659
- const int shmem = 0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
 
661
  float scale = 1.0f;
662
  float max_bias = 0.0f;
663
  float logit_softcap = 0.0f;
664
 
665
- memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
666
- memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
667
- memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
668
 
669
  if (logit_softcap != 0.0f) {
670
  scale /= logit_softcap;
671
  }
672
 
673
  const uint32_t n_head = Q->ne[2];
674
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
675
 
676
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
677
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
678
 
679
- fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
680
  (const char *) Q->data,
681
  K_data,
682
  V_data,
683
  mask ? ((const char *) mask->data) : nullptr,
684
- (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
685
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
686
  Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
687
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@@ -693,16 +816,22 @@ void launch_fattn(
693
  );
694
  CUDA_CHECK(cudaGetLastError());
695
 
696
- if ((parallel_blocks) == 1) {
697
- return;
698
- }
 
699
 
700
- const dim3 block_dim_combine(D, 1, 1);
701
- const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
702
- const int shmem_combine = 0;
 
 
 
 
703
 
704
- flash_attn_combine_results<D, parallel_blocks>
705
- <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
706
- (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
 
707
  CUDA_CHECK(cudaGetLastError());
708
  }
 
516
  nullptr;
517
  }
518
 
519
+ template<int D, int ncols, int KQ_stride> // D == head size
520
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
521
+ __launch_bounds__(D, 1)
522
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
523
+ static __global__ void flash_attn_stream_k_fixup(
524
+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
525
+ const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
526
+
527
+ const int iter_k = ne11 / KQ_stride;
528
+ const int iter_j = (ne01 + (ncols - 1)) / ncols;
529
+
530
+ const int bidx0 = blockIdx.x;
531
+
532
+ const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
533
+ const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
534
+
535
+ const bool did_not_have_any_data = kbc0 == kbc0_stop;
536
+ const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
537
+ const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
538
+ if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
539
+ return;
540
+ }
541
+
542
+ const int channel = kbc0 / (iter_k*iter_j);
543
+ const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
544
+
545
+ dst += jt*ncols*ne02*D + channel*D;
546
+
547
+ // Load the partial result that needs a fixup:
548
+ float dst_val[ncols] = {0.0f};
549
+ float max_val[ncols] = {0.0f};
550
+ float rowsum[ncols] = {0.0f};
551
+ #pragma unroll
552
+ for (int j = 0; j < ncols; ++j) {
553
+ if (jt*ncols + j >= ne01) {
554
+ break;
555
+ }
556
+ dst_val[j] = dst[j*ne02*D + threadIdx.x];
557
+
558
+ const float2 tmp = dst_fixup[bidx0*ncols + j];
559
+ max_val[j] = tmp.x;
560
+ rowsum[j] = tmp.y;
561
+ }
562
+
563
+ // Iterate over previous blocks and compute the combined results.
564
+ // All CUDA blocks that get here must have a previous block that needs a fixup.
565
+ int bidx = bidx0 - 1;
566
+ int kbc_stop = kbc0;
567
+ while(true) {
568
+ const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
569
+ if (kbc == kbc_stop) { // Did not have any data.
570
+ bidx--;
571
+ kbc_stop = kbc;
572
+ continue;
573
+ }
574
+
575
+ #pragma unroll
576
+ for (int j = 0; j < ncols; ++j) {
577
+ if (jt*ncols + j >= ne01) {
578
+ break;
579
+ }
580
+ const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
581
+
582
+ const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
583
+
584
+ // Scale the current and new value accumulators depending on the max. values.
585
+ const float max_val_new = fmaxf(max_val[j], tmp.x);
586
+
587
+ const float diff_val = max_val[j] - max_val_new;
588
+ const float diff_add = tmp.x - max_val_new;
589
+
590
+ const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
591
+ const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
592
+
593
+ dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
594
+ rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
595
+
596
+ max_val[j] = max_val_new;
597
+ }
598
+
599
+ // If this block started in a previous tile we are done and don't need to combine additional partial results.
600
+ if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
601
+ break;
602
+ }
603
+ bidx--;
604
+ kbc_stop = kbc;
605
+ }
606
+
607
+ // Write back final result:
608
+ #pragma unroll
609
+ for (int j = 0; j < ncols; ++j) {
610
+ if (jt*ncols + j >= ne01) {
611
+ return;
612
+ }
613
+ dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
614
+ }
615
+ }
616
+
617
  template<int D, int parallel_blocks> // D == head size
618
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
619
  __launch_bounds__(D, 1)
 
679
  }
680
  }
681
 
682
+ // parallel_blocks == 0 is stream-k decomposition
683
+ template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
684
  void launch_fattn(
685
  ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
686
+ const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
687
  ) {
688
  const ggml_tensor * Q = dst->src[0];
689
  const ggml_tensor * K = dst->src[1];
 
702
 
703
  GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
704
 
705
+ GGML_ASSERT(Q->ne[3] == 1);
706
+
707
  ggml_cuda_pool & pool = ctx.pool();
708
  cudaStream_t main_stream = ctx.stream();
709
+ const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
710
 
711
  ggml_cuda_pool_alloc<half> K_f16(pool);
712
  ggml_cuda_pool_alloc<half> V_f16(pool);
713
  ggml_cuda_pool_alloc<float> dst_tmp(pool);
714
  ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
715
 
716
+ const char * K_data = (const char *) K->data;
717
  size_t nb11 = K->nb[1];
718
  size_t nb12 = K->nb[2];
719
  size_t nb13 = K->nb[3];
720
 
721
+ const char * V_data = (const char *) V->data;
722
  size_t nb21 = V->nb[1];
723
  size_t nb22 = V->nb[2];
724
  size_t nb23 = V->nb[3];
 
751
  nb23 = nb23*bs*sizeof(half)/ts;
752
  }
753
 
754
+ const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
755
+ const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
 
 
756
 
757
  const dim3 block_dim(WARP_SIZE, nwarps, 1);
758
+ dim3 blocks_num;
759
+ if (parallel_blocks == 0) {
760
+ // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
761
+ const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
762
+ const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
763
+ const bool short_context = K->ne[1] < 4096;
764
+
765
+ const int nblocks_stream_k = 2*nsm;
766
+
767
+ blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
768
+ blocks_num.y = 1;
769
+ blocks_num.z = 1;
770
+
771
+ dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
772
+ } else {
773
+ blocks_num.x = parallel_blocks*ntiles_x;
774
+ blocks_num.y = Q->ne[2];
775
+ blocks_num.z = Q->ne[3];
776
+
777
+ if (parallel_blocks > 1) {
778
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
779
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
780
+ }
781
+ }
782
+
783
 
784
  float scale = 1.0f;
785
  float max_bias = 0.0f;
786
  float logit_softcap = 0.0f;
787
 
788
+ memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
789
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
790
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
791
 
792
  if (logit_softcap != 0.0f) {
793
  scale /= logit_softcap;
794
  }
795
 
796
  const uint32_t n_head = Q->ne[2];
797
+ const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
798
 
799
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
800
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
801
 
802
+ fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
803
  (const char *) Q->data,
804
  K_data,
805
  V_data,
806
  mask ? ((const char *) mask->data) : nullptr,
807
+ (parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
808
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
809
  Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
810
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
 
816
  );
817
  CUDA_CHECK(cudaGetLastError());
818
 
819
+ if constexpr (parallel_blocks == 0) {
820
+ if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
821
+ const dim3 block_dim_combine(D, 1, 1);
822
+ const dim3 blocks_num_combine = blocks_num;
823
 
824
+ flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
825
+ <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
826
+ ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
827
+ }
828
+ } else if constexpr (parallel_blocks > 1) {
829
+ const dim3 block_dim_combine(D, 1, 1);
830
+ const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
831
 
832
+ flash_attn_combine_results<D, parallel_blocks>
833
+ <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
834
+ (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
835
+ }
836
  CUDA_CHECK(cudaGetLastError());
837
  }
ggml/src/ggml-cuda/fattn-mma-f16.cuh ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "mma.cuh"
3
+ #include "fattn-common.cuh"
4
+
5
+ template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
6
+ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
7
+ const float2 * const __restrict__ Q_f2,
8
+ const half2 * const __restrict__ K_h2,
9
+ const half2 * const __restrict__ V_h2,
10
+ const half * const __restrict__ maskh,
11
+ float2 * const __restrict__ dstk,
12
+ float2 * const __restrict__ dstk_fixup,
13
+ const float scale,
14
+ const float slope,
15
+ const float logit_softcap,
16
+ const int ne00,
17
+ const int ne01,
18
+ const int ne02,
19
+ const int ne03,
20
+ const int ne10,
21
+ const int ne11,
22
+ const int ne12,
23
+ const int ne13,
24
+ const int ne31,
25
+ const int nb31,
26
+ const int nb01,
27
+ const int nb02,
28
+ const int nb03,
29
+ const int nb11,
30
+ const int nb12,
31
+ const int nb13,
32
+ const int nb21,
33
+ const int nb22,
34
+ const int nb23,
35
+ const int ne0,
36
+ const int ne1,
37
+ const int ne2,
38
+ const int ne3,
39
+ const int jt,
40
+ const int kb0_start,
41
+ const int kb0_stop) {
42
+ #ifdef NEW_MMA_AVAILABLE
43
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
44
+
45
+ typedef mma_A_I16K8<half2> mma_A;
46
+ typedef mma_B_J8K8<half2> mma_B;
47
+ typedef mma_C_I16J8<float> mma_C_KQ;
48
+ typedef mma_C_I16J8<half2> mma_C_VKQ;
49
+
50
+ static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps");
51
+ constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column.
52
+
53
+ static_assert(D % nwarps == 0, "bad D");
54
+ static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
55
+
56
+ constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
57
+ extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements.
58
+
59
+ const int stride_Q = nb01 / sizeof(float2);
60
+ const int stride_KV = nb11 / sizeof(half2);
61
+ const int stride_mask = nb31 / sizeof(half);
62
+
63
+ mma_B Q_B[D/(2*mma_B::K)];
64
+ mma_C_VKQ VKQ_C[D/mma_C_VKQ::I];
65
+
66
+ float2 KQ_rowsum = {0.0f, 0.0f};
67
+ float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
68
+ float2 KQ_max_scale = {0.0f, 0.0f};
69
+
70
+ // Temporarily load Q data into tile_KV, will be loaded into registers afterwards.
71
+ // The loading is done with decreasing granularity for D for better memory bandwidth.
72
+ const half2 scale_h2 = make_half2(scale, scale);
73
+ #pragma unroll
74
+ for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
75
+ const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
76
+ const int k0_stop = D/2 - (D/2) % (1*stride_k);
77
+ const int stride_j = WARP_SIZE / stride_k;
78
+
79
+ if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
80
+ break;
81
+ }
82
+
83
+ #pragma unroll
84
+ for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
85
+ const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
86
+
87
+ if (jt*ncols + j < ne01) {
88
+ #pragma unroll
89
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
90
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
91
+
92
+ const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
93
+ tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
94
+ }
95
+ } else {
96
+ #pragma unroll
97
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
98
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
99
+
100
+ tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f);
101
+ }
102
+ }
103
+ }
104
+ }
105
+
106
+ __syncthreads();
107
+
108
+ {
109
+ const int j0 = (threadIdx.y / np) * mma_B::J;
110
+
111
+ #pragma unroll
112
+ for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
113
+ Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded);
114
+ }
115
+ }
116
+
117
+ __syncthreads();
118
+
119
+ // Iterate over ne11 == previous tokens:
120
+ for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) {
121
+ const int k_VKQ_0 = kb0*KQ_stride;
122
+ mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)];
123
+
124
+ // Load K data into tile with decreasing granularity for D for better memory bandwidth:
125
+ static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
126
+ #pragma unroll
127
+ for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
128
+ const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
129
+ const int k0_stop = D/2 - (D/2) % (1*stride_k);
130
+ const int stride_i = WARP_SIZE / stride_k;
131
+
132
+ #pragma unroll
133
+ for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) {
134
+ const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
135
+
136
+ #pragma unroll
137
+ for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) {
138
+ const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
139
+
140
+ tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ];
141
+ }
142
+ }
143
+ }
144
+
145
+ __syncthreads();
146
+
147
+ // Calculate tile of KQ:
148
+ #pragma unroll
149
+ for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) {
150
+ const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I;
151
+ #pragma unroll
152
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) {
153
+ mma_A K_A;
154
+ K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
155
+ KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]);
156
+ }
157
+ }
158
+
159
+ __syncthreads();
160
+
161
+ if (use_logit_softcap) {
162
+ static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
163
+ #pragma unroll
164
+ for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) {
165
+ #pragma unroll
166
+ for (int l = 0; l < mma_C_KQ::ne; ++l) {
167
+ KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
168
+ }
169
+ }
170
+ }
171
+
172
+ if (maskh) {
173
+ static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size");
174
+ static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size");
175
+ #pragma unroll
176
+ for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) {
177
+ const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I;
178
+ #pragma unroll
179
+ for (int l = 0; l < mma_C_KQ::ne; ++l) {
180
+ const int i = i0 + mma_C_KQ::get_i(l);
181
+ const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l);
182
+
183
+ KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
184
+ }
185
+ }
186
+ }
187
+
188
+ // Calculate softmax for each KQ column using the current max. value.
189
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
190
+ float2 KQ_max_new = KQ_max;
191
+ static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
192
+ #pragma unroll
193
+ for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
194
+ #pragma unroll
195
+ for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) {
196
+ KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
197
+ KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
198
+ }
199
+ }
200
+
201
+ // Values per KQ column are spread across 8 threads, does not need full warp reduce:
202
+ #pragma unroll
203
+ for (int offset = 16; offset > 2; offset >>= 1) {
204
+ KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
205
+ KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
206
+ }
207
+
208
+ {
209
+ const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
210
+ KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
211
+ if (diff.x <= SOFTMAX_FTZ_THRESHOLD) {
212
+ KQ_max_scale.x = 0.0f;
213
+ }
214
+ if (diff.y <= SOFTMAX_FTZ_THRESHOLD) {
215
+ KQ_max_scale.y = 0.0f;
216
+ }
217
+ KQ_max = KQ_max_new;
218
+ }
219
+
220
+ float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
221
+ static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
222
+ #pragma unroll
223
+ for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
224
+ #pragma unroll
225
+ for (int l = 0; l < mma_C_KQ::ne; ++l) {
226
+ const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y;
227
+ const float diff = KQ_C[k].x[l] - KQ_max_l;
228
+ KQ_C[k].x[l] = expf(diff);
229
+ if (diff <= SOFTMAX_FTZ_THRESHOLD) {
230
+ KQ_C[k].x[l] = 0.0f;
231
+ }
232
+
233
+ if (l % 2 == 0) {
234
+ KQ_rowsum_add.x += KQ_C[k].x[l];
235
+ } else {
236
+ KQ_rowsum_add.y += KQ_C[k].x[l];
237
+ }
238
+ }
239
+ }
240
+
241
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
242
+ KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
243
+ KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y;
244
+
245
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
246
+ #pragma unroll
247
+ for (int i = 0; i < D/mma_C_VKQ::I; ++i) {
248
+ #pragma unroll
249
+ for (int l = 0; l < mma_C_VKQ::ne; ++l) {
250
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
251
+ }
252
+ }
253
+
254
+ // Convert KQ C tiles into B tiles for VKQ calculation:
255
+ mma_B B[KQ_stride/(np*2*mma_B::K)];
256
+ static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size");
257
+ #pragma unroll
258
+ for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) {
259
+ B[k] = KQ_C[k].to_mma_B();
260
+ }
261
+
262
+ // Load V data into tile with decreasing granularity for D for better memory bandwidth:
263
+ static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
264
+ #pragma unroll
265
+ for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
266
+ const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i);
267
+ const int i0_stop = D/2 - (D/2) % (1*stride_i);
268
+ const int stride_k = WARP_SIZE / stride_i;
269
+
270
+ #pragma unroll
271
+ for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) {
272
+ const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i);
273
+
274
+ #pragma unroll
275
+ for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) {
276
+ const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i);
277
+
278
+ tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V];
279
+ }
280
+ }
281
+ }
282
+
283
+ __syncthreads();
284
+
285
+ // Calculate VKQ tile:
286
+ #pragma unroll
287
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) {
288
+ static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size");
289
+ #pragma unroll
290
+ for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) {
291
+ const int k0 = k00 + (threadIdx.y % np)*mma_A::K;
292
+
293
+ mma_A A;
294
+ A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
295
+ VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]);
296
+ }
297
+ }
298
+
299
+ __syncthreads();
300
+ }
301
+
302
+ // Finally, sum up partial KQ rowsums.
303
+ // The partial sums are spread across 8 threads each, does not need full reduce.
304
+ #pragma unroll
305
+ for (int offset = 16; offset > 2; offset >>= 1) {
306
+ KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE);
307
+ KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE);
308
+ }
309
+
310
+ // Write VKQ accumulators to shared memory in column-major format.
311
+ // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
312
+ // Also for np > 1 the combination is done via these values in shared memory.
313
+ const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data
314
+ #pragma unroll
315
+ for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
316
+ const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format.
317
+
318
+ #pragma unroll
319
+ for (int l = 0; l < mma_B::ne; ++l) {
320
+ const int k = k0 + mma_B::get_k(l);
321
+
322
+ tile_KV[j_cwd*D2_padded + k] = B.x[l];
323
+ }
324
+ }
325
+
326
+ const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset
327
+ const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
328
+ const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
329
+
330
+ if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) {
331
+ // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
332
+ ((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
333
+ }
334
+
335
+ __syncthreads();
336
+
337
+ static_assert(np == 1 || np == 2 || np == 4, "bad np");
338
+ if (np == 1) {
339
+ // No combination is needed, the meta data can be directly written from registers to VRAM.
340
+ if (needs_fixup && threadIdx.x < mma_B::J) {
341
+ float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
342
+ dstk_fixup_meta[j_cwm] = KQ_cmr;
343
+ }
344
+ if (is_fixup && threadIdx.x < mma_B::J) {
345
+ float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
346
+ dstk_fixup_meta[j_cwm] = KQ_cmr;
347
+ }
348
+ } else if (threadIdx.y % np == 0) {
349
+ // Combine the meta data for parallel warps via shared memory.
350
+ // Warps with threadIdx.y % np != 0 must NOT return early.
351
+ // All threads must return simultaneously to avoid race conditions with work on the next tile.
352
+
353
+ float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2;
354
+
355
+ float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
356
+ if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
357
+ KQ_cm = meta_j[0];
358
+ }
359
+
360
+ float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
361
+ #pragma unroll
362
+ for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
363
+ KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
364
+ }
365
+
366
+ const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
367
+ float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
368
+ if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
369
+ KQ_crs = KQ_cms*meta_j[1];
370
+ }
371
+ #pragma unroll
372
+ for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
373
+ KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
374
+ }
375
+
376
+ // Write back combined meta data:
377
+ if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
378
+ meta_j[0] = KQ_cmn; // Combined max. KQ values.
379
+ meta_j[1] = KQ_crs; // Combined KQ rowsums.
380
+ meta_j[2] = KQ_cms; // KQ max scales per parallel warp.
381
+ }
382
+ if (needs_fixup && threadIdx.x < mma_B::J) {
383
+ float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
384
+ dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
385
+ }
386
+ if (is_fixup && threadIdx.x < mma_B::J) {
387
+ float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
388
+ dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
389
+ }
390
+ }
391
+
392
+ if (np > 1) {
393
+ __syncthreads();
394
+ }
395
+
396
+ if (np == 1 || threadIdx.y % np == 0) {
397
+ // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
398
+ // The values after that are for the partial results of the individual blocks.
399
+ float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2));
400
+
401
+ #pragma unroll
402
+ for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
403
+ const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
404
+ const int k0_stop = D/2 - (D/2) % (1*stride_k);
405
+ const int stride_j = WARP_SIZE / stride_k;
406
+
407
+ if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
408
+ break;
409
+ }
410
+
411
+ #pragma unroll
412
+ for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
413
+ const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
414
+ const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J;
415
+
416
+ if (!is_fixup && jt*ncols + j_dst >= ne01) {
417
+ continue;
418
+ }
419
+ const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2;
420
+ #pragma unroll
421
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
422
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
423
+
424
+ float2 dstk_val = make_float2(0.0f, 0.0f);
425
+ #pragma unroll
426
+ for (int ip = 0; ip < np; ++ip) {
427
+ const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2];
428
+ const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]);
429
+ dstk_val.x += dstk_val_add.x*KQ_crs;
430
+ dstk_val.y += dstk_val_add.y*KQ_crs;
431
+ }
432
+
433
+ if (!needs_fixup && !is_fixup) {
434
+ const float KQ_rowsum_j = meta_j[1];
435
+ dstk_val.x /= KQ_rowsum_j;
436
+ dstk_val.y /= KQ_rowsum_j;
437
+ }
438
+
439
+ if (is_fixup) {
440
+ dstk_fixup_data[j_dst*(D/2) + k] = dstk_val;
441
+ } else {
442
+ dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val;
443
+ }
444
+ }
445
+ }
446
+ }
447
+ }
448
+
449
+ if (np > 1) {
450
+ __syncthreads();
451
+ }
452
+ #else
453
+ NO_DEVICE_CODE;
454
+ #endif // NEW_MMA_AVAILABLE
455
+ }
456
+
457
+ template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap>
458
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
459
+ __launch_bounds__(nwarps*WARP_SIZE, 2)
460
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
461
+ static __global__ void flash_attn_ext_f16(
462
+ const char * __restrict__ Q,
463
+ const char * __restrict__ K,
464
+ const char * __restrict__ V,
465
+ const char * __restrict__ mask,
466
+ float * __restrict__ dst,
467
+ float2 * __restrict__ dst_meta,
468
+ const float scale,
469
+ const float max_bias,
470
+ const float m0,
471
+ const float m1,
472
+ const uint32_t n_head_log2,
473
+ const float logit_softcap,
474
+ const int ne00,
475
+ const int ne01,
476
+ const int ne02,
477
+ const int ne03,
478
+ const int ne10,
479
+ const int ne11,
480
+ const int ne12,
481
+ const int ne13,
482
+ const int ne31,
483
+ const int nb31,
484
+ const int nb01,
485
+ const int nb02,
486
+ const int nb03,
487
+ const int nb11,
488
+ const int nb12,
489
+ const int nb13,
490
+ const int nb21,
491
+ const int nb22,
492
+ const int nb23,
493
+ const int ne0,
494
+ const int ne1,
495
+ const int ne2,
496
+ const int ne3) {
497
+ // Skip unused kernel variants for faster compilation:
498
+ if (use_logit_softcap && !(D == 128 || D == 256)) {
499
+ NO_DEVICE_CODE;
500
+ return;
501
+ }
502
+
503
+ static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride");
504
+
505
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
506
+
507
+ const int iter_k = ne11 / KQ_stride;
508
+ const int iter_j = (ne01 + (ncols - 1)) / ncols;
509
+
510
+ // kbc == k block continuous, current index in continuous ijk space.
511
+ int kbc = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x;
512
+ const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x;
513
+
514
+ // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
515
+ // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
516
+ // In the most general case >2 seams can fall into the same tile.
517
+
518
+ // kb0 == k start index when in the output tile.
519
+ int kb0_start = kbc % iter_k;
520
+ int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
521
+ while (kbc < kbc_stop && kb0_stop == iter_k) {
522
+ const int channel = kbc / (iter_k*iter_j);
523
+ const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
524
+
525
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* channel);
526
+ const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio));
527
+ const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
528
+ const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
529
+ float2 * dstk = ((float2 *) dst) + channel*(D/2);
530
+
531
+ const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
532
+
533
+ constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
534
+ if (kb0_start == 0) {
535
+ constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
536
+ flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
537
+ (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
538
+ ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
539
+ jt, kb0_start, kb0_stop);
540
+ } else {
541
+ constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
542
+ flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
543
+ (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
544
+ ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
545
+ jt, kb0_start, kb0_stop);
546
+ }
547
+
548
+ kbc += iter_k;
549
+ kbc -= kbc % iter_k;
550
+
551
+ kb0_start = 0;
552
+ kb0_stop = min(iter_k, kbc_stop - kbc);
553
+ }
554
+
555
+ if (kbc >= kbc_stop) {
556
+ return;
557
+ }
558
+
559
+ const int channel = kbc / (iter_k*iter_j);
560
+ const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
561
+
562
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* channel);
563
+ const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio));
564
+ const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
565
+ const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
566
+ float2 * dstk = ((float2 *) dst) + channel*(D/2);
567
+
568
+ const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
569
+
570
+ constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
571
+ constexpr bool needs_fixup = false;
572
+ flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
573
+ (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
574
+ ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
575
+ jt, kb0_start, kb0_stop);
576
+ }
577
+
578
+ template <int D, int cols_per_block>
579
+ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
580
+ typedef mma_A_I16K8<half2> mma_A;
581
+ typedef mma_B_J8K8<half2> mma_B;
582
+
583
+ static_assert(D % mma_B::K == 0, "bad D");
584
+ static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block");
585
+
586
+ const ggml_tensor * KQV = dst;
587
+
588
+ constexpr int KQ_stride = D <= 128 ? 64 : 32;
589
+ constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
590
+ cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8);
591
+ constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half);
592
+
593
+ float logit_softcap;
594
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
595
+
596
+ fattn_kernel_t fattn_kernel;
597
+ if (logit_softcap == 0.0f) {
598
+ constexpr bool use_logit_softcap = false;
599
+ fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
600
+ } else {
601
+ constexpr bool use_logit_softcap = true;
602
+ fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
603
+ }
604
+ launch_fattn<D, cols_per_block, 0, KQ_stride>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
605
+ }
606
+
607
+ #define DECL_FATTN_MMA_F16_CASE(D, cols_per_block) \
608
+ template void ggml_cuda_flash_attn_ext_mma_f16_case \
609
+ <D, cols_per_block>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
610
+
611
+ extern DECL_FATTN_MMA_F16_CASE( 64, 8);
612
+ extern DECL_FATTN_MMA_F16_CASE( 80, 8);
613
+ extern DECL_FATTN_MMA_F16_CASE( 96, 8);
614
+ extern DECL_FATTN_MMA_F16_CASE(112, 8);
615
+ extern DECL_FATTN_MMA_F16_CASE(128, 8);
616
+ extern DECL_FATTN_MMA_F16_CASE(256, 8);
617
+
618
+ extern DECL_FATTN_MMA_F16_CASE( 64, 16);
619
+ extern DECL_FATTN_MMA_F16_CASE( 80, 16);
620
+ extern DECL_FATTN_MMA_F16_CASE( 96, 16);
621
+ extern DECL_FATTN_MMA_F16_CASE(112, 16);
622
+ extern DECL_FATTN_MMA_F16_CASE(128, 16);
623
+ extern DECL_FATTN_MMA_F16_CASE(256, 16);
624
+
625
+ extern DECL_FATTN_MMA_F16_CASE( 64, 32);
626
+ extern DECL_FATTN_MMA_F16_CASE( 80, 32);
627
+ extern DECL_FATTN_MMA_F16_CASE( 96, 32);
628
+ extern DECL_FATTN_MMA_F16_CASE(112, 32);
629
+ extern DECL_FATTN_MMA_F16_CASE(128, 32);
630
+ extern DECL_FATTN_MMA_F16_CASE(256, 32);
631
+
632
+ extern DECL_FATTN_MMA_F16_CASE( 64, 64);
633
+ extern DECL_FATTN_MMA_F16_CASE( 80, 64);
634
+ extern DECL_FATTN_MMA_F16_CASE( 96, 64);
635
+ extern DECL_FATTN_MMA_F16_CASE(112, 64);
636
+ extern DECL_FATTN_MMA_F16_CASE(128, 64);
637
+ extern DECL_FATTN_MMA_F16_CASE(256, 64);
ggml/src/ggml-cuda/fattn-tile-f16.cu CHANGED
@@ -45,7 +45,17 @@ static __global__ void flash_attn_tile_ext_f16(
45
  const int ne2,
46
  const int ne3) {
47
  #ifdef FP16_AVAILABLE
 
 
 
 
 
 
48
  // Skip unused kernel variants for faster compilation:
 
 
 
 
49
  if (use_logit_softcap && !(D == 128 || D == 256)) {
50
  NO_DEVICE_CODE;
51
  return;
@@ -288,16 +298,18 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
288
  const ggml_tensor * Q = dst->src[0];
289
  switch (Q->ne[0]) {
290
  case 64: {
291
- constexpr int D = 64;
292
- constexpr int nwarps = 8;
 
293
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
294
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
295
  } break;
296
  case 128: {
297
- constexpr int D = 128;
298
- constexpr int nwarps = 8;
 
299
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
300
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
301
  } break;
302
  default: {
303
  GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
 
45
  const int ne2,
46
  const int ne3) {
47
  #ifdef FP16_AVAILABLE
48
+
49
+ #ifndef FLASH_ATTN_AVAILABLE
50
+ NO_DEVICE_CODE;
51
+ return;
52
+ #endif // FLASH_ATTN_AVAILABLE
53
+
54
  // Skip unused kernel variants for faster compilation:
55
+ #ifdef FP16_MMA_AVAILABLE
56
+ NO_DEVICE_CODE;
57
+ return;
58
+ #endif // FP16_MMA_AVAILABLE
59
  if (use_logit_softcap && !(D == 128 || D == 256)) {
60
  NO_DEVICE_CODE;
61
  return;
 
298
  const ggml_tensor * Q = dst->src[0];
299
  switch (Q->ne[0]) {
300
  case 64: {
301
+ constexpr int D = 64;
302
+ constexpr int nwarps = 8;
303
+ constexpr size_t nbytes_shared = 0;
304
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
305
+ launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
306
  } break;
307
  case 128: {
308
+ constexpr int D = 128;
309
+ constexpr int nwarps = 8;
310
+ constexpr size_t nbytes_shared = 0;
311
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
312
+ launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
313
  } break;
314
  default: {
315
  GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
ggml/src/ggml-cuda/fattn-tile-f32.cu CHANGED
@@ -48,7 +48,12 @@ static __global__ void flash_attn_tile_ext_f32(
48
  NO_DEVICE_CODE;
49
  return;
50
  #endif // FLASH_ATTN_AVAILABLE
 
51
  // Skip unused kernel variants for faster compilation:
 
 
 
 
52
  if (use_logit_softcap && !(D == 128 || D == 256)) {
53
  NO_DEVICE_CODE;
54
  return;
@@ -287,16 +292,18 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
287
  const ggml_tensor * Q = dst->src[0];
288
  switch (Q->ne[0]) {
289
  case 64: {
290
- constexpr int D = 64;
291
- constexpr int nwarps = 8;
 
292
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
293
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
294
  } break;
295
  case 128: {
296
- constexpr int D = 128;
297
- constexpr int nwarps = 8;
 
298
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
299
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
300
  } break;
301
  default: {
302
  GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
 
48
  NO_DEVICE_CODE;
49
  return;
50
  #endif // FLASH_ATTN_AVAILABLE
51
+
52
  // Skip unused kernel variants for faster compilation:
53
+ #ifdef FP16_MMA_AVAILABLE
54
+ NO_DEVICE_CODE;
55
+ return;
56
+ #endif // FP16_MMA_AVAILABLE
57
  if (use_logit_softcap && !(D == 128 || D == 256)) {
58
  NO_DEVICE_CODE;
59
  return;
 
292
  const ggml_tensor * Q = dst->src[0];
293
  switch (Q->ne[0]) {
294
  case 64: {
295
+ constexpr int D = 64;
296
+ constexpr int nwarps = 8;
297
+ constexpr size_t nbytes_shared = 0;
298
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
299
+ launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
300
  } break;
301
  case 128: {
302
+ constexpr int D = 128;
303
+ constexpr int nwarps = 8;
304
+ constexpr size_t nbytes_shared = 0;
305
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
306
+ launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
307
  } break;
308
  default: {
309
  GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
ggml/src/ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -42,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
42
  const int ne2,
43
  const int ne3) {
44
  #ifdef FP16_AVAILABLE
 
 
 
 
 
 
45
  // Skip unused kernel variants for faster compilation:
46
  if (use_logit_softcap && !(D == 128 || D == 256)) {
47
  NO_DEVICE_CODE;
@@ -303,7 +309,8 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
303
  fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
304
  constexpr bool need_f16_K = D != 128;
305
  constexpr bool need_f16_V = D != 128 && D != 64;
306
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
 
307
  }
308
 
309
  template <int D, ggml_type type_K, ggml_type type_V>
 
42
  const int ne2,
43
  const int ne3) {
44
  #ifdef FP16_AVAILABLE
45
+
46
+ #ifndef FLASH_ATTN_AVAILABLE
47
+ NO_DEVICE_CODE;
48
+ return;
49
+ #endif // FLASH_ATTN_AVAILABLE
50
+
51
  // Skip unused kernel variants for faster compilation:
52
  if (use_logit_softcap && !(D == 128 || D == 256)) {
53
  NO_DEVICE_CODE;
 
309
  fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
310
  constexpr bool need_f16_K = D != 128;
311
  constexpr bool need_f16_V = D != 128 && D != 64;
312
+ constexpr size_t nbytes_shared = 0;
313
+ launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
314
  }
315
 
316
  template <int D, ggml_type type_K, ggml_type type_V>
ggml/src/ggml-cuda/fattn-vec-f32.cuh CHANGED
@@ -41,6 +41,11 @@ static __global__ void flash_attn_vec_ext_f32(
41
  const int ne1,
42
  const int ne2,
43
  const int ne3) {
 
 
 
 
 
44
  // Skip unused kernel variants for faster compilation:
45
  if (use_logit_softcap && !(D == 128 || D == 256)) {
46
  NO_DEVICE_CODE;
@@ -284,7 +289,8 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
284
  fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
285
  constexpr bool need_f16_K = D != 128;
286
  constexpr bool need_f16_V = D != 128 && D != 64;
287
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
 
288
  }
289
 
290
  template <int D, ggml_type type_K, ggml_type type_V>
 
41
  const int ne1,
42
  const int ne2,
43
  const int ne3) {
44
+ #ifndef FLASH_ATTN_AVAILABLE
45
+ NO_DEVICE_CODE;
46
+ return;
47
+ #endif // FLASH_ATTN_AVAILABLE
48
+
49
  // Skip unused kernel variants for faster compilation:
50
  if (use_logit_softcap && !(D == 128 || D == 256)) {
51
  NO_DEVICE_CODE;
 
289
  fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
290
  constexpr bool need_f16_K = D != 128;
291
  constexpr bool need_f16_V = D != 128 && D != 64;
292
+ constexpr size_t nbytes_shared = 0;
293
+ launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
294
  }
295
 
296
  template <int D, ggml_type type_K, ggml_type type_V>
ggml/src/ggml-cuda/fattn-wmma-f16.cu ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Old and deprecated WMMA FlashAttention implementation.
2
+ // It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing.
3
+ // Long-term the WMMA code should be replaced with a dedicated Volta implementation.
4
+
5
+ #include "common.cuh"
6
+ #include "fattn-common.cuh"
7
+ #include "fattn-wmma-f16.cuh"
8
+
9
+ #ifdef FP16_MMA_AVAILABLE
10
+ #include <mma.h>
11
+ #endif // FP16_MMA_AVAILABLE
12
+
13
+ // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
14
+ template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
15
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
16
+ __launch_bounds__(nwarps*WARP_SIZE, 1)
17
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
18
+ static __global__ void flash_attn_ext_f16(
19
+ const char * __restrict__ Q,
20
+ const char * __restrict__ K,
21
+ const char * __restrict__ V,
22
+ const char * __restrict__ mask,
23
+ float * __restrict__ dst,
24
+ float2 * __restrict__ dst_meta,
25
+ const float scale,
26
+ const float max_bias,
27
+ const float m0,
28
+ const float m1,
29
+ const uint32_t n_head_log2,
30
+ const float logit_softcap,
31
+ const int ne00,
32
+ const int ne01,
33
+ const int ne02,
34
+ const int ne03,
35
+ const int ne10,
36
+ const int ne11,
37
+ const int ne12,
38
+ const int ne13,
39
+ const int ne31,
40
+ const int nb31,
41
+ const int nb01,
42
+ const int nb02,
43
+ const int nb03,
44
+ const int nb11,
45
+ const int nb12,
46
+ const int nb13,
47
+ const int nb21,
48
+ const int nb22,
49
+ const int nb23,
50
+ const int ne0,
51
+ const int ne1,
52
+ const int ne2,
53
+ const int ne3) {
54
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
55
+ // Skip unused kernel variants for faster compilation:
56
+ if (use_logit_softcap && !(D == 128 || D == 256)) {
57
+ NO_DEVICE_CODE;
58
+ return;
59
+ }
60
+
61
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
62
+
63
+ const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
64
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
65
+
66
+ static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
67
+ static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
68
+ constexpr int frag_m = ncols == 8 ? 32 : 16;
69
+ constexpr int frag_n = ncols == 8 ? 8 : 16;
70
+ static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
71
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
72
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
73
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
74
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
75
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
76
+
77
+ constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
78
+ constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
79
+ static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
80
+
81
+ // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
82
+ constexpr int D_padded = D + 8;
83
+ constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
84
+ constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
85
+
86
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
87
+ const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
88
+ const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
89
+ const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
90
+ const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
91
+ const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
92
+
93
+ const int stride_Q = nb01 / sizeof(float);
94
+ const int stride_KV = nb11 / sizeof(half);
95
+
96
+ const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
97
+ const half slopeh = __float2half(slopef);
98
+ const half2 slope2 = make_half2(slopef, slopef);
99
+
100
+ const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
101
+
102
+ frag_b Q_b[D/16][ncols/frag_n];
103
+
104
+ // A single buffer for temporarily holding tiles of KQ and VKQ parts:
105
+ constexpr int mem_KQ = ncols*kqs_padded*kqar;
106
+ constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
107
+ __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
108
+ float * KQ_f = (float *) KQ;
109
+ half2 * KQ2 = (half2 *) KQ;
110
+
111
+ float KQ_rowsum_f[ncols/nwarps] = {0.0f};
112
+ float KQ_max_f[ncols/nwarps];
113
+ float KQ_max_scale_f[ncols/nwarps] = {0.0f};
114
+
115
+ #pragma unroll
116
+ for (int j = 0; j < ncols/nwarps; ++j) {
117
+ KQ_max_f[j] = -FLT_MAX/2.0f;
118
+ }
119
+
120
+ half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
121
+ half2 KQ_max_h2[ncols/nwarps];
122
+ half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
123
+
124
+ #pragma unroll
125
+ for (int j = 0; j < ncols/nwarps; ++j) {
126
+ KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
127
+ }
128
+
129
+ __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
130
+ half2 * VKQ2 = (half2 *) VKQ;
131
+ #pragma unroll
132
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
133
+ const int j = j0 + threadIdx.y;
134
+ #pragma unroll
135
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
136
+ const int i = i0 + threadIdx.x;
137
+ if (i0 + WARP_SIZE > D/2 && i >= D/2) {
138
+ break;
139
+ }
140
+ VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
141
+ }
142
+ }
143
+
144
+ // Convert Q to half and apply scale, temporarily store in KQ:
145
+ #pragma unroll
146
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
147
+ const int j = j0 + threadIdx.y;
148
+ #pragma unroll
149
+ for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
150
+ const int i = i0 + threadIdx.x;
151
+ if (i0 + WARP_SIZE > D && i >= D) {
152
+ break;
153
+ }
154
+ KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
155
+ }
156
+ }
157
+
158
+ __syncthreads();
159
+
160
+ // Load Q into tensor core fragments/registers since it will be used frequently:
161
+ #pragma unroll
162
+ for (int i0 = 0; i0 < D; i0 += 16) {
163
+ #pragma unroll
164
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
165
+ nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
166
+ }
167
+ }
168
+
169
+ __syncthreads();
170
+
171
+ // Iterate over ne11 == previous tokens:
172
+ for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
173
+ // Calculate tile of KQ:
174
+ #pragma unroll
175
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
176
+ frag_c_KQ KQ_c[ncols/frag_n];
177
+ #pragma unroll
178
+ for (int j = 0; j < ncols/frag_n; ++j) {
179
+ nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
180
+ }
181
+ #pragma unroll
182
+ for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
183
+ frag_a_K K_a;
184
+ nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
185
+ #pragma unroll
186
+ for (int j = 0; j < ncols/frag_n; ++j) {
187
+ nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
188
+ }
189
+ }
190
+ #pragma unroll
191
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
192
+ nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
193
+ }
194
+ }
195
+
196
+ __syncthreads();
197
+
198
+ // Calculate softmax for each KQ column using the current max. value.
199
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
200
+ #pragma unroll
201
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
202
+ const int j = j0 + threadIdx.y;
203
+
204
+ if (std::is_same<KQ_acc_t, float>::value) {
205
+ float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
206
+ #pragma unroll
207
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
208
+ const int k = k0 + threadIdx.x;
209
+
210
+ KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
211
+
212
+ if (use_logit_softcap) {
213
+ KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
214
+ }
215
+ }
216
+
217
+ float KQ_max_new = KQ_max_f[j0/nwarps];
218
+ #pragma unroll
219
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
220
+ const int k = k0 + threadIdx.x;
221
+
222
+ KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
223
+ KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
224
+ }
225
+ KQ_max_new = warp_reduce_max(KQ_max_new);
226
+
227
+ const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
228
+ KQ_max_scale_f[j0/nwarps] = expf(diff);
229
+ if (diff <= SOFTMAX_FTZ_THRESHOLD) {
230
+ KQ_max_scale_f[j0/nwarps] = 0.0f;
231
+ }
232
+ KQ_max_f[j0/nwarps] = KQ_max_new;
233
+
234
+ float KQ_rowsum_add = 0.0f;
235
+ #pragma unroll
236
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
237
+ const int k = k0 + threadIdx.x;
238
+
239
+ const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
240
+ KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
241
+ if (diff <= SOFTMAX_FTZ_THRESHOLD) {
242
+ KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
243
+ }
244
+ KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
245
+ KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
246
+ }
247
+ KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
248
+
249
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
250
+ KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
251
+ } else {
252
+ half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
253
+ #pragma unroll
254
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
255
+ const int k = k0 + threadIdx.x;
256
+
257
+ KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
258
+
259
+ if (use_logit_softcap) {
260
+ // There is no dedicated tangens hyperbolicus function for half2.
261
+ KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
262
+ KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
263
+ /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
264
+
265
+ KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
266
+ }
267
+ }
268
+
269
+ half2 KQ_max_new = KQ_max_h2[j0/nwarps];
270
+ #pragma unroll
271
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
272
+ const int k = k0 + threadIdx.x;
273
+
274
+ KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
275
+ KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
276
+ }
277
+ KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
278
+ const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
279
+ KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
280
+ const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
281
+ *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
282
+ KQ_max_h2[j0/nwarps] = KQ_max_new;
283
+
284
+ half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
285
+ #pragma unroll
286
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
287
+ const int k = k0 + threadIdx.x;
288
+
289
+ const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
290
+ KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
291
+ const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
292
+ *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
293
+ KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
294
+ KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
295
+ }
296
+ KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
297
+
298
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
299
+ KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
300
+ }
301
+ }
302
+
303
+ __syncthreads();
304
+
305
+ frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
306
+ #pragma unroll
307
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
308
+ #pragma unroll
309
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
310
+ const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
311
+ nvcuda::wmma::load_matrix_sync(
312
+ KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
313
+ KQ + j0*(kqar*kqs_padded) + k,
314
+ kqar*kqs_padded);
315
+ }
316
+ }
317
+
318
+ frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
319
+ #pragma unroll
320
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
321
+ #pragma unroll
322
+ for (int j = 0; j < ncols/frag_n; ++j) {
323
+ nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
324
+ }
325
+
326
+ #pragma unroll
327
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
328
+ const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
329
+
330
+ frag_a_V v_a;
331
+ nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
332
+ #pragma unroll
333
+ for (int j = 0; j < ncols/frag_n; ++j) {
334
+ nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
335
+ }
336
+ }
337
+ }
338
+
339
+ __syncthreads();
340
+
341
+ const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
342
+ #pragma unroll
343
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
344
+ #pragma unroll
345
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
346
+ nvcuda::wmma::store_matrix_sync(
347
+ KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
348
+ VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
349
+ D_padded, nvcuda::wmma::mem_col_major);
350
+ }
351
+ }
352
+
353
+ __syncthreads();
354
+
355
+ #pragma unroll
356
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
357
+ const int j = j0 + threadIdx.y;
358
+
359
+ half2 VKQ_scale;
360
+ if (std::is_same<KQ_acc_t, float>::value) {
361
+ VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
362
+ } else {
363
+ VKQ_scale = KQ_max_scale_h2[j0/nwarps];
364
+ }
365
+
366
+ #pragma unroll
367
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
368
+ const int i = i0 + threadIdx.x;
369
+ if (i0 + WARP_SIZE > D/2 && i >= D/2) {
370
+ break;
371
+ }
372
+
373
+ half2 VKQ_add = make_half2(0.0f, 0.0f);
374
+ #pragma unroll
375
+ for (int l = 0; l < VKQ_ratio; ++l) {
376
+ VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
377
+ }
378
+ VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
379
+ }
380
+ }
381
+
382
+ __syncthreads();
383
+ }
384
+
385
+ #pragma unroll
386
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
387
+ const int j_VKQ = j0 + threadIdx.y;
388
+ if (ic0 + j_VKQ >= ne01) {
389
+ return;
390
+ }
391
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
392
+
393
+ float KQ_rowsum_j;
394
+ if (std::is_same<KQ_acc_t, float>::value) {
395
+ KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
396
+ } else {
397
+ KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
398
+ }
399
+
400
+ #pragma unroll
401
+ for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
402
+ const int i = i0 + threadIdx.x;
403
+ if (i0 + WARP_SIZE > D && i >= D) {
404
+ break;
405
+ }
406
+ float dst_val = VKQ[j_VKQ*D_padded + i];
407
+ if (parallel_blocks == 1) {
408
+ dst_val /= KQ_rowsum_j;
409
+ }
410
+ dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
411
+ }
412
+
413
+ if (parallel_blocks == 1 || threadIdx.x != 0) {
414
+ continue;
415
+ }
416
+
417
+ float2 dst_meta_val;
418
+ if (std::is_same<KQ_acc_t, float>::value) {
419
+ dst_meta_val.x = KQ_max_f[j0/nwarps];
420
+ } else {
421
+ dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
422
+ }
423
+ dst_meta_val.y = KQ_rowsum_j;
424
+ dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
425
+ }
426
+ #else
427
+ NO_DEVICE_CODE;
428
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
429
+ }
430
+
431
+ constexpr int get_max_power_of_2(int x) {
432
+ return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
433
+ }
434
+
435
+ static_assert(get_max_power_of_2(1) == 1, "Test failed.");
436
+ static_assert(get_max_power_of_2(2) == 2, "Test failed.");
437
+ static_assert(get_max_power_of_2(4) == 4, "Test failed.");
438
+ static_assert(get_max_power_of_2(6) == 2, "Test failed.");
439
+
440
+ // Number of VKQ rows calculated in parallel:
441
+ constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
442
+ return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
443
+ }
444
+
445
+ static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
446
+ static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
447
+ static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
448
+ static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
449
+ static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
450
+ static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
451
+ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
452
+ static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
453
+ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
454
+
455
+ template <int D, int cols_per_block, typename KQ_acc_t>
456
+ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
457
+ const ggml_tensor * KQV = dst;
458
+ const ggml_tensor * Q = dst->src[0];
459
+
460
+ constexpr int nwarps = 4;
461
+
462
+ constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
463
+ const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
464
+ const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
465
+
466
+ float logit_softcap;
467
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
468
+
469
+ if (4*blocks_num_pb1 < 2*nsm) {
470
+ constexpr int parallel_blocks = 4;
471
+ fattn_kernel_t fattn_kernel;
472
+ if (logit_softcap == 0.0f) {
473
+ constexpr bool use_logit_softcap = false;
474
+ fattn_kernel = flash_attn_ext_f16<
475
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
476
+ } else {
477
+ constexpr bool use_logit_softcap = true;
478
+ fattn_kernel = flash_attn_ext_f16<
479
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
480
+ }
481
+ launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
482
+ return;
483
+ }
484
+ if (2*blocks_num_pb1 < 2*nsm) {
485
+ constexpr int parallel_blocks = 2;
486
+ fattn_kernel_t fattn_kernel;
487
+ if (logit_softcap == 0.0f) {
488
+ constexpr bool use_logit_softcap = false;
489
+ fattn_kernel = flash_attn_ext_f16<
490
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
491
+ } else {
492
+ constexpr bool use_logit_softcap = true;
493
+ fattn_kernel = flash_attn_ext_f16<
494
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
495
+ }
496
+ launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
497
+ return;
498
+ }
499
+ constexpr int parallel_blocks = 1;
500
+ fattn_kernel_t fattn_kernel;
501
+ if (logit_softcap == 0.0f) {
502
+ constexpr bool use_logit_softcap = false;
503
+ fattn_kernel = flash_attn_ext_f16<
504
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
505
+ } else {
506
+ constexpr bool use_logit_softcap = true;
507
+ fattn_kernel = flash_attn_ext_f16<
508
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
509
+ }
510
+ launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
511
+ }
512
+
513
+ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
514
+ const ggml_tensor * KQV = dst;
515
+ const ggml_tensor * Q = dst->src[0];
516
+
517
+ const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
518
+
519
+ if (prec != GGML_PREC_DEFAULT) {
520
+ if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
521
+ constexpr int cols_per_block = 16;
522
+ switch (Q->ne[0]) {
523
+ case 64:
524
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
525
+ break;
526
+ case 80:
527
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
528
+ break;
529
+ case 96:
530
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
531
+ break;
532
+ case 112:
533
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
534
+ break;
535
+ case 128:
536
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
537
+ break;
538
+ case 256:
539
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
540
+ break;
541
+ default:
542
+ GGML_ABORT("fatal error");
543
+ break;
544
+ }
545
+ } else {
546
+ constexpr int cols_per_block = 32;
547
+ switch (Q->ne[0]) {
548
+ case 64:
549
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
550
+ break;
551
+ case 80:
552
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
553
+ break;
554
+ case 96:
555
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
556
+ break;
557
+ case 112:
558
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
559
+ break;
560
+ case 128:
561
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
562
+ break;
563
+ // case 256:
564
+ // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
565
+ // break;
566
+ default:
567
+ GGML_ABORT("fatal error");
568
+ break;
569
+ }
570
+ }
571
+ return;
572
+ }
573
+
574
+ if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
575
+ constexpr int cols_per_block = 8;
576
+ switch (Q->ne[0]) {
577
+ case 64:
578
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
579
+ break;
580
+ case 96:
581
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
582
+ break;
583
+ case 128:
584
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
585
+ break;
586
+ case 256:
587
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
588
+ break;
589
+ default:
590
+ GGML_ABORT("fatal error");
591
+ break;
592
+ }
593
+ return;
594
+ }
595
+
596
+ if (Q->ne[1] <= 32) {
597
+ constexpr int cols_per_block = 16;
598
+ switch (Q->ne[0]) {
599
+ case 64:
600
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
601
+ break;
602
+ case 80:
603
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
604
+ break;
605
+ case 96:
606
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
607
+ break;
608
+ case 112:
609
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
610
+ break;
611
+ case 128:
612
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
613
+ break;
614
+ case 256:
615
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
616
+ break;
617
+ default:
618
+ GGML_ABORT("fatal error");
619
+ break;
620
+ }
621
+ return;
622
+ }
623
+
624
+ constexpr int cols_per_block = 32;
625
+ switch (Q->ne[0]) {
626
+ case 64:
627
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
628
+ break;
629
+ case 80:
630
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
631
+ break;
632
+ case 96:
633
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
634
+ break;
635
+ case 112:
636
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
637
+ break;
638
+ case 128:
639
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
640
+ break;
641
+ case 256:
642
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
643
+ break;
644
+ default:
645
+ GGML_ABORT("fatal error");
646
+ break;
647
+ }
648
+ }
ggml/src/ggml-cuda/fattn-wmma-f16.cuh CHANGED
@@ -1,543 +1,3 @@
1
  #include "common.cuh"
2
- #include "fattn-common.cuh"
3
 
4
- #ifdef FP16_MMA_AVAILABLE
5
- #include <mma.h>
6
- #endif // FP16_MMA_AVAILABLE
7
-
8
- // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
9
- template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
10
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
11
- __launch_bounds__(nwarps*WARP_SIZE, 1)
12
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
13
- static __global__ void flash_attn_ext_f16(
14
- const char * __restrict__ Q,
15
- const char * __restrict__ K,
16
- const char * __restrict__ V,
17
- const char * __restrict__ mask,
18
- float * __restrict__ dst,
19
- float2 * __restrict__ dst_meta,
20
- const float scale,
21
- const float max_bias,
22
- const float m0,
23
- const float m1,
24
- const uint32_t n_head_log2,
25
- const float logit_softcap,
26
- const int ne00,
27
- const int ne01,
28
- const int ne02,
29
- const int ne03,
30
- const int ne10,
31
- const int ne11,
32
- const int ne12,
33
- const int ne13,
34
- const int ne31,
35
- const int nb31,
36
- const int nb01,
37
- const int nb02,
38
- const int nb03,
39
- const int nb11,
40
- const int nb12,
41
- const int nb13,
42
- const int nb21,
43
- const int nb22,
44
- const int nb23,
45
- const int ne0,
46
- const int ne1,
47
- const int ne2,
48
- const int ne3) {
49
- #ifdef FP16_MMA_AVAILABLE
50
- // Skip unused kernel variants for faster compilation:
51
- if (use_logit_softcap && !(D == 128 || D == 256)) {
52
- NO_DEVICE_CODE;
53
- return;
54
- }
55
-
56
- //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
57
-
58
- const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
59
- const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
60
-
61
- static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
62
- static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
63
- constexpr int frag_m = ncols == 8 ? 32 : 16;
64
- constexpr int frag_n = ncols == 8 ? 8 : 16;
65
- static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
66
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
67
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
68
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
69
- typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
70
- typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
71
-
72
- constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
73
- constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
74
- static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
75
-
76
- // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
77
- constexpr int D_padded = D + 8;
78
- constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
79
- constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
80
-
81
- const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
82
- const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
83
- const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
84
- const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
85
- const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
86
- const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
87
-
88
- const int stride_Q = nb01 / sizeof(float);
89
- const int stride_KV = nb11 / sizeof(half);
90
-
91
- const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
92
- const half slopeh = __float2half(slopef);
93
- const half2 slope2 = make_half2(slopef, slopef);
94
-
95
- const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
96
-
97
- frag_b Q_b[D/16][ncols/frag_n];
98
-
99
- // A single buffer for temporarily holding tiles of KQ and VKQ parts:
100
- constexpr int mem_KQ = ncols*kqs_padded*kqar;
101
- constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
102
- __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
103
- float * KQ_f = (float *) KQ;
104
- half2 * KQ2 = (half2 *) KQ;
105
-
106
- float KQ_rowsum_f[ncols/nwarps] = {0.0f};
107
- float KQ_max_f[ncols/nwarps];
108
- float KQ_max_scale_f[ncols/nwarps] = {0.0f};
109
-
110
- #pragma unroll
111
- for (int j = 0; j < ncols/nwarps; ++j) {
112
- KQ_max_f[j] = -FLT_MAX/2.0f;
113
- }
114
-
115
- half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
116
- half2 KQ_max_h2[ncols/nwarps];
117
- half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
118
-
119
- #pragma unroll
120
- for (int j = 0; j < ncols/nwarps; ++j) {
121
- KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
122
- }
123
-
124
- __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
125
- half2 * VKQ2 = (half2 *) VKQ;
126
- #pragma unroll
127
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
128
- const int j = j0 + threadIdx.y;
129
- #pragma unroll
130
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
131
- const int i = i0 + threadIdx.x;
132
- if (i0 + WARP_SIZE > D/2 && i >= D/2) {
133
- break;
134
- }
135
- VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
136
- }
137
- }
138
-
139
- // Convert Q to half and apply scale, temporarily store in KQ:
140
- #pragma unroll
141
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
142
- const int j = j0 + threadIdx.y;
143
- #pragma unroll
144
- for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
145
- const int i = i0 + threadIdx.x;
146
- if (i0 + WARP_SIZE > D && i >= D) {
147
- break;
148
- }
149
- KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
150
- }
151
- }
152
-
153
- __syncthreads();
154
-
155
- // Load Q into tensor core fragments/registers since it will be used frequently:
156
- #pragma unroll
157
- for (int i0 = 0; i0 < D; i0 += 16) {
158
- #pragma unroll
159
- for (int j0 = 0; j0 < ncols; j0 += frag_n) {
160
- nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
161
- }
162
- }
163
-
164
- __syncthreads();
165
-
166
- // Iterate over ne11 == previous tokens:
167
- for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
168
- // Calculate tile of KQ:
169
- #pragma unroll
170
- for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
171
- frag_c_KQ KQ_c[ncols/frag_n];
172
- #pragma unroll
173
- for (int j = 0; j < ncols/frag_n; ++j) {
174
- nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
175
- }
176
- #pragma unroll
177
- for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
178
- frag_a_K K_a;
179
- nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
180
- #pragma unroll
181
- for (int j = 0; j < ncols/frag_n; ++j) {
182
- nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
183
- }
184
- }
185
- #pragma unroll
186
- for (int j0 = 0; j0 < ncols; j0 += frag_n) {
187
- nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
188
- }
189
- }
190
-
191
- __syncthreads();
192
-
193
- // Calculate softmax for each KQ column using the current max. value.
194
- // The divisor is stored in KQ_rowsum and will be applied at the end.
195
- #pragma unroll
196
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
197
- const int j = j0 + threadIdx.y;
198
-
199
- if (std::is_same<KQ_acc_t, float>::value) {
200
- float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
201
- #pragma unroll
202
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
203
- const int k = k0 + threadIdx.x;
204
-
205
- KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
206
-
207
- if (use_logit_softcap) {
208
- KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
209
- }
210
- }
211
-
212
- float KQ_max_new = KQ_max_f[j0/nwarps];
213
- #pragma unroll
214
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
215
- const int k = k0 + threadIdx.x;
216
-
217
- KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
218
- KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
219
- }
220
- KQ_max_new = warp_reduce_max(KQ_max_new);
221
-
222
- const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
223
- KQ_max_scale_f[j0/nwarps] = expf(diff);
224
- if (diff <= SOFTMAX_FTZ_THRESHOLD) {
225
- KQ_max_scale_f[j0/nwarps] = 0.0f;
226
- }
227
- KQ_max_f[j0/nwarps] = KQ_max_new;
228
-
229
- float KQ_rowsum_add = 0.0f;
230
- #pragma unroll
231
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
232
- const int k = k0 + threadIdx.x;
233
-
234
- const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
235
- KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
236
- if (diff <= SOFTMAX_FTZ_THRESHOLD) {
237
- KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
238
- }
239
- KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
240
- KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
241
- }
242
- KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
243
-
244
- // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
245
- KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
246
- } else {
247
- half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
248
- #pragma unroll
249
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
250
- const int k = k0 + threadIdx.x;
251
-
252
- KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
253
-
254
- if (use_logit_softcap) {
255
- // There is no dedicated tangens hyperbolicus function for half2.
256
- KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
257
- KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
258
- /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
259
-
260
- KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
261
- }
262
- }
263
-
264
- half2 KQ_max_new = KQ_max_h2[j0/nwarps];
265
- #pragma unroll
266
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
267
- const int k = k0 + threadIdx.x;
268
-
269
- KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
270
- KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
271
- }
272
- KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
273
- const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
274
- KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
275
- const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
276
- *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
277
- KQ_max_h2[j0/nwarps] = KQ_max_new;
278
-
279
- half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
280
- #pragma unroll
281
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
282
- const int k = k0 + threadIdx.x;
283
-
284
- const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
285
- KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
286
- const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
287
- *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
288
- KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
289
- KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
290
- }
291
- KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
292
-
293
- // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
294
- KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
295
- }
296
- }
297
-
298
- __syncthreads();
299
-
300
- frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
301
- #pragma unroll
302
- for (int j0 = 0; j0 < ncols; j0 += frag_n) {
303
- #pragma unroll
304
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
305
- const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
306
- nvcuda::wmma::load_matrix_sync(
307
- KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
308
- KQ + j0*(kqar*kqs_padded) + k,
309
- kqar*kqs_padded);
310
- }
311
- }
312
-
313
- frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
314
- #pragma unroll
315
- for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
316
- #pragma unroll
317
- for (int j = 0; j < ncols/frag_n; ++j) {
318
- nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
319
- }
320
-
321
- #pragma unroll
322
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
323
- const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
324
-
325
- frag_a_V v_a;
326
- nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
327
- #pragma unroll
328
- for (int j = 0; j < ncols/frag_n; ++j) {
329
- nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
330
- }
331
- }
332
- }
333
-
334
- __syncthreads();
335
-
336
- const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
337
- #pragma unroll
338
- for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
339
- #pragma unroll
340
- for (int j0 = 0; j0 < ncols; j0 += frag_n) {
341
- nvcuda::wmma::store_matrix_sync(
342
- KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
343
- VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
344
- D_padded, nvcuda::wmma::mem_col_major);
345
- }
346
- }
347
-
348
- __syncthreads();
349
-
350
- #pragma unroll
351
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
352
- const int j = j0 + threadIdx.y;
353
-
354
- half2 VKQ_scale;
355
- if (std::is_same<KQ_acc_t, float>::value) {
356
- VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
357
- } else {
358
- VKQ_scale = KQ_max_scale_h2[j0/nwarps];
359
- }
360
-
361
- #pragma unroll
362
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
363
- const int i = i0 + threadIdx.x;
364
- if (i0 + WARP_SIZE > D/2 && i >= D/2) {
365
- break;
366
- }
367
-
368
- half2 VKQ_add = make_half2(0.0f, 0.0f);
369
- #pragma unroll
370
- for (int l = 0; l < VKQ_ratio; ++l) {
371
- VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
372
- }
373
- VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
374
- }
375
- }
376
-
377
- __syncthreads();
378
- }
379
-
380
- #pragma unroll
381
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
382
- const int j_VKQ = j0 + threadIdx.y;
383
- if (ic0 + j_VKQ >= ne01) {
384
- return;
385
- }
386
- const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
387
-
388
- float KQ_rowsum_j;
389
- if (std::is_same<KQ_acc_t, float>::value) {
390
- KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
391
- } else {
392
- KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
393
- }
394
-
395
- #pragma unroll
396
- for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
397
- const int i = i0 + threadIdx.x;
398
- if (i0 + WARP_SIZE > D && i >= D) {
399
- break;
400
- }
401
- float dst_val = VKQ[j_VKQ*D_padded + i];
402
- if (parallel_blocks == 1) {
403
- dst_val /= KQ_rowsum_j;
404
- }
405
- dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
406
- }
407
-
408
- if (parallel_blocks == 1 || threadIdx.x != 0) {
409
- continue;
410
- }
411
-
412
- float2 dst_meta_val;
413
- if (std::is_same<KQ_acc_t, float>::value) {
414
- dst_meta_val.x = KQ_max_f[j0/nwarps];
415
- } else {
416
- dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
417
- }
418
- dst_meta_val.y = KQ_rowsum_j;
419
- dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
420
- }
421
- #else
422
- NO_DEVICE_CODE;
423
- #endif // FP16_MMA_AVAILABLE
424
- }
425
-
426
- constexpr int get_max_power_of_2(int x) {
427
- return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
428
- }
429
-
430
- static_assert(get_max_power_of_2(1) == 1, "Test failed.");
431
- static_assert(get_max_power_of_2(2) == 2, "Test failed.");
432
- static_assert(get_max_power_of_2(4) == 4, "Test failed.");
433
- static_assert(get_max_power_of_2(6) == 2, "Test failed.");
434
-
435
- // Number of VKQ rows calculated in parallel:
436
- constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
437
- return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
438
- }
439
-
440
- static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
441
- static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
442
- static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
443
- static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
444
- static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
445
- static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
446
- static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
447
- static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
448
- static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
449
-
450
- template <int D, int cols_per_block, typename KQ_acc_t>
451
- void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
452
- const ggml_tensor * KQV = dst;
453
- const ggml_tensor * Q = dst->src[0];
454
-
455
- constexpr int nwarps = 4;
456
-
457
- constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
458
- const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
459
- const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
460
-
461
- float logit_softcap;
462
- memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
463
-
464
- if (4*blocks_num_pb1 < 2*nsm) {
465
- constexpr int parallel_blocks = 4;
466
- fattn_kernel_t fattn_kernel;
467
- if (logit_softcap == 0.0f) {
468
- constexpr bool use_logit_softcap = false;
469
- fattn_kernel = flash_attn_ext_f16<
470
- D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
471
- } else {
472
- constexpr bool use_logit_softcap = true;
473
- fattn_kernel = flash_attn_ext_f16<
474
- D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
475
- }
476
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
477
- return;
478
- }
479
- if (2*blocks_num_pb1 < 2*nsm) {
480
- constexpr int parallel_blocks = 2;
481
- fattn_kernel_t fattn_kernel;
482
- if (logit_softcap == 0.0f) {
483
- constexpr bool use_logit_softcap = false;
484
- fattn_kernel = flash_attn_ext_f16<
485
- D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
486
- } else {
487
- constexpr bool use_logit_softcap = true;
488
- fattn_kernel = flash_attn_ext_f16<
489
- D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
490
- }
491
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
492
- return;
493
- }
494
- constexpr int parallel_blocks = 1;
495
- fattn_kernel_t fattn_kernel;
496
- if (logit_softcap == 0.0f) {
497
- constexpr bool use_logit_softcap = false;
498
- fattn_kernel = flash_attn_ext_f16<
499
- D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
500
- } else {
501
- constexpr bool use_logit_softcap = true;
502
- fattn_kernel = flash_attn_ext_f16<
503
- D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
504
- }
505
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
506
- }
507
-
508
- #define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
509
- template void ggml_cuda_flash_attn_ext_wmma_f16_case \
510
- <D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
511
-
512
- extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
513
- extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
514
- extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float);
515
- extern DECL_FATTN_WMMA_F16_CASE(112, 16, float);
516
- extern DECL_FATTN_WMMA_F16_CASE(128, 16, float);
517
- extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
518
-
519
- extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float);
520
- extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float);
521
- extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float);
522
- extern DECL_FATTN_WMMA_F16_CASE(112, 32, float);
523
- extern DECL_FATTN_WMMA_F16_CASE(128, 32, float);
524
- // extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
525
-
526
- extern DECL_FATTN_WMMA_F16_CASE( 64, 8, half);
527
- extern DECL_FATTN_WMMA_F16_CASE( 96, 8, half);
528
- extern DECL_FATTN_WMMA_F16_CASE(128, 8, half);
529
- extern DECL_FATTN_WMMA_F16_CASE(256, 8, half);
530
-
531
- extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half);
532
- extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half);
533
- extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half);
534
- extern DECL_FATTN_WMMA_F16_CASE(112, 16, half);
535
- extern DECL_FATTN_WMMA_F16_CASE(128, 16, half);
536
- extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
537
-
538
- extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half);
539
- extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half);
540
- extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
541
- extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
542
- extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
543
- extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
 
1
  #include "common.cuh"
 
2
 
3
+ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ggml/src/ggml-cuda/fattn.cu CHANGED
@@ -1,5 +1,6 @@
1
  #include "common.cuh"
2
  #include "fattn-common.cuh"
 
3
  #include "fattn-tile-f16.cuh"
4
  #include "fattn-tile-f32.cuh"
5
  #include "fattn-vec-f16.cuh"
@@ -7,144 +8,56 @@
7
  #include "fattn-wmma-f16.cuh"
8
  #include "fattn.cuh"
9
 
10
- #include <cstdint>
 
 
11
 
12
- static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13
- const ggml_tensor * KQV = dst;
14
- const ggml_tensor * Q = dst->src[0];
15
-
16
- const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
17
-
18
- if (prec != GGML_PREC_DEFAULT) {
19
- if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
20
- constexpr int cols_per_block = 16;
21
- switch (Q->ne[0]) {
22
- case 64:
23
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
24
- break;
25
- case 80:
26
- ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
27
- break;
28
- case 96:
29
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
30
- break;
31
- case 112:
32
- ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
33
- break;
34
- case 128:
35
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
36
- break;
37
- case 256:
38
- ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
39
- break;
40
- default:
41
- GGML_ABORT("fatal error");
42
- break;
43
- }
44
- } else {
45
- constexpr int cols_per_block = 32;
46
- switch (Q->ne[0]) {
47
- case 64:
48
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
49
- break;
50
- case 80:
51
- ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
52
- break;
53
- case 96:
54
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
55
- break;
56
- case 112:
57
- ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
58
- break;
59
- case 128:
60
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
61
- break;
62
- // case 256:
63
- // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
64
- // break;
65
- default:
66
- GGML_ABORT("fatal error");
67
- break;
68
- }
69
- }
70
- return;
71
- }
72
-
73
- if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
74
- constexpr int cols_per_block = 8;
75
- switch (Q->ne[0]) {
76
- case 64:
77
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
78
- break;
79
- case 96:
80
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
81
- break;
82
- case 128:
83
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
84
- break;
85
- case 256:
86
- ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
87
- break;
88
- default:
89
- GGML_ABORT("fatal error");
90
- break;
91
- }
92
- return;
93
- }
94
-
95
- if (Q->ne[1] <= 32) {
96
- constexpr int cols_per_block = 16;
97
- switch (Q->ne[0]) {
98
- case 64:
99
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
100
- break;
101
- case 80:
102
- ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
103
- break;
104
- case 96:
105
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
106
- break;
107
- case 112:
108
- ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
109
- break;
110
- case 128:
111
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
112
- break;
113
- case 256:
114
- ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
115
- break;
116
- default:
117
- GGML_ABORT("fatal error");
118
- break;
119
- }
120
- return;
121
- }
122
-
123
- constexpr int cols_per_block = 32;
124
  switch (Q->ne[0]) {
125
  case 64:
126
- ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
127
  break;
128
  case 80:
129
- ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
130
  break;
131
  case 96:
132
- ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
133
  break;
134
  case 112:
135
- ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
136
  break;
137
  case 128:
138
- ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
139
  break;
140
  case 256:
141
- ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
142
  break;
143
  default:
144
  GGML_ABORT("fatal error");
145
  break;
146
  }
147
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  #define FATTN_VEC_F16_CASE(D, type_K, type_V) \
149
  if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
150
  ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
@@ -322,11 +235,19 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
322
  return;
323
  }
324
 
325
- if (!fp16_mma_available(cc)) {
326
- if (Q->ne[1] <= 8) {
327
- ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
 
 
 
 
328
  } else {
329
- ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
 
 
 
 
330
  }
331
  return;
332
  }
@@ -341,5 +262,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
341
  }
342
  }
343
 
344
- ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
 
 
 
 
 
345
  }
 
1
  #include "common.cuh"
2
  #include "fattn-common.cuh"
3
+ #include "fattn-mma-f16.cuh"
4
  #include "fattn-tile-f16.cuh"
5
  #include "fattn-tile-f32.cuh"
6
  #include "fattn-vec-f16.cuh"
 
8
  #include "fattn-wmma-f16.cuh"
9
  #include "fattn.cuh"
10
 
11
+ template <int cols_per_block>
12
+ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13
+ const ggml_tensor * Q = dst->src[0];
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  switch (Q->ne[0]) {
16
  case 64:
17
+ ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst);
18
  break;
19
  case 80:
20
+ ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst);
21
  break;
22
  case 96:
23
+ ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst);
24
  break;
25
  case 112:
26
+ ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst);
27
  break;
28
  case 128:
29
+ ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst);
30
  break;
31
  case 256:
32
+ ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst);
33
  break;
34
  default:
35
  GGML_ABORT("fatal error");
36
  break;
37
  }
38
  }
39
+
40
+ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
41
+ const ggml_tensor * Q = dst->src[0];
42
+
43
+ if (Q->ne[1] <= 8) {
44
+ ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
45
+ return;
46
+ }
47
+
48
+ if (Q->ne[1] <= 16) {
49
+ ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst);
50
+ return;
51
+ }
52
+
53
+ if (Q->ne[1] <= 32) {
54
+ ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst);
55
+ return;
56
+ }
57
+
58
+ ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst);
59
+ }
60
+
61
  #define FATTN_VEC_F16_CASE(D, type_K, type_V) \
62
  if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
63
  ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
 
235
  return;
236
  }
237
 
238
+ if (!new_mma_available(cc)) {
239
+ if (prec == GGML_PREC_DEFAULT) {
240
+ if (Q->ne[1] <= 8) {
241
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
242
+ } else {
243
+ ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
244
+ }
245
  } else {
246
+ if (Q->ne[1] <= 8) {
247
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
248
+ } else {
249
+ ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
250
+ }
251
  }
252
  return;
253
  }
 
262
  }
263
  }
264
 
265
+ // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
266
+ if (cc == GGML_CUDA_CC_VOLTA) {
267
+ ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
268
+ }
269
+
270
+ ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
271
  }
ggml/src/ggml-cuda/mma.cuh CHANGED
@@ -1,11 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #include "common.cuh"
2
 
3
- struct mma_int_A_I16K4 {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  static constexpr int I = 16;
5
  static constexpr int K = 4;
6
  static constexpr int ne = 2;
7
 
8
- int x[ne] = {0};
9
 
10
  static __device__ __forceinline__ int get_i(const int l) {
11
  const int ret = (l%2) * (I/2) + threadIdx.x / K;
@@ -21,27 +77,35 @@ struct mma_int_A_I16K4 {
21
  return ret;
22
  }
23
 
24
- __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
25
- #if defined(INT8_MMA_AVAILABLE)
26
- const int * xs = xs0 + (threadIdx.x%I)*stride;
27
- asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
28
- : "+r"(x[0]), "+r"(x[1])
29
- : "l"(xs));
30
- #else
31
  #pragma unroll
32
  for (int l = 0; l < ne; ++l) {
33
  x[l] = xs0[get_i(l)*stride + get_k(l)];
34
  }
35
- #endif // defined(INT8_MMA_AVAILABLE)
 
 
 
 
 
 
 
 
 
 
 
36
  }
37
  };
38
 
39
- struct mma_int_A_I16K8 {
 
 
 
40
  static constexpr int I = 16;
41
  static constexpr int K = 8;
42
  static constexpr int ne = 4;
43
 
44
- int x[ne] = {0};
45
 
46
  static __device__ __forceinline__ int get_i(const int l) {
47
  const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
@@ -57,31 +121,62 @@ struct mma_int_A_I16K8 {
57
  return ret;
58
  }
59
 
60
- __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
61
- #if defined(INT8_MMA_AVAILABLE)
62
- const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
63
- asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
64
- : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
65
- : "l"(xs));
66
- #else
67
  #pragma unroll
68
  for (int l = 0; l < ne; ++l) {
69
  x[l] = xs0[get_i(l)*stride + get_k(l)];
70
  }
71
- #endif // defined(INT8_MMA_AVAILABLE)
72
  }
73
 
74
- __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
75
- ((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  }
77
  };
78
 
79
- struct mma_int_B_J8K4 {
 
 
 
80
  static constexpr int J = 8;
81
  static constexpr int K = 4;
82
  static constexpr int ne = 1;
83
 
84
- int x[ne] = {0};
85
 
86
  static __device__ __forceinline__ int get_j(const int /* l */) {
87
  const int ret = threadIdx.x / K;
@@ -97,27 +192,34 @@ struct mma_int_B_J8K4 {
97
  return ret;
98
  }
99
 
100
- __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
101
- #if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
102
- const int * xs = xs0 + (threadIdx.x%J)*stride;
103
- asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
104
- : "+r"(x[0])
105
- : "l"(xs));
106
- #else
107
  #pragma unroll
108
  for (int l = 0; l < ne; ++l) {
109
  x[l] = xs0[get_j(l)*stride + get_k(l)];
110
  }
111
- #endif // defined(INT8_MMA_AVAILABLE)
 
 
 
 
 
 
 
 
 
 
112
  }
113
  };
114
 
115
- struct mma_int_B_J8K8 {
 
 
 
116
  static constexpr int J = 8;
117
  static constexpr int K = 8;
118
  static constexpr int ne = 2;
119
 
120
- int x[ne] = {0};
121
 
122
  static __device__ __forceinline__ int get_j(const int /* l */) {
123
  const int ret = threadIdx.x / (K/2);
@@ -133,22 +235,31 @@ struct mma_int_B_J8K8 {
133
  return ret;
134
  }
135
 
136
- __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
137
- #if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
138
- const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
139
- asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
140
- : "+r"(x[0]), "+r"(x[1])
141
- : "l"(xs));
142
- #else
143
  #pragma unroll
144
  for (int l = 0; l < ne; ++l) {
145
  x[l] = xs0[get_j(l)*stride + get_k(l)];
146
  }
147
- #endif // defined(INT8_MMA_AVAILABLE)
 
 
 
 
 
 
 
 
 
 
 
148
  }
149
  };
150
 
151
- struct mma_int_C_I16J8 {
 
 
 
 
152
  static constexpr int I = 16;
153
  static constexpr int J = 8;
154
  static constexpr int ne = 4;
@@ -169,8 +280,8 @@ struct mma_int_C_I16J8 {
169
  return ret;
170
  }
171
 
172
- __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
173
- #ifdef INT8_MMA_AVAILABLE
174
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
175
  asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
176
  : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
@@ -188,11 +299,11 @@ struct mma_int_C_I16J8 {
188
  GGML_UNUSED(mma_A);
189
  GGML_UNUSED(mma_B);
190
  NO_DEVICE_CODE;
191
- #endif // INT8_MMA_AVAILABLE
192
  }
193
 
194
- __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
195
- #ifdef INT8_MMA_AVAILABLE
196
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
197
  asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
198
  : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
@@ -216,6 +327,132 @@ struct mma_int_C_I16J8 {
216
  GGML_UNUSED(mma_A);
217
  GGML_UNUSED(mma_B);
218
  NO_DEVICE_CODE;
219
- #endif // INT8_MMA_AVAILABLE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  }
221
  };
 
1
+ // This file contains primitives that expose the tensor core PTX instructions for CUDA code.
2
+ // The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
3
+ // The documentation for the PTX instructions can be found under:
4
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
5
+ //
6
+ // Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
7
+ // A is a row-major matrix with shape I x K.
8
+ // B is a column-major matrix with shape K x J.
9
+ // C is a column-major matrix with shape I x J.
10
+ // Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements.
11
+ // The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile.
12
+ // All matrix tiles have ne physical 32 bit elements per warp.
13
+ //
14
+ // As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
15
+
16
  #include "common.cuh"
17
 
18
+
19
+ #if CUDART_VERSION >= 11800
20
+
21
+ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
22
+ int ret = 0;
23
+
24
+ #ifdef NEW_MMA_AVAILABLE
25
+ asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
26
+ : "+r"(ret) : "r"(x));
27
+ #else
28
+ NO_DEVICE_CODE;
29
+ #endif // defined(NEW_MMA_AVAILABLE)
30
+ return ret;
31
+ }
32
+
33
+ #else
34
+
35
+ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
36
+ // Imagine transposing row-major matrix to column-major matrix.
37
+ const int src_i_low = 2 * (threadIdx.x % 4);
38
+ const int src_i_high = src_i_low + 1;
39
+ const int src_j = threadIdx.x / 4;
40
+
41
+ const int src_laneid_low = src_i_low * 4 + src_j / 2;
42
+ const int src_laneid_high = src_i_high * 4 + src_j / 2;
43
+
44
+ const int shift_low = ((src_j + 0) % 2) * 16;
45
+ const int shift_high = ((src_j + 1) % 2) * 16;
46
+
47
+ const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
48
+ const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
49
+
50
+ return ret_low | ret_high;
51
+ }
52
+
53
+ #endif // CUDART_VERSION >= 11800
54
+
55
+
56
+ template <typename T>
57
+ struct mma_A_I16K4 {
58
+ static_assert(sizeof(T) == 4, "bad type size");
59
+
60
  static constexpr int I = 16;
61
  static constexpr int K = 4;
62
  static constexpr int ne = 2;
63
 
64
+ T x[ne];
65
 
66
  static __device__ __forceinline__ int get_i(const int l) {
67
  const int ret = (l%2) * (I/2) + threadIdx.x / K;
 
77
  return ret;
78
  }
79
 
80
+ __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
 
 
 
 
 
 
81
  #pragma unroll
82
  for (int l = 0; l < ne; ++l) {
83
  x[l] = xs0[get_i(l)*stride + get_k(l)];
84
  }
85
+ }
86
+
87
+ __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
88
+ #ifdef NEW_MMA_AVAILABLE
89
+ int * xi = (int *) x;
90
+ const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride;
91
+ asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
92
+ : "+r"(xi[0]), "+r"(xi[1])
93
+ : "l"(xs));
94
+ #else
95
+ load_generic(xs0, stride);
96
+ #endif // NEW_MMA_AVAILABLE
97
  }
98
  };
99
 
100
+ template <typename T>
101
+ struct mma_A_I16K8 {
102
+ static_assert(sizeof(T) == 4, "bad type size");
103
+
104
  static constexpr int I = 16;
105
  static constexpr int K = 8;
106
  static constexpr int ne = 4;
107
 
108
+ T x[ne];
109
 
110
  static __device__ __forceinline__ int get_i(const int l) {
111
  const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
 
121
  return ret;
122
  }
123
 
124
+ __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
 
 
 
 
 
 
125
  #pragma unroll
126
  for (int l = 0; l < ne; ++l) {
127
  x[l] = xs0[get_i(l)*stride + get_k(l)];
128
  }
 
129
  }
130
 
131
+ __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
132
+ #ifdef NEW_MMA_AVAILABLE
133
+ int * xi = (int * ) x;
134
+ const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
135
+ asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
136
+ : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
137
+ : "l"(xs));
138
+ #else
139
+ GGML_UNUSED(xs0);
140
+ GGML_UNUSED(stride);
141
+ NO_DEVICE_CODE;
142
+ #endif // NEW_MMA_AVAILABLE
143
+ }
144
+
145
+ __device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) {
146
+ #ifdef NEW_MMA_AVAILABLE
147
+ int * xi = (int * ) x;
148
+ const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
149
+ asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
150
+ : "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3])
151
+ : "l"(xs));
152
+ #else
153
+ GGML_UNUSED(xs0);
154
+ GGML_UNUSED(stride);
155
+ NO_DEVICE_CODE;
156
+ #endif // NEW_MMA_AVAILABLE
157
+ }
158
+
159
+ __device__ __forceinline__ void transpose() {
160
+ int * xi = (int *) x;
161
+ xi[0] = ggml_cuda_movmatrix(xi[0]);
162
+
163
+ const int tmp = ggml_cuda_movmatrix(xi[1]);
164
+ xi[1] = ggml_cuda_movmatrix(xi[2]);
165
+ xi[2] = tmp;
166
+
167
+ xi[3] = ggml_cuda_movmatrix(xi[3]);
168
  }
169
  };
170
 
171
+ template <typename T>
172
+ struct mma_B_J8K4 {
173
+ static_assert(sizeof(T) == 4, "bad type size");
174
+
175
  static constexpr int J = 8;
176
  static constexpr int K = 4;
177
  static constexpr int ne = 1;
178
 
179
+ T x[ne];
180
 
181
  static __device__ __forceinline__ int get_j(const int /* l */) {
182
  const int ret = threadIdx.x / K;
 
192
  return ret;
193
  }
194
 
195
+ __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
 
 
 
 
 
 
196
  #pragma unroll
197
  for (int l = 0; l < ne; ++l) {
198
  x[l] = xs0[get_j(l)*stride + get_k(l)];
199
  }
200
+ }
201
+
202
+ __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
203
+ #ifdef NEW_MMA_AVAILABLE
204
+ int * xi = (int *) x;
205
+ const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride;
206
+ asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
207
+ : "+r"(xi[0]) : "l"(xs));
208
+ #else
209
+ load_generic(xs0, stride);
210
+ #endif // NEW_MMA_AVAILABLE
211
  }
212
  };
213
 
214
+ template <typename T>
215
+ struct mma_B_J8K8 {
216
+ static_assert(sizeof(T) == 4, "bad type size");
217
+
218
  static constexpr int J = 8;
219
  static constexpr int K = 8;
220
  static constexpr int ne = 2;
221
 
222
+ T x[ne];
223
 
224
  static __device__ __forceinline__ int get_j(const int /* l */) {
225
  const int ret = threadIdx.x / (K/2);
 
235
  return ret;
236
  }
237
 
238
+ __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
 
 
 
 
 
 
239
  #pragma unroll
240
  for (int l = 0; l < ne; ++l) {
241
  x[l] = xs0[get_j(l)*stride + get_k(l)];
242
  }
243
+ }
244
+
245
+ __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
246
+ #ifdef NEW_MMA_AVAILABLE
247
+ int * xi = (int *) x;
248
+ const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
249
+ asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
250
+ : "+r"(xi[0]), "+r"(xi[1])
251
+ : "l"(xs));
252
+ #else
253
+ load_generic(xs0, stride);
254
+ #endif // NEW_MMA_AVAILABLE
255
  }
256
  };
257
 
258
+ template <typename T>
259
+ struct mma_C_I16J8 {};
260
+
261
+ template <>
262
+ struct mma_C_I16J8<int> {
263
  static constexpr int I = 16;
264
  static constexpr int J = 8;
265
  static constexpr int ne = 4;
 
280
  return ret;
281
  }
282
 
283
+ __device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) {
284
+ #ifdef NEW_MMA_AVAILABLE
285
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
286
  asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
287
  : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
 
299
  GGML_UNUSED(mma_A);
300
  GGML_UNUSED(mma_B);
301
  NO_DEVICE_CODE;
302
+ #endif // NEW_MMA_AVAILABLE
303
  }
304
 
305
+ __device__ __forceinline__ void mma(const mma_A_I16K8<int> & mma_A, const mma_B_J8K8<int> & mma_B) {
306
+ #ifdef NEW_MMA_AVAILABLE
307
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
308
  asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
309
  : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
 
327
  GGML_UNUSED(mma_A);
328
  GGML_UNUSED(mma_B);
329
  NO_DEVICE_CODE;
330
+ #endif // NEW_MMA_AVAILABLE
331
+ }
332
+ };
333
+
334
+ template <>
335
+ struct mma_C_I16J8<half2> {
336
+ static constexpr int I = 16;
337
+ static constexpr int J = 4;
338
+ static constexpr int ne = 2;
339
+
340
+ half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}};
341
+
342
+ static __device__ __forceinline__ int get_i(const int l) {
343
+ const int ret = l * (I/2) + threadIdx.x / J;
344
+ GGML_CUDA_ASSUME(ret >= 0);
345
+ GGML_CUDA_ASSUME(ret < I);
346
+ return ret;
347
+ }
348
+
349
+ static __device__ __forceinline__ int get_j(const int /* l */) {
350
+ const int ret = threadIdx.x % J;
351
+ GGML_CUDA_ASSUME(ret >= 0);
352
+ GGML_CUDA_ASSUME(ret < J);
353
+ return ret;
354
+ }
355
+
356
+ __device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
357
+ #ifdef NEW_MMA_AVAILABLE
358
+ int * Axi = (int *) mma_A.x;
359
+ int * Bxi = (int *) mma_B.x;
360
+ int * xi = (int *) x;
361
+ #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
362
+ asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
363
+ : "+r"(xi[0]), "+r"(xi[1])
364
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
365
+ #else
366
+ // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
367
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
368
+ : "+r"(xi[0]), "+r"(xi[1])
369
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
370
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
371
+ : "+r"(xi[0]), "+r"(xi[1])
372
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
373
+ #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
374
+ #else
375
+ GGML_UNUSED(mma_A);
376
+ GGML_UNUSED(mma_B);
377
+ NO_DEVICE_CODE;
378
+ #endif // NEW_MMA_AVAILABLE
379
+ }
380
+
381
+ __device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
382
+ mma_B_J8K8<half2> mma_B;
383
+
384
+ int * xi = (int *) x;
385
+ int * Bxi = (int *) mma_B.x;
386
+ Bxi[0] = ggml_cuda_movmatrix(xi[0]);
387
+ Bxi[1] = ggml_cuda_movmatrix(xi[1]);
388
+
389
+ return mma_B;
390
+ }
391
+ };
392
+
393
+ template <>
394
+ struct mma_C_I16J8<float> {
395
+ static constexpr int I = 16;
396
+ static constexpr int J = 8;
397
+ static constexpr int ne = 4;
398
+
399
+ float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f};
400
+
401
+ static __device__ __forceinline__ int get_i(const int l) {
402
+ const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
403
+ GGML_CUDA_ASSUME(ret >= 0);
404
+ GGML_CUDA_ASSUME(ret < I);
405
+ return ret;
406
+ }
407
+
408
+ static __device__ __forceinline__ int get_j(const int l) {
409
+ const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
410
+ GGML_CUDA_ASSUME(ret >= 0);
411
+ GGML_CUDA_ASSUME(ret < J);
412
+ return ret;
413
+ }
414
+
415
+ __device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
416
+ #ifdef NEW_MMA_AVAILABLE
417
+ int * Axi = (int *) mma_A.x;
418
+ int * Bxi = (int *) mma_B.x;
419
+ int * xi = (int *) x;
420
+ #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
421
+ asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
422
+ : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
423
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
424
+ #else
425
+ // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
426
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
427
+ : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
428
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
429
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
430
+ : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
431
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
432
+ #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
433
+ #else
434
+ GGML_UNUSED(mma_A);
435
+ GGML_UNUSED(mma_B);
436
+ NO_DEVICE_CODE;
437
+ #endif // NEW_MMA_AVAILABLE
438
+ }
439
+
440
+ __device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
441
+ mma_B_J8K8<half2> mma_B;
442
+ mma_B.x[0] = make_half2(x[0], x[1]);
443
+ mma_B.x[1] = make_half2(x[2], x[3]);
444
+
445
+ int * Bxi = (int *) mma_B.x;
446
+ Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
447
+ Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
448
+
449
+ return mma_B;
450
+ }
451
+
452
+ __device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) {
453
+ #pragma unroll
454
+ for (int l = 0; l < ne; ++l) {
455
+ x[l] = xs0[get_j(l)*stride + get_i(l)];
456
+ }
457
  }
458
  };
ggml/src/ggml-cuda/mmq.cu CHANGED
@@ -132,7 +132,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
132
  return false;
133
  }
134
 
135
- if (int8_mma_available(cc)) {
136
  return true;
137
  }
138
 
 
132
  return false;
133
  }
134
 
135
+ if (new_mma_available(cc)) {
136
  return true;
137
  }
138
 
ggml/src/ggml-cuda/mmq.cuh CHANGED
@@ -87,7 +87,7 @@ struct tile_x_sizes {
87
  };
88
 
89
  static constexpr int get_mmq_x_max_host(const int cc) {
90
- return int8_mma_available(cc) ? 128 :
91
  #ifdef GGML_CUDA_FORCE_MMQ
92
  cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64;
93
  #else
@@ -96,9 +96,9 @@ static constexpr int get_mmq_x_max_host(const int cc) {
96
  }
97
 
98
  static constexpr __device__ int get_mmq_x_max_device() {
99
- #ifdef INT8_MMA_AVAILABLE
100
  return 128;
101
- #else // INT8_MMA_AVAILABLE
102
 
103
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
104
  return 128;
@@ -116,7 +116,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
116
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
117
 
118
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
119
- #endif // INT8_MMA_AVAILABLE
120
  }
121
 
122
  static constexpr int get_mmq_y_host(const int cc) {
@@ -209,10 +209,10 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
209
  #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
210
 
211
  static int mmq_get_granularity_host(const int mmq_x, const int cc) {
212
- return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
213
  }
214
 
215
- #ifdef INT8_MMA_AVAILABLE
216
  static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
217
  return mmq_x >= 48 ? 16 : 8;
218
  }
@@ -220,21 +220,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
220
  static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
221
  return 8;
222
  }
223
- #endif // INT8_MMA_AVAILABLE
224
 
225
  // ------------------------------------------------------------
226
 
227
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
228
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
229
 
230
- #ifdef INT8_MMA_AVAILABLE
231
  int * x_qs = (int *) x_tile;
232
  float * x_df = (float *) (x_qs + 2*WARP_SIZE);
233
  #else
234
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
235
  int * x_qs = (int *) x_tile;
236
  float * x_df = (float *) (x_qs + txs.qs);
237
- #endif // INT8_MMA_AVAILABLE
238
 
239
  const int kbx = threadIdx.x / QI4_0;
240
  const int kqsx = threadIdx.x % QI4_0;
@@ -250,12 +250,12 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
250
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
251
  const int qs0 = get_int_b2(bxi->qs, kqsx);
252
 
253
- #ifdef INT8_MMA_AVAILABLE
254
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
255
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
256
  #else
257
  x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
258
- #endif // INT8_MMA_AVAILABLE
259
  }
260
 
261
  const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
@@ -271,11 +271,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
271
 
272
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
273
 
274
- #ifdef INT8_MMA_AVAILABLE
275
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
276
  #else
277
  x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
278
- #endif // INT8_MMA_AVAILABLE
279
  }
280
  }
281
 
@@ -322,14 +322,14 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
322
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
323
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
324
 
325
- #ifdef INT8_MMA_AVAILABLE
326
  int * x_qs = (int *) x_tile;
327
  half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
328
  #else
329
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
330
  int * x_qs = (int *) x_tile;
331
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
332
- #endif // INT8_MMA_AVAILABLE
333
 
334
  const int kbx = threadIdx.x / QI4_1;
335
  const int kqsx = threadIdx.x % QI4_1;
@@ -345,12 +345,12 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
345
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
346
  const int qs0 = get_int_b4(bxi->qs, kqsx);
347
 
348
- #ifdef INT8_MMA_AVAILABLE
349
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
350
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
351
  #else
352
  x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
353
- #endif // INT8_MMA_AVAILABLE
354
  }
355
 
356
  const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
@@ -366,11 +366,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
366
 
367
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
368
 
369
- #ifdef INT8_MMA_AVAILABLE
370
  x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
371
  #else
372
  x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
373
- #endif // INT8_MMA_AVAILABLE
374
  }
375
  }
376
 
@@ -417,14 +417,14 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
417
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
418
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
419
 
420
- #ifdef INT8_MMA_AVAILABLE
421
  int * x_qs = (int *) x_tile;
422
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
423
  #else
424
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
425
  int * x_qs = (int *) x_tile;
426
  float * x_df = (float *) (x_qs + txs.qs);
427
- #endif // INT8_MMA_AVAILABLE
428
 
429
  const int kbx = threadIdx.x / QI5_0;
430
  const int kqsx = threadIdx.x % QI5_0;
@@ -456,13 +456,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
456
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
457
  qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
458
 
459
- #ifdef INT8_MMA_AVAILABLE
460
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
461
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
462
  #else
463
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
464
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
465
- #endif // INT8_MMA_AVAILABLE
466
  }
467
 
468
  const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
@@ -478,25 +478,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
478
 
479
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
480
 
481
- #ifdef INT8_MMA_AVAILABLE
482
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
483
  #else
484
  x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
485
- #endif // INT8_MMA_AVAILABLE
486
  }
487
  }
488
 
489
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
490
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
491
 
492
- #ifdef INT8_MMA_AVAILABLE
493
  int * x_qs = (int *) x_tile;
494
  half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
495
  #else
496
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
497
  int * x_qs = (int *) x_tile;
498
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
499
- #endif // INT8_MMA_AVAILABLE
500
 
501
  const int kbx = threadIdx.x / QI5_1;
502
  const int kqsx = threadIdx.x % QI5_1;
@@ -526,13 +526,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
526
  qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
527
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
528
 
529
- #ifdef INT8_MMA_AVAILABLE
530
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
531
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
532
  #else
533
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
534
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
535
- #endif // INT8_MMA_AVAILABLE
536
  }
537
 
538
  const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
@@ -548,25 +548,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
548
 
549
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
550
 
551
- #ifdef INT8_MMA_AVAILABLE
552
  x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
553
  #else
554
  x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
555
- #endif // INT8_MMA_AVAILABLE
556
  }
557
  }
558
 
559
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
560
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
561
 
562
- #ifdef INT8_MMA_AVAILABLE
563
  int * x_qs = (int *) x_tile;
564
  float * x_df = (float *) (x_tile + 2*WARP_SIZE);
565
  #else
566
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
567
  int * x_qs = (int *) x_tile;
568
  float * x_df = (float *) (x_qs + txs.qs);
569
- #endif // INT8_MMA_AVAILABLE
570
 
571
  const int kbx = threadIdx.x / QI8_0;
572
  const int kqsx = threadIdx.x % QI8_0;
@@ -581,13 +581,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
581
 
582
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
583
 
584
- #ifdef INT8_MMA_AVAILABLE
585
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
586
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
587
  #else
588
  x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
589
  x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
590
- #endif // INT8_MMA_AVAILABLE
591
  }
592
 
593
  const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
@@ -603,11 +603,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
603
 
604
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
605
 
606
- #ifdef INT8_MMA_AVAILABLE
607
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
608
  #else
609
  x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
610
- #endif // INT8_MMA_AVAILABLE
611
  }
612
  }
613
 
@@ -645,9 +645,9 @@ template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
645
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
646
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
647
 
648
- typedef mma_int_A_I16K8 mma_A;
649
- typedef mma_int_B_J8K8 mma_B;
650
- typedef mma_int_C_I16J8 mma_C;
651
 
652
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
653
  constexpr int rows_per_warp = 2 * granularity;
@@ -672,7 +672,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
672
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
673
  const int k0 = k00 + k01;
674
 
675
- A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
676
  }
677
 
678
  #pragma unroll
@@ -695,7 +695,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
695
  mma_B B;
696
  float dB[mma_C::ne/2];
697
 
698
- B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
699
 
700
  #pragma unroll
701
  for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -711,7 +711,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
711
  #pragma unroll
712
  for (int n = 0; n < ntx; ++n) {
713
  mma_C C;
714
- C.mma_K8(A[n][k01/QI8_0], B);
715
 
716
  #pragma unroll
717
  for (int l = 0; l < mma_C::ne; ++l) {
@@ -756,9 +756,9 @@ template <int mmq_x, int mmq_y, int nwarps>
756
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
757
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
758
 
759
- typedef mma_int_A_I16K8 mma_A;
760
- typedef mma_int_B_J8K8 mma_B;
761
- typedef mma_int_C_I16J8 mma_C;
762
 
763
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
764
  constexpr int rows_per_warp = 2 * granularity;
@@ -782,7 +782,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
782
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
783
  const int k0 = k00 + k01;
784
 
785
- A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
786
  }
787
 
788
  #pragma unroll
@@ -805,7 +805,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
805
  mma_B B;
806
  float2 dsB[mma_C::ne/2];
807
 
808
- B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
809
 
810
  #pragma unroll
811
  for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -817,7 +817,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
817
  #pragma unroll
818
  for (int n = 0; n < ntx; ++n) {
819
  mma_C C;
820
- C.mma_K8(A[n][k01/QI8_1], B);
821
 
822
  #pragma unroll
823
  for (int l = 0; l < mma_C::ne; ++l) {
@@ -864,12 +864,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
864
  template <int mmq_x, int mmq_y, int nwarps>
865
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
866
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
867
- #ifdef INT8_MMA_AVAILABLE
868
 
869
- typedef mma_int_A_I16K4 mma_A;
870
- typedef mma_int_A_I16K8 mma_A_K8;
871
- typedef mma_int_B_J8K4 mma_B;
872
- typedef mma_int_C_I16J8 mma_C;
873
 
874
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
875
  constexpr int rows_per_warp = 2 * granularity;
@@ -893,7 +893,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
893
  for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
894
  const int k0 = k00 + k01;
895
 
896
- ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
897
  }
898
 
899
  #pragma unroll
@@ -916,8 +916,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
916
  mma_B B[2];
917
  float dB[mma_C::ne/2];
918
 
919
- B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
920
- B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
 
921
 
922
  #pragma unroll
923
  for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -929,8 +930,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
929
  #pragma unroll
930
  for (int n = 0; n < ntx; ++n) {
931
  mma_C C[2];
932
- C[0].mma_K4(A[n][k01/4 + 0], B[0]);
933
- C[1].mma_K4(A[n][k01/4 + 1], B[1]);
934
 
935
  #pragma unroll
936
  for (int l = 0; l < mma_C::ne; ++l) {
@@ -942,20 +943,20 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
942
  #else
943
  GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
944
  NO_DEVICE_CODE;
945
- #endif // INT8_MMA_AVAILABLE
946
  }
947
 
948
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
949
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
950
 
951
- #ifdef INT8_MMA_AVAILABLE
952
  int * x_qs = (int *) x_tile;
953
  half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
954
  #else
955
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
956
  int * x_qs = (int *) x_tile;
957
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
958
- #endif // INT8_MMA_AVAILABLE
959
 
960
  const int kqsx = threadIdx.x % QI2_K;
961
 
@@ -977,11 +978,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
977
 
978
  const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
979
 
980
- #ifdef INT8_MMA_AVAILABLE
981
  x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
982
  #else
983
  x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
984
- #endif // INT8_MMA_AVAILABLE
985
  }
986
 
987
  const int sc_m = bxi->scales[kqsx];
@@ -992,11 +993,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
992
  const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
993
  #endif // FAST_FP16_AVAILABLE
994
 
995
- #ifdef INT8_MMA_AVAILABLE
996
  x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
997
  #else
998
  x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik;
999
- #endif // INT8_MMA_AVAILABLE
1000
  }
1001
  }
1002
 
@@ -1051,12 +1052,12 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1051
  template <int mmq_x, int mmq_y, int nwarps>
1052
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1053
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1054
- #ifdef INT8_MMA_AVAILABLE
1055
 
1056
- typedef mma_int_A_I16K4 mma_A;
1057
- typedef mma_int_A_I16K8 mma_A_K8;
1058
- typedef mma_int_B_J8K4 mma_B;
1059
- typedef mma_int_C_I16J8 mma_C;
1060
 
1061
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
1062
  constexpr int rows_per_warp = 2 * granularity;
@@ -1081,7 +1082,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1081
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1082
  const int k0 = k00 + k01;
1083
 
1084
- ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1085
  }
1086
  }
1087
 
@@ -1118,24 +1119,25 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1118
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1119
  mma_B B[2];
1120
 
1121
- B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
1122
- B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
 
1123
 
1124
  mma_C Cm[2];
1125
  if (k01 >= WARP_SIZE * 3/4) {
1126
  mma_A A1;
1127
  A1.x[0] = 0x01010101;
1128
  A1.x[1] = 0x01010101;
1129
- Cm[0].mma_K4(A1, B[0]);
1130
- Cm[1].mma_K4(A1, B[1]);
1131
  }
1132
 
1133
  #pragma unroll
1134
  for (int n = 0; n < ntx; ++n) {
1135
  mma_C Cd[2];
1136
 
1137
- Cd[0].mma_K4(A[n][k01/4 + 0], B[0]);
1138
- Cd[1].mma_K4(A[n][k01/4 + 1], B[1]);
1139
 
1140
  #pragma unroll
1141
  for (int l = 0; l < mma_C::ne; ++l) {
@@ -1172,13 +1174,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1172
  #else
1173
  GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
1174
  NO_DEVICE_CODE;
1175
- #endif // INT8_MMA_AVAILABLE
1176
  }
1177
 
1178
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
1179
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1180
 
1181
- #ifdef INT8_MMA_AVAILABLE
1182
  int * x_qs = (int *) x_tile;
1183
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1184
  #else
@@ -1186,7 +1188,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1186
  int * x_qs = (int *) x_tile;
1187
  float * x_df = (float *) (x_qs + txs.qs);
1188
  int * x_sc = (int *) (x_df + txs.dm);
1189
- #endif // INT8_MMA_AVAILABLE
1190
 
1191
  const int kqsx = threadIdx.x % QI3_K;
1192
 
@@ -1212,11 +1214,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1212
 
1213
  const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
1214
 
1215
- #ifdef INT8_MMA_AVAILABLE
1216
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
1217
  #else
1218
  x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
1219
- #endif // INT8_MMA_AVAILABLE
1220
  }
1221
  }
1222
 
@@ -1242,7 +1244,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1242
 
1243
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1244
 
1245
- #ifdef INT8_MMA_AVAILABLE
1246
  const int8_t * sc8 = (const int8_t *) &sc;
1247
  const float d = bxi->d;
1248
 
@@ -1252,10 +1254,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1252
  }
1253
  #else
1254
  x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
1255
- #endif // INT8_MMA_AVAILABLE
1256
  }
1257
 
1258
- #ifndef INT8_MMA_AVAILABLE
1259
  #pragma unroll
1260
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
1261
  int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
@@ -1268,7 +1270,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1268
 
1269
  x_df[i] = bxi->d;
1270
  }
1271
- #endif // INT8_MMA_AVAILABLE
1272
  }
1273
 
1274
  template <int mmq_x, int mmq_y, int nwarps>
@@ -1317,7 +1319,7 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
1317
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1318
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1319
 
1320
- #ifdef INT8_MMA_AVAILABLE
1321
  int * x_qs = (int *) x_tile;
1322
  half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
1323
  #else
@@ -1325,7 +1327,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1325
  int * x_qs = (int *) x_tile;
1326
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1327
  int * x_sc = (int *) (x_dm + txs.dm);
1328
- #endif // INT8_MMA_AVAILABLE
1329
 
1330
  #pragma unroll
1331
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -1338,15 +1340,15 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1338
  const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1339
  const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
1340
 
1341
- #ifdef INT8_MMA_AVAILABLE
1342
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
1343
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
1344
  #else
1345
  x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
1346
- #endif // INT8_MMA_AVAILABLE
1347
  }
1348
 
1349
- #ifdef INT8_MMA_AVAILABLE
1350
 
1351
  #pragma unroll
1352
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
@@ -1407,7 +1409,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1407
 
1408
  x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
1409
  }
1410
- #endif // INT8_MMA_AVAILABLE
1411
  }
1412
 
1413
  template <int mmq_x, int mmq_y, int nwarps>
@@ -1446,7 +1448,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1446
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
1447
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1448
 
1449
- #ifdef INT8_MMA_AVAILABLE
1450
  int * x_qs = (int *) x_tile;
1451
  half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
1452
  #else
@@ -1454,7 +1456,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1454
  int * x_qs = (int *) x_tile;
1455
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1456
  int * x_sc = (int *) (x_dm + txs.dm);
1457
- #endif // INT8_MMA_AVAILABLE
1458
 
1459
  #pragma unroll
1460
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -1478,16 +1480,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1478
  const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
1479
  const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
1480
 
1481
- #ifdef INT8_MMA_AVAILABLE
1482
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
1483
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
1484
  #else
1485
  x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
1486
  x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
1487
- #endif // INT8_MMA_AVAILABLE
1488
  }
1489
 
1490
- #ifdef INT8_MMA_AVAILABLE
1491
 
1492
  #pragma unroll
1493
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
@@ -1548,7 +1550,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1548
 
1549
  x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
1550
  }
1551
- #endif // INT8_MMA_AVAILABLE
1552
  }
1553
 
1554
  template <int mmq_x, int mmq_y, int nwarps>
@@ -1587,7 +1589,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1587
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
1588
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1589
 
1590
- #ifdef INT8_MMA_AVAILABLE
1591
  int * x_qs = (int *) x_tile;
1592
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1593
  int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
@@ -1596,7 +1598,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1596
  int * x_qs = (int *) x_tile;
1597
  float * x_df = (float *) (x_qs + txs.qs);
1598
  int * x_sc = (int *) (x_df + txs.dm);
1599
- #endif // INT8_MMA_AVAILABLE
1600
 
1601
  #pragma unroll
1602
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -1619,13 +1621,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1619
  const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
1620
  const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
1621
 
1622
- #ifdef INT8_MMA_AVAILABLE
1623
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
1624
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
1625
  #else
1626
  x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
1627
  x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
1628
- #endif // INT8_MMA_AVAILABLE
1629
  }
1630
 
1631
  const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
@@ -1641,11 +1643,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1641
 
1642
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
1643
 
1644
- #ifdef INT8_MMA_AVAILABLE
1645
  x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d;
1646
  #else
1647
  x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
1648
- #endif // INT8_MMA_AVAILABLE
1649
  }
1650
 
1651
  #pragma unroll
@@ -1658,11 +1660,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1658
 
1659
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
1660
 
1661
- #ifdef INT8_MMA_AVAILABLE
1662
  x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
1663
  #else
1664
  x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
1665
- #endif // INT8_MMA_AVAILABLE
1666
  }
1667
  }
1668
 
@@ -1702,11 +1704,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1702
  template <int mmq_x, int mmq_y, int nwarps>
1703
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1704
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1705
- #ifdef INT8_MMA_AVAILABLE
1706
 
1707
- typedef mma_int_A_I16K4 mma_A;
1708
- typedef mma_int_B_J8K4 mma_B;
1709
- typedef mma_int_C_I16J8 mma_C;
1710
 
1711
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
1712
  constexpr int rows_per_warp = 2 * granularity;
@@ -1732,8 +1734,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1732
  for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1733
  const int k0 = k00 + k01;
1734
 
1735
- A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
1736
- A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
1737
  }
1738
 
1739
  #pragma unroll
@@ -1771,8 +1773,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1771
  mma_B B[2];
1772
  float dB[mma_C::ne/2];
1773
 
1774
- B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
1775
- B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
 
1776
 
1777
  #pragma unroll
1778
  for (int l = 0; l < mma_C::ne/2; ++l) {
@@ -1784,8 +1787,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1784
  #pragma unroll
1785
  for (int n = 0; n < ntx; ++n) {
1786
  mma_C C[2];
1787
- C[0].mma_K4(A[n][k01/4 + 0], B[0]);
1788
- C[1].mma_K4(A[n][k01/4 + 1], B[1]);
1789
 
1790
  #pragma unroll
1791
  for (int l = 0; l < mma_C::ne; ++l) {
@@ -1805,20 +1808,20 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1805
  #else
1806
  GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
1807
  NO_DEVICE_CODE;
1808
- #endif // INT8_MMA_AVAILABLE
1809
  }
1810
 
1811
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
1812
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1813
 
1814
- #ifdef INT8_MMA_AVAILABLE
1815
  int * x_qs = (int *) x_tile;
1816
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1817
  #else
1818
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
1819
  int * x_qs = (int *) x_tile;
1820
  float * x_df = (float *) (x_qs + txs.qs);
1821
- #endif // INT8_MMA_AVAILABLE
1822
 
1823
  const int kbx = threadIdx.x / QI4_NL;
1824
  const int kqsx = threadIdx.x % QI4_NL;
@@ -1836,13 +1839,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1836
  const int aux_q4 = get_int_b2(bxi->qs, kqsx);
1837
  const int2 v = get_int_from_table_16(aux_q4);
1838
  const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
1839
- #ifdef INT8_MMA_AVAILABLE
1840
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
1841
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
1842
  #else
1843
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
1844
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
1845
- #endif // INT8_MMA_AVAILABLE
1846
  }
1847
 
1848
  const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
@@ -1858,25 +1861,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1858
 
1859
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
1860
 
1861
- #ifdef INT8_MMA_AVAILABLE
1862
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
1863
  #else
1864
  x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
1865
- #endif // INT8_MMA_AVAILABLE
1866
  }
1867
  }
1868
 
1869
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
1870
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1871
 
1872
- #ifdef INT8_MMA_AVAILABLE
1873
  int * x_qs = (int *) x_tile;
1874
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1875
  #else
1876
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
1877
  int * x_qs = (int *) x_tile;
1878
  float * x_df = (float *) (x_qs + txs.qs);
1879
- #endif // INT8_MMA_AVAILABLE
1880
 
1881
  const int kqsx = threadIdx.x % (QI2_XXS/2);
1882
 
@@ -1905,36 +1908,36 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1905
  const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
1906
  const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
1907
 
1908
- #ifdef INT8_MMA_AVAILABLE
1909
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
1910
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
1911
  #else
1912
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0;
1913
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1;
1914
- #endif // INT8_MMA_AVAILABLE
1915
  }
1916
 
1917
  const int ls = aux32 >> 28;
1918
  const float d = bxi->d;
1919
- #ifdef INT8_MMA_AVAILABLE
1920
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
1921
  #else
1922
  x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4;
1923
- #endif // INT8_MMA_AVAILABLE
1924
  }
1925
  }
1926
 
1927
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
1928
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1929
 
1930
- #ifdef INT8_MMA_AVAILABLE
1931
  int * x_qs = (int *) x_tile;
1932
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1933
  #else
1934
  constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
1935
  int * x_qs = (int *) x_tile;
1936
  float * x_df = (float *) (x_qs + txs.qs);
1937
- #endif // INT8_MMA_AVAILABLE
1938
 
1939
  const int kqsx = threadIdx.x % (QI2_XS/2);
1940
 
@@ -1959,38 +1962,38 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1959
  const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
1960
  const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
1961
 
1962
- #ifdef INT8_MMA_AVAILABLE
1963
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
1964
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
1965
  #else
1966
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
1967
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
1968
- #endif // INT8_MMA_AVAILABLE
1969
  }
1970
 
1971
  const int ls = bxi->scales[kqsx];
1972
  const float d = bxi->d;
1973
- #ifdef INT8_MMA_AVAILABLE
1974
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
1975
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
1976
  #else
1977
  x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
1978
  x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
1979
- #endif // INT8_MMA_AVAILABLE
1980
  }
1981
  }
1982
 
1983
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
1984
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1985
 
1986
- #ifdef INT8_MMA_AVAILABLE
1987
  int * x_qs = (int *) x_tile;
1988
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1989
  #else
1990
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
1991
  int * x_qs = (int *) x_tile;
1992
  float * x_df = (float *) (x_qs + txs.qs);
1993
- #endif // INT8_MMA_AVAILABLE
1994
 
1995
  const int kqsx = threadIdx.x % (QI2_S/2);
1996
 
@@ -2022,38 +2025,38 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2022
  const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
2023
  const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
2024
 
2025
- #ifdef INT8_MMA_AVAILABLE
2026
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2027
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2028
  #else
2029
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2030
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2031
- #endif // INT8_MMA_AVAILABLE
2032
  }
2033
 
2034
  const int ls = bxi->scales[kqsx];
2035
  const float d = bxi->d;
2036
- #ifdef INT8_MMA_AVAILABLE
2037
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2038
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2039
  #else
2040
  x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2041
  x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2042
- #endif // INT8_MMA_AVAILABLE
2043
  }
2044
  }
2045
 
2046
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
2047
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2048
 
2049
- #ifdef INT8_MMA_AVAILABLE
2050
  int * x_qs = (int *) x_tile;
2051
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
2052
  #else
2053
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
2054
  int * x_qs = (int *) x_tile;
2055
  float * x_df = (float *) (x_qs + txs.qs);
2056
- #endif // INT8_MMA_AVAILABLE
2057
 
2058
  const int kqsx = threadIdx.x % (QI3_XXS/2);
2059
 
@@ -2080,36 +2083,36 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2080
  const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
2081
  const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
2082
 
2083
- #ifdef INT8_MMA_AVAILABLE
2084
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
2085
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
2086
  #else
2087
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2088
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2089
- #endif // INT8_MMA_AVAILABLE
2090
  }
2091
 
2092
  const int ls = aux32 >> 28;
2093
  const float d = bxi->d;
2094
- #ifdef INT8_MMA_AVAILABLE
2095
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2096
  #else
2097
  x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2098
- #endif // INT8_MMA_AVAILABLE
2099
  }
2100
  }
2101
 
2102
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
2103
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2104
 
2105
- #ifdef INT8_MMA_AVAILABLE
2106
  int * x_qs = (int *) x_tile;
2107
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
2108
  #else
2109
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2110
  int * x_qs = (int *) x_tile;
2111
  float * x_df = (float *) (x_qs + txs.qs);
2112
- #endif // INT8_MMA_AVAILABLE
2113
 
2114
  const int kqsx = threadIdx.x % (QI3_S/2);
2115
 
@@ -2143,36 +2146,36 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2143
  const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2144
  const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2145
 
2146
- #ifdef INT8_MMA_AVAILABLE
2147
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
2148
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
2149
  #else
2150
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l;
2151
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h;
2152
- #endif // INT8_MMA_AVAILABLE
2153
  }
2154
 
2155
  const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
2156
  const float d = bxi->d;
2157
- #ifdef INT8_MMA_AVAILABLE
2158
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2159
  #else
2160
  x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d;
2161
- #endif // INT8_MMA_AVAILABLE
2162
  }
2163
  }
2164
 
2165
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
2166
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2167
 
2168
- #ifdef INT8_MMA_AVAILABLE
2169
  int * x_qs = (int *) x_tile;
2170
  half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
2171
  #else
2172
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2173
  int * x_qs = (int *) x_tile;
2174
  half2 * x_ds = (half2 *) (x_qs + txs.qs);
2175
- #endif // INT8_MMA_AVAILABLE
2176
 
2177
  const int kqsx = threadIdx.x % QI1_S;
2178
 
@@ -2198,37 +2201,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2198
  const int grid0 = (grid >> 0) & 0x0F0F0F0F;
2199
  const int grid1 = (grid >> 4) & 0x0F0F0F0F;
2200
 
2201
- #ifdef INT8_MMA_AVAILABLE
2202
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
2203
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
2204
  #else
2205
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0;
2206
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1;
2207
- #endif // INT8_MMA_AVAILABLE
2208
  }
2209
 
2210
  const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
2211
  const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
2212
 
2213
- #ifdef INT8_MMA_AVAILABLE
2214
  x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
2215
  #else
2216
  x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
2217
- #endif // INT8_MMA_AVAILABLE
2218
  }
2219
  }
2220
 
2221
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
2222
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2223
 
2224
- #ifdef INT8_MMA_AVAILABLE
2225
  int * x_qs = (int *) x_tile;
2226
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
2227
  #else
2228
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2229
  int * x_qs = (int *) x_tile;
2230
  float * x_df = (float *) (x_qs + txs.qs);
2231
- #endif // INT8_MMA_AVAILABLE
2232
 
2233
  const int kbx = 0; // threadIdx.x / QI4_XS
2234
  const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
@@ -2246,13 +2249,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2246
  const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2247
  const int2 v = get_int_from_table_16(aux_q4);
2248
  const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2249
- #ifdef INT8_MMA_AVAILABLE
2250
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2251
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2252
  #else
2253
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2254
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
2255
- #endif // INT8_MMA_AVAILABLE
2256
  }
2257
 
2258
  #pragma unroll
@@ -2270,11 +2273,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2270
  const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
2271
  | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
2272
 
2273
- #ifdef INT8_MMA_AVAILABLE
2274
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
2275
  #else
2276
  x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2277
- #endif // INT8_MMA_AVAILABLE
2278
  }
2279
  }
2280
 
@@ -2307,16 +2310,16 @@ template<int mmq_x, int mmq_y, int nwarps, bool need_check>
2307
  static __device__ __forceinline__ void mmq_write_back_mma(
2308
  const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
2309
 
2310
- typedef mma_int_C_I16J8 mma_C;
2311
 
2312
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
2313
  constexpr int rows_per_warp = 2 * granularity;
2314
  constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
2315
 
2316
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
2317
- #ifdef INT8_MMA_AVAILABLE
2318
  static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
2319
- #endif // INT8_MMA_AVAILABLE
2320
 
2321
  #pragma unroll
2322
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
@@ -2505,13 +2508,13 @@ static __device__ void mul_mat_q_process_tile(
2505
  int * tile_y = (int *) data_mul_mat_q;
2506
  int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
2507
 
2508
- #ifdef INT8_MMA_AVAILABLE
2509
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
2510
  constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
2511
  #else
2512
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
2513
  constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
2514
- #endif // INT8_MMA_AVAILABLE
2515
 
2516
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
2517
 
@@ -2643,7 +2646,7 @@ static __global__ void mul_mat_q(
2643
  const int jt = kbc / (blocks_per_ne00*nty);
2644
  const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
2645
 
2646
- constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
2647
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2648
  (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
2649
  it, jt, kb0_start, kb0_stop);
@@ -2749,7 +2752,7 @@ template<ggml_type type>
2749
  static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
2750
  const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
2751
  const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
2752
- const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
2753
  const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
2754
  return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
2755
  }
 
87
  };
88
 
89
  static constexpr int get_mmq_x_max_host(const int cc) {
90
+ return new_mma_available(cc) ? 128 :
91
  #ifdef GGML_CUDA_FORCE_MMQ
92
  cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64;
93
  #else
 
96
  }
97
 
98
  static constexpr __device__ int get_mmq_x_max_device() {
99
+ #ifdef NEW_MMA_AVAILABLE
100
  return 128;
101
+ #else // NEW_MMA_AVAILABLE
102
 
103
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
104
  return 128;
 
116
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
117
 
118
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
119
+ #endif // NEW_MMA_AVAILABLE
120
  }
121
 
122
  static constexpr int get_mmq_y_host(const int cc) {
 
209
  #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
210
 
211
  static int mmq_get_granularity_host(const int mmq_x, const int cc) {
212
+ return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
213
  }
214
 
215
+ #ifdef NEW_MMA_AVAILABLE
216
  static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
217
  return mmq_x >= 48 ? 16 : 8;
218
  }
 
220
  static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
221
  return 8;
222
  }
223
+ #endif // NEW_MMA_AVAILABLE
224
 
225
  // ------------------------------------------------------------
226
 
227
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
228
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
229
 
230
+ #ifdef NEW_MMA_AVAILABLE
231
  int * x_qs = (int *) x_tile;
232
  float * x_df = (float *) (x_qs + 2*WARP_SIZE);
233
  #else
234
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
235
  int * x_qs = (int *) x_tile;
236
  float * x_df = (float *) (x_qs + txs.qs);
237
+ #endif // NEW_MMA_AVAILABLE
238
 
239
  const int kbx = threadIdx.x / QI4_0;
240
  const int kqsx = threadIdx.x % QI4_0;
 
250
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
251
  const int qs0 = get_int_b2(bxi->qs, kqsx);
252
 
253
+ #ifdef NEW_MMA_AVAILABLE
254
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
255
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
256
  #else
257
  x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
258
+ #endif // NEW_MMA_AVAILABLE
259
  }
260
 
261
  const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
 
271
 
272
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
273
 
274
+ #ifdef NEW_MMA_AVAILABLE
275
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
276
  #else
277
  x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
278
+ #endif // NEW_MMA_AVAILABLE
279
  }
280
  }
281
 
 
322
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
323
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
324
 
325
+ #ifdef NEW_MMA_AVAILABLE
326
  int * x_qs = (int *) x_tile;
327
  half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
328
  #else
329
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
330
  int * x_qs = (int *) x_tile;
331
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
332
+ #endif // NEW_MMA_AVAILABLE
333
 
334
  const int kbx = threadIdx.x / QI4_1;
335
  const int kqsx = threadIdx.x % QI4_1;
 
345
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
346
  const int qs0 = get_int_b4(bxi->qs, kqsx);
347
 
348
+ #ifdef NEW_MMA_AVAILABLE
349
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
350
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
351
  #else
352
  x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
353
+ #endif // NEW_MMA_AVAILABLE
354
  }
355
 
356
  const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
 
366
 
367
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
368
 
369
+ #ifdef NEW_MMA_AVAILABLE
370
  x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
371
  #else
372
  x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
373
+ #endif // NEW_MMA_AVAILABLE
374
  }
375
  }
376
 
 
417
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
418
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
419
 
420
+ #ifdef NEW_MMA_AVAILABLE
421
  int * x_qs = (int *) x_tile;
422
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
423
  #else
424
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
425
  int * x_qs = (int *) x_tile;
426
  float * x_df = (float *) (x_qs + txs.qs);
427
+ #endif // NEW_MMA_AVAILABLE
428
 
429
  const int kbx = threadIdx.x / QI5_0;
430
  const int kqsx = threadIdx.x % QI5_0;
 
456
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
457
  qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
458
 
459
+ #ifdef NEW_MMA_AVAILABLE
460
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
461
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
462
  #else
463
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
464
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
465
+ #endif // NEW_MMA_AVAILABLE
466
  }
467
 
468
  const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
 
478
 
479
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
480
 
481
+ #ifdef NEW_MMA_AVAILABLE
482
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
483
  #else
484
  x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
485
+ #endif // NEW_MMA_AVAILABLE
486
  }
487
  }
488
 
489
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
490
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
491
 
492
+ #ifdef NEW_MMA_AVAILABLE
493
  int * x_qs = (int *) x_tile;
494
  half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
495
  #else
496
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
497
  int * x_qs = (int *) x_tile;
498
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
499
+ #endif // NEW_MMA_AVAILABLE
500
 
501
  const int kbx = threadIdx.x / QI5_1;
502
  const int kqsx = threadIdx.x % QI5_1;
 
526
  qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
527
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
528
 
529
+ #ifdef NEW_MMA_AVAILABLE
530
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
531
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
532
  #else
533
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
534
  x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
535
+ #endif // NEW_MMA_AVAILABLE
536
  }
537
 
538
  const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
 
548
 
549
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
550
 
551
+ #ifdef NEW_MMA_AVAILABLE
552
  x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
553
  #else
554
  x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
555
+ #endif // NEW_MMA_AVAILABLE
556
  }
557
  }
558
 
559
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
560
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
561
 
562
+ #ifdef NEW_MMA_AVAILABLE
563
  int * x_qs = (int *) x_tile;
564
  float * x_df = (float *) (x_tile + 2*WARP_SIZE);
565
  #else
566
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
567
  int * x_qs = (int *) x_tile;
568
  float * x_df = (float *) (x_qs + txs.qs);
569
+ #endif // NEW_MMA_AVAILABLE
570
 
571
  const int kbx = threadIdx.x / QI8_0;
572
  const int kqsx = threadIdx.x % QI8_0;
 
581
 
582
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
583
 
584
+ #ifdef NEW_MMA_AVAILABLE
585
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
586
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
587
  #else
588
  x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
589
  x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
590
+ #endif // NEW_MMA_AVAILABLE
591
  }
592
 
593
  const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
 
603
 
604
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
605
 
606
+ #ifdef NEW_MMA_AVAILABLE
607
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
608
  #else
609
  x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
610
+ #endif // NEW_MMA_AVAILABLE
611
  }
612
  }
613
 
 
645
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
646
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
647
 
648
+ typedef mma_A_I16K8<int> mma_A;
649
+ typedef mma_B_J8K8<int> mma_B;
650
+ typedef mma_C_I16J8<int> mma_C;
651
 
652
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
653
  constexpr int rows_per_warp = 2 * granularity;
 
672
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
673
  const int k0 = k00 + k01;
674
 
675
+ A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
676
  }
677
 
678
  #pragma unroll
 
695
  mma_B B;
696
  float dB[mma_C::ne/2];
697
 
698
+ B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
699
 
700
  #pragma unroll
701
  for (int l = 0; l < mma_C::ne/2; ++l) {
 
711
  #pragma unroll
712
  for (int n = 0; n < ntx; ++n) {
713
  mma_C C;
714
+ C.mma(A[n][k01/QI8_0], B);
715
 
716
  #pragma unroll
717
  for (int l = 0; l < mma_C::ne; ++l) {
 
756
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
757
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
758
 
759
+ typedef mma_A_I16K8<int> mma_A;
760
+ typedef mma_B_J8K8<int> mma_B;
761
+ typedef mma_C_I16J8<int> mma_C;
762
 
763
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
764
  constexpr int rows_per_warp = 2 * granularity;
 
782
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
783
  const int k0 = k00 + k01;
784
 
785
+ A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
786
  }
787
 
788
  #pragma unroll
 
805
  mma_B B;
806
  float2 dsB[mma_C::ne/2];
807
 
808
+ B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
809
 
810
  #pragma unroll
811
  for (int l = 0; l < mma_C::ne/2; ++l) {
 
817
  #pragma unroll
818
  for (int n = 0; n < ntx; ++n) {
819
  mma_C C;
820
+ C.mma(A[n][k01/QI8_1], B);
821
 
822
  #pragma unroll
823
  for (int l = 0; l < mma_C::ne; ++l) {
 
864
  template <int mmq_x, int mmq_y, int nwarps>
865
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
866
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
867
+ #ifdef NEW_MMA_AVAILABLE
868
 
869
+ typedef mma_A_I16K4<int> mma_A;
870
+ typedef mma_A_I16K8<int> mma_A_K8;
871
+ typedef mma_B_J8K4<int> mma_B;
872
+ typedef mma_C_I16J8<int> mma_C;
873
 
874
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
875
  constexpr int rows_per_warp = 2 * granularity;
 
893
  for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
894
  const int k0 = k00 + k01;
895
 
896
+ ((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
897
  }
898
 
899
  #pragma unroll
 
916
  mma_B B[2];
917
  float dB[mma_C::ne/2];
918
 
919
+ // Here load_generic is faster than load_ldmatrix.
920
+ B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
921
+ B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
922
 
923
  #pragma unroll
924
  for (int l = 0; l < mma_C::ne/2; ++l) {
 
930
  #pragma unroll
931
  for (int n = 0; n < ntx; ++n) {
932
  mma_C C[2];
933
+ C[0].mma(A[n][k01/4 + 0], B[0]);
934
+ C[1].mma(A[n][k01/4 + 1], B[1]);
935
 
936
  #pragma unroll
937
  for (int l = 0; l < mma_C::ne; ++l) {
 
943
  #else
944
  GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
945
  NO_DEVICE_CODE;
946
+ #endif // NEW_MMA_AVAILABLE
947
  }
948
 
949
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
950
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
951
 
952
+ #ifdef NEW_MMA_AVAILABLE
953
  int * x_qs = (int *) x_tile;
954
  half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
955
  #else
956
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
957
  int * x_qs = (int *) x_tile;
958
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
959
+ #endif // NEW_MMA_AVAILABLE
960
 
961
  const int kqsx = threadIdx.x % QI2_K;
962
 
 
978
 
979
  const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
980
 
981
+ #ifdef NEW_MMA_AVAILABLE
982
  x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
983
  #else
984
  x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
985
+ #endif // NEW_MMA_AVAILABLE
986
  }
987
 
988
  const int sc_m = bxi->scales[kqsx];
 
993
  const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
994
  #endif // FAST_FP16_AVAILABLE
995
 
996
+ #ifdef NEW_MMA_AVAILABLE
997
  x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
998
  #else
999
  x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik;
1000
+ #endif // NEW_MMA_AVAILABLE
1001
  }
1002
  }
1003
 
 
1052
  template <int mmq_x, int mmq_y, int nwarps>
1053
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1054
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1055
+ #ifdef NEW_MMA_AVAILABLE
1056
 
1057
+ typedef mma_A_I16K4<int> mma_A;
1058
+ typedef mma_A_I16K8<int> mma_A_K8;
1059
+ typedef mma_B_J8K4<int> mma_B;
1060
+ typedef mma_C_I16J8<int> mma_C;
1061
 
1062
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
1063
  constexpr int rows_per_warp = 2 * granularity;
 
1082
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1083
  const int k0 = k00 + k01;
1084
 
1085
+ ((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1086
  }
1087
  }
1088
 
 
1119
  for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
1120
  mma_B B[2];
1121
 
1122
+ // Here load_generic is faster than load_ldmatrix.
1123
+ B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
1124
+ B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
1125
 
1126
  mma_C Cm[2];
1127
  if (k01 >= WARP_SIZE * 3/4) {
1128
  mma_A A1;
1129
  A1.x[0] = 0x01010101;
1130
  A1.x[1] = 0x01010101;
1131
+ Cm[0].mma(A1, B[0]);
1132
+ Cm[1].mma(A1, B[1]);
1133
  }
1134
 
1135
  #pragma unroll
1136
  for (int n = 0; n < ntx; ++n) {
1137
  mma_C Cd[2];
1138
 
1139
+ Cd[0].mma(A[n][k01/4 + 0], B[0]);
1140
+ Cd[1].mma(A[n][k01/4 + 1], B[1]);
1141
 
1142
  #pragma unroll
1143
  for (int l = 0; l < mma_C::ne; ++l) {
 
1174
  #else
1175
  GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
1176
  NO_DEVICE_CODE;
1177
+ #endif // NEW_MMA_AVAILABLE
1178
  }
1179
 
1180
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
1181
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1182
 
1183
+ #ifdef NEW_MMA_AVAILABLE
1184
  int * x_qs = (int *) x_tile;
1185
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1186
  #else
 
1188
  int * x_qs = (int *) x_tile;
1189
  float * x_df = (float *) (x_qs + txs.qs);
1190
  int * x_sc = (int *) (x_df + txs.dm);
1191
+ #endif // NEW_MMA_AVAILABLE
1192
 
1193
  const int kqsx = threadIdx.x % QI3_K;
1194
 
 
1214
 
1215
  const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
1216
 
1217
+ #ifdef NEW_MMA_AVAILABLE
1218
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
1219
  #else
1220
  x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
1221
+ #endif // NEW_MMA_AVAILABLE
1222
  }
1223
  }
1224
 
 
1244
 
1245
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1246
 
1247
+ #ifdef NEW_MMA_AVAILABLE
1248
  const int8_t * sc8 = (const int8_t *) &sc;
1249
  const float d = bxi->d;
1250
 
 
1254
  }
1255
  #else
1256
  x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
1257
+ #endif // NEW_MMA_AVAILABLE
1258
  }
1259
 
1260
+ #ifndef NEW_MMA_AVAILABLE
1261
  #pragma unroll
1262
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
1263
  int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
 
1270
 
1271
  x_df[i] = bxi->d;
1272
  }
1273
+ #endif // NEW_MMA_AVAILABLE
1274
  }
1275
 
1276
  template <int mmq_x, int mmq_y, int nwarps>
 
1319
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1320
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1321
 
1322
+ #ifdef NEW_MMA_AVAILABLE
1323
  int * x_qs = (int *) x_tile;
1324
  half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
1325
  #else
 
1327
  int * x_qs = (int *) x_tile;
1328
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1329
  int * x_sc = (int *) (x_dm + txs.dm);
1330
+ #endif // NEW_MMA_AVAILABLE
1331
 
1332
  #pragma unroll
1333
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
 
1340
  const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1341
  const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
1342
 
1343
+ #ifdef NEW_MMA_AVAILABLE
1344
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
1345
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
1346
  #else
1347
  x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
1348
+ #endif // NEW_MMA_AVAILABLE
1349
  }
1350
 
1351
+ #ifdef NEW_MMA_AVAILABLE
1352
 
1353
  #pragma unroll
1354
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
 
1409
 
1410
  x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
1411
  }
1412
+ #endif // NEW_MMA_AVAILABLE
1413
  }
1414
 
1415
  template <int mmq_x, int mmq_y, int nwarps>
 
1448
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
1449
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1450
 
1451
+ #ifdef NEW_MMA_AVAILABLE
1452
  int * x_qs = (int *) x_tile;
1453
  half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
1454
  #else
 
1456
  int * x_qs = (int *) x_tile;
1457
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1458
  int * x_sc = (int *) (x_dm + txs.dm);
1459
+ #endif // NEW_MMA_AVAILABLE
1460
 
1461
  #pragma unroll
1462
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
 
1480
  const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
1481
  const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
1482
 
1483
+ #ifdef NEW_MMA_AVAILABLE
1484
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
1485
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
1486
  #else
1487
  x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
1488
  x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
1489
+ #endif // NEW_MMA_AVAILABLE
1490
  }
1491
 
1492
+ #ifdef NEW_MMA_AVAILABLE
1493
 
1494
  #pragma unroll
1495
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
 
1550
 
1551
  x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
1552
  }
1553
+ #endif // NEW_MMA_AVAILABLE
1554
  }
1555
 
1556
  template <int mmq_x, int mmq_y, int nwarps>
 
1589
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
1590
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1591
 
1592
+ #ifdef NEW_MMA_AVAILABLE
1593
  int * x_qs = (int *) x_tile;
1594
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1595
  int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
 
1598
  int * x_qs = (int *) x_tile;
1599
  float * x_df = (float *) (x_qs + txs.qs);
1600
  int * x_sc = (int *) (x_df + txs.dm);
1601
+ #endif // NEW_MMA_AVAILABLE
1602
 
1603
  #pragma unroll
1604
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
 
1621
  const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
1622
  const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
1623
 
1624
+ #ifdef NEW_MMA_AVAILABLE
1625
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
1626
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
1627
  #else
1628
  x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
1629
  x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
1630
+ #endif // NEW_MMA_AVAILABLE
1631
  }
1632
 
1633
  const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
 
1643
 
1644
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
1645
 
1646
+ #ifdef NEW_MMA_AVAILABLE
1647
  x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d;
1648
  #else
1649
  x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
1650
+ #endif // NEW_MMA_AVAILABLE
1651
  }
1652
 
1653
  #pragma unroll
 
1660
 
1661
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
1662
 
1663
+ #ifdef NEW_MMA_AVAILABLE
1664
  x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
1665
  #else
1666
  x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
1667
+ #endif // NEW_MMA_AVAILABLE
1668
  }
1669
  }
1670
 
 
1704
  template <int mmq_x, int mmq_y, int nwarps>
1705
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1706
  const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1707
+ #ifdef NEW_MMA_AVAILABLE
1708
 
1709
+ typedef mma_A_I16K4<int> mma_A;
1710
+ typedef mma_B_J8K4<int> mma_B;
1711
+ typedef mma_C_I16J8<int> mma_C;
1712
 
1713
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
1714
  constexpr int rows_per_warp = 2 * granularity;
 
1734
  for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
1735
  const int k0 = k00 + k01;
1736
 
1737
+ A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
1738
+ A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
1739
  }
1740
 
1741
  #pragma unroll
 
1773
  mma_B B[2];
1774
  float dB[mma_C::ne/2];
1775
 
1776
+ // Here load_generic is faster than load_ldmatrix.
1777
+ B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
1778
+ B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
1779
 
1780
  #pragma unroll
1781
  for (int l = 0; l < mma_C::ne/2; ++l) {
 
1787
  #pragma unroll
1788
  for (int n = 0; n < ntx; ++n) {
1789
  mma_C C[2];
1790
+ C[0].mma(A[n][k01/4 + 0], B[0]);
1791
+ C[1].mma(A[n][k01/4 + 1], B[1]);
1792
 
1793
  #pragma unroll
1794
  for (int l = 0; l < mma_C::ne; ++l) {
 
1808
  #else
1809
  GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
1810
  NO_DEVICE_CODE;
1811
+ #endif // NEW_MMA_AVAILABLE
1812
  }
1813
 
1814
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
1815
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1816
 
1817
+ #ifdef NEW_MMA_AVAILABLE
1818
  int * x_qs = (int *) x_tile;
1819
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1820
  #else
1821
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
1822
  int * x_qs = (int *) x_tile;
1823
  float * x_df = (float *) (x_qs + txs.qs);
1824
+ #endif // NEW_MMA_AVAILABLE
1825
 
1826
  const int kbx = threadIdx.x / QI4_NL;
1827
  const int kqsx = threadIdx.x % QI4_NL;
 
1839
  const int aux_q4 = get_int_b2(bxi->qs, kqsx);
1840
  const int2 v = get_int_from_table_16(aux_q4);
1841
  const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
1842
+ #ifdef NEW_MMA_AVAILABLE
1843
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
1844
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
1845
  #else
1846
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
1847
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
1848
+ #endif // NEW_MMA_AVAILABLE
1849
  }
1850
 
1851
  const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
 
1861
 
1862
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
1863
 
1864
+ #ifdef NEW_MMA_AVAILABLE
1865
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
1866
  #else
1867
  x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
1868
+ #endif // NEW_MMA_AVAILABLE
1869
  }
1870
  }
1871
 
1872
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
1873
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1874
 
1875
+ #ifdef NEW_MMA_AVAILABLE
1876
  int * x_qs = (int *) x_tile;
1877
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1878
  #else
1879
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
1880
  int * x_qs = (int *) x_tile;
1881
  float * x_df = (float *) (x_qs + txs.qs);
1882
+ #endif // NEW_MMA_AVAILABLE
1883
 
1884
  const int kqsx = threadIdx.x % (QI2_XXS/2);
1885
 
 
1908
  const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
1909
  const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
1910
 
1911
+ #ifdef NEW_MMA_AVAILABLE
1912
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
1913
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
1914
  #else
1915
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0;
1916
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1;
1917
+ #endif // NEW_MMA_AVAILABLE
1918
  }
1919
 
1920
  const int ls = aux32 >> 28;
1921
  const float d = bxi->d;
1922
+ #ifdef NEW_MMA_AVAILABLE
1923
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
1924
  #else
1925
  x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4;
1926
+ #endif // NEW_MMA_AVAILABLE
1927
  }
1928
  }
1929
 
1930
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
1931
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1932
 
1933
+ #ifdef NEW_MMA_AVAILABLE
1934
  int * x_qs = (int *) x_tile;
1935
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1936
  #else
1937
  constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
1938
  int * x_qs = (int *) x_tile;
1939
  float * x_df = (float *) (x_qs + txs.qs);
1940
+ #endif // NEW_MMA_AVAILABLE
1941
 
1942
  const int kqsx = threadIdx.x % (QI2_XS/2);
1943
 
 
1962
  const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
1963
  const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
1964
 
1965
+ #ifdef NEW_MMA_AVAILABLE
1966
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
1967
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
1968
  #else
1969
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
1970
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
1971
+ #endif // NEW_MMA_AVAILABLE
1972
  }
1973
 
1974
  const int ls = bxi->scales[kqsx];
1975
  const float d = bxi->d;
1976
+ #ifdef NEW_MMA_AVAILABLE
1977
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
1978
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
1979
  #else
1980
  x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
1981
  x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
1982
+ #endif // NEW_MMA_AVAILABLE
1983
  }
1984
  }
1985
 
1986
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
1987
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1988
 
1989
+ #ifdef NEW_MMA_AVAILABLE
1990
  int * x_qs = (int *) x_tile;
1991
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
1992
  #else
1993
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
1994
  int * x_qs = (int *) x_tile;
1995
  float * x_df = (float *) (x_qs + txs.qs);
1996
+ #endif // NEW_MMA_AVAILABLE
1997
 
1998
  const int kqsx = threadIdx.x % (QI2_S/2);
1999
 
 
2025
  const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
2026
  const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
2027
 
2028
+ #ifdef NEW_MMA_AVAILABLE
2029
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2030
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2031
  #else
2032
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2033
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2034
+ #endif // NEW_MMA_AVAILABLE
2035
  }
2036
 
2037
  const int ls = bxi->scales[kqsx];
2038
  const float d = bxi->d;
2039
+ #ifdef NEW_MMA_AVAILABLE
2040
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2041
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2042
  #else
2043
  x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2044
  x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2045
+ #endif // NEW_MMA_AVAILABLE
2046
  }
2047
  }
2048
 
2049
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
2050
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2051
 
2052
+ #ifdef NEW_MMA_AVAILABLE
2053
  int * x_qs = (int *) x_tile;
2054
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
2055
  #else
2056
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
2057
  int * x_qs = (int *) x_tile;
2058
  float * x_df = (float *) (x_qs + txs.qs);
2059
+ #endif // NEW_MMA_AVAILABLE
2060
 
2061
  const int kqsx = threadIdx.x % (QI3_XXS/2);
2062
 
 
2083
  const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
2084
  const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
2085
 
2086
+ #ifdef NEW_MMA_AVAILABLE
2087
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
2088
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
2089
  #else
2090
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2091
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2092
+ #endif // NEW_MMA_AVAILABLE
2093
  }
2094
 
2095
  const int ls = aux32 >> 28;
2096
  const float d = bxi->d;
2097
+ #ifdef NEW_MMA_AVAILABLE
2098
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2099
  #else
2100
  x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2101
+ #endif // NEW_MMA_AVAILABLE
2102
  }
2103
  }
2104
 
2105
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
2106
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2107
 
2108
+ #ifdef NEW_MMA_AVAILABLE
2109
  int * x_qs = (int *) x_tile;
2110
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
2111
  #else
2112
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2113
  int * x_qs = (int *) x_tile;
2114
  float * x_df = (float *) (x_qs + txs.qs);
2115
+ #endif // NEW_MMA_AVAILABLE
2116
 
2117
  const int kqsx = threadIdx.x % (QI3_S/2);
2118
 
 
2146
  const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2147
  const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2148
 
2149
+ #ifdef NEW_MMA_AVAILABLE
2150
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
2151
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
2152
  #else
2153
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l;
2154
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h;
2155
+ #endif // NEW_MMA_AVAILABLE
2156
  }
2157
 
2158
  const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
2159
  const float d = bxi->d;
2160
+ #ifdef NEW_MMA_AVAILABLE
2161
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2162
  #else
2163
  x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d;
2164
+ #endif // NEW_MMA_AVAILABLE
2165
  }
2166
  }
2167
 
2168
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
2169
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2170
 
2171
+ #ifdef NEW_MMA_AVAILABLE
2172
  int * x_qs = (int *) x_tile;
2173
  half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
2174
  #else
2175
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2176
  int * x_qs = (int *) x_tile;
2177
  half2 * x_ds = (half2 *) (x_qs + txs.qs);
2178
+ #endif // NEW_MMA_AVAILABLE
2179
 
2180
  const int kqsx = threadIdx.x % QI1_S;
2181
 
 
2201
  const int grid0 = (grid >> 0) & 0x0F0F0F0F;
2202
  const int grid1 = (grid >> 4) & 0x0F0F0F0F;
2203
 
2204
+ #ifdef NEW_MMA_AVAILABLE
2205
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
2206
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
2207
  #else
2208
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0;
2209
  x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1;
2210
+ #endif // NEW_MMA_AVAILABLE
2211
  }
2212
 
2213
  const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
2214
  const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
2215
 
2216
+ #ifdef NEW_MMA_AVAILABLE
2217
  x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
2218
  #else
2219
  x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
2220
+ #endif // NEW_MMA_AVAILABLE
2221
  }
2222
  }
2223
 
2224
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
2225
  const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2226
 
2227
+ #ifdef NEW_MMA_AVAILABLE
2228
  int * x_qs = (int *) x_tile;
2229
  float * x_df = (float *) (x_qs + WARP_SIZE*2);
2230
  #else
2231
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2232
  int * x_qs = (int *) x_tile;
2233
  float * x_df = (float *) (x_qs + txs.qs);
2234
+ #endif // NEW_MMA_AVAILABLE
2235
 
2236
  const int kbx = 0; // threadIdx.x / QI4_XS
2237
  const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
 
2249
  const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2250
  const int2 v = get_int_from_table_16(aux_q4);
2251
  const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2252
+ #ifdef NEW_MMA_AVAILABLE
2253
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2254
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2255
  #else
2256
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2257
  x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
2258
+ #endif // NEW_MMA_AVAILABLE
2259
  }
2260
 
2261
  #pragma unroll
 
2273
  const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
2274
  | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
2275
 
2276
+ #ifdef NEW_MMA_AVAILABLE
2277
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
2278
  #else
2279
  x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2280
+ #endif // NEW_MMA_AVAILABLE
2281
  }
2282
  }
2283
 
 
2310
  static __device__ __forceinline__ void mmq_write_back_mma(
2311
  const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
2312
 
2313
+ typedef mma_C_I16J8<int> mma_C;
2314
 
2315
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
2316
  constexpr int rows_per_warp = 2 * granularity;
2317
  constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
2318
 
2319
  const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
2320
+ #ifdef NEW_MMA_AVAILABLE
2321
  static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
2322
+ #endif // NEW_MMA_AVAILABLE
2323
 
2324
  #pragma unroll
2325
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
 
2508
  int * tile_y = (int *) data_mul_mat_q;
2509
  int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
2510
 
2511
+ #ifdef NEW_MMA_AVAILABLE
2512
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
2513
  constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
2514
  #else
2515
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
2516
  constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
2517
+ #endif // NEW_MMA_AVAILABLE
2518
 
2519
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
2520
 
 
2646
  const int jt = kbc / (blocks_per_ne00*nty);
2647
  const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
2648
 
2649
+ constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
2650
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2651
  (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
2652
  it, jt, kb0_start, kb0_stop);
 
2752
  static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
2753
  const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
2754
  const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
2755
+ const int shmem_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
2756
  const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
2757
  return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
2758
  }
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 16);
6
+ DECL_FATTN_MMA_F16_CASE(80, 16);
7
+ DECL_FATTN_MMA_F16_CASE(96, 16);
8
+ DECL_FATTN_MMA_F16_CASE(112, 16);
9
+ DECL_FATTN_MMA_F16_CASE(128, 16);
10
+ DECL_FATTN_MMA_F16_CASE(256, 16);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 32);
6
+ DECL_FATTN_MMA_F16_CASE(80, 32);
7
+ DECL_FATTN_MMA_F16_CASE(96, 32);
8
+ DECL_FATTN_MMA_F16_CASE(112, 32);
9
+ DECL_FATTN_MMA_F16_CASE(128, 32);
10
+ DECL_FATTN_MMA_F16_CASE(256, 32);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 64);
6
+ DECL_FATTN_MMA_F16_CASE(80, 64);
7
+ DECL_FATTN_MMA_F16_CASE(96, 64);
8
+ DECL_FATTN_MMA_F16_CASE(112, 64);
9
+ DECL_FATTN_MMA_F16_CASE(128, 64);
10
+ DECL_FATTN_MMA_F16_CASE(256, 64);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 8);
6
+ DECL_FATTN_MMA_F16_CASE(80, 8);
7
+ DECL_FATTN_MMA_F16_CASE(96, 8);
8
+ DECL_FATTN_MMA_F16_CASE(112, 8);
9
+ DECL_FATTN_MMA_F16_CASE(128, 8);
10
+ DECL_FATTN_MMA_F16_CASE(256, 8);
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu DELETED
@@ -1,10 +0,0 @@
1
- // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
-
3
- #include "../fattn-wmma-f16.cuh"
4
-
5
- DECL_FATTN_WMMA_F16_CASE(64, 16, float);
6
- DECL_FATTN_WMMA_F16_CASE(80, 16, float);
7
- DECL_FATTN_WMMA_F16_CASE(96, 16, float);
8
- DECL_FATTN_WMMA_F16_CASE(112, 16, float);
9
- DECL_FATTN_WMMA_F16_CASE(128, 16, float);
10
- DECL_FATTN_WMMA_F16_CASE(256, 16, float);
 
 
 
 
 
 
 
 
 
 
 
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu DELETED
@@ -1,9 +0,0 @@
1
- // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
-
3
- #include "../fattn-wmma-f16.cuh"
4
-
5
- DECL_FATTN_WMMA_F16_CASE(64, 32, float);
6
- DECL_FATTN_WMMA_F16_CASE(80, 32, float);
7
- DECL_FATTN_WMMA_F16_CASE(96, 32, float);
8
- DECL_FATTN_WMMA_F16_CASE(112, 32, float);
9
- DECL_FATTN_WMMA_F16_CASE(128, 32, float);
 
 
 
 
 
 
 
 
 
 
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu DELETED
@@ -1,10 +0,0 @@
1
- // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
-
3
- #include "../fattn-wmma-f16.cuh"
4
-
5
- DECL_FATTN_WMMA_F16_CASE(64, 16, half);
6
- DECL_FATTN_WMMA_F16_CASE(80, 16, half);
7
- DECL_FATTN_WMMA_F16_CASE(96, 16, half);
8
- DECL_FATTN_WMMA_F16_CASE(112, 16, half);
9
- DECL_FATTN_WMMA_F16_CASE(128, 16, half);
10
- DECL_FATTN_WMMA_F16_CASE(256, 16, half);
 
 
 
 
 
 
 
 
 
 
 
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu DELETED
@@ -1,10 +0,0 @@
1
- // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
-
3
- #include "../fattn-wmma-f16.cuh"
4
-
5
- DECL_FATTN_WMMA_F16_CASE(64, 32, half);
6
- DECL_FATTN_WMMA_F16_CASE(80, 32, half);
7
- DECL_FATTN_WMMA_F16_CASE(96, 32, half);
8
- DECL_FATTN_WMMA_F16_CASE(112, 32, half);
9
- DECL_FATTN_WMMA_F16_CASE(128, 32, half);
10
- DECL_FATTN_WMMA_F16_CASE(256, 32, half);
 
 
 
 
 
 
 
 
 
 
 
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu DELETED
@@ -1,8 +0,0 @@
1
- // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
-
3
- #include "../fattn-wmma-f16.cuh"
4
-
5
- DECL_FATTN_WMMA_F16_CASE(64, 8, half);
6
- DECL_FATTN_WMMA_F16_CASE(96, 8, half);
7
- DECL_FATTN_WMMA_F16_CASE(128, 8, half);
8
- DECL_FATTN_WMMA_F16_CASE(256, 8, half);
 
 
 
 
 
 
 
 
 
ggml/src/ggml-cuda/template-instances/generate_cu_files.py CHANGED
@@ -12,13 +12,13 @@ SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.p
12
  DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v});
13
  """
14
 
15
- SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
16
 
17
- #include "../fattn-wmma-f16.cuh"
18
 
19
  """
20
 
21
- SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n"
22
 
23
  TYPES_MMQ = [
24
  "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
@@ -57,20 +57,12 @@ for vkq_size in [16, 32]:
57
  with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
58
  f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
59
 
60
- for kq_acc_t in ["half", "float"]:
61
- for cols_per_block in [8, 16, 32]:
62
- if kq_acc_t == "float" and cols_per_block == 8:
63
- continue
64
 
65
- with open(f"fattn-wmma-f16-instance-kq{kq_acc_t}-cpb{cols_per_block}.cu", "w") as f:
66
- f.write(SOURCE_FATTN_WMMA_START)
67
-
68
- for head_size in [64, 80, 96, 112, 128, 256]:
69
- if cols_per_block == 8 and head_size % 32 != 0: # wmma fragment is 8x32
70
- continue
71
- if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance
72
- continue
73
- f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size))
74
 
75
  for type in TYPES_MMQ:
76
  with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
 
12
  DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v});
13
  """
14
 
15
+ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
16
 
17
+ #include "../fattn-mma-f16.cuh"
18
 
19
  """
20
 
21
+ SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n"
22
 
23
  TYPES_MMQ = [
24
  "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
 
57
  with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
58
  f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
59
 
60
+ for cols_per_block in [8, 16, 32, 64]:
61
+ with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f:
62
+ f.write(SOURCE_FATTN_MMA_START)
 
63
 
64
+ for head_size in [64, 80, 96, 112, 128, 256]:
65
+ f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size))
 
 
 
 
 
 
 
66
 
67
  for type in TYPES_MMQ:
68
  with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
ggml/src/ggml-cuda/vendors/hip.h CHANGED
@@ -25,6 +25,7 @@
25
  #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
26
  #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
27
  #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
 
28
  #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
29
  #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
30
  #define cublasCreate hipblasCreate
 
25
  #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
26
  #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
27
  #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
28
+ #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
29
  #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
30
  #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
31
  #define cublasCreate hipblasCreate
ggml/src/ggml-hip/CMakeLists.txt CHANGED
@@ -50,7 +50,7 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
50
  list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
51
 
52
  file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
53
- file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
54
  list(APPEND GGML_SOURCES_ROCM ${SRCS})
55
  file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
56
  list(APPEND GGML_SOURCES_ROCM ${SRCS})
 
50
  list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
51
 
52
  file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
53
+ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
54
  list(APPEND GGML_SOURCES_ROCM ${SRCS})
55
  file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
56
  list(APPEND GGML_SOURCES_ROCM ${SRCS})
ggml/src/ggml-musa/CMakeLists.txt CHANGED
@@ -29,7 +29,7 @@ if (MUSAToolkit_FOUND)
29
  list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
30
 
31
  file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
32
- file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
33
  list(APPEND GGML_SOURCES_MUSA ${SRCS})
34
  file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
35
  list(APPEND GGML_SOURCES_MUSA ${SRCS})
 
29
  list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
30
 
31
  file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
32
+ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
33
  list(APPEND GGML_SOURCES_MUSA ${SRCS})
34
  file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
35
  list(APPEND GGML_SOURCES_MUSA ${SRCS})