JohannesGaessler commited on
Commit
1d24833
·
1 Parent(s): 67ec576

CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16 (llama/15131)

Browse files
ggml/src/ggml-cuda/common.cuh CHANGED
@@ -233,9 +233,13 @@ typedef float2 dfloat2;
233
  #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
234
 
235
  #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
236
- #define NEW_MMA_AVAILABLE
237
  #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
238
 
 
 
 
 
239
  #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
240
  #define CP_ASYNC_AVAILABLE
241
  #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@@ -303,10 +307,14 @@ static bool amd_mfma_available(const int cc) {
303
  }
304
 
305
  // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
306
- static bool new_mma_available(const int cc) {
307
  return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
308
  }
309
 
 
 
 
 
310
  static bool cp_async_available(const int cc) {
311
  return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
312
  }
 
233
  #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
234
 
235
  #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
236
+ #define TURING_MMA_AVAILABLE
237
  #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
238
 
239
+ #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
240
+ #define AMPERE_MMA_AVAILABLE
241
+ #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
242
+
243
  #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
244
  #define CP_ASYNC_AVAILABLE
245
  #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
307
  }
308
 
309
  // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
310
+ static bool turing_mma_available(const int cc) {
311
  return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
312
  }
313
 
314
+ static bool ampere_mma_available(const int cc) {
315
+ return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
316
+ }
317
+
318
  static bool cp_async_available(const int cc) {
319
  return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
320
  }
ggml/src/ggml-cuda/fattn-mma-f16.cuh CHANGED
@@ -418,7 +418,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
418
  float * const __restrict__ KQ_max,
419
  float * const __restrict__ KQ_rowsum,
420
  const int kb0) {
421
- #ifdef NEW_MMA_AVAILABLE
422
  typedef fattn_mma_f16_config<DKQ, DV> c;
423
 
424
  #ifdef CP_ASYNC_AVAILABLE
@@ -776,7 +776,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
776
  GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
777
  GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
778
  NO_DEVICE_CODE;
779
- #endif // NEW_MMA_AVAILABLE
780
  }
781
 
782
  template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
@@ -800,7 +800,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
800
  const int jt,
801
  const int kb0_start,
802
  const int kb0_stop) {
803
- #ifdef NEW_MMA_AVAILABLE
804
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
805
 
806
  typedef fattn_mma_f16_config<DKQ, DV> c;
@@ -1196,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1196
  GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
1197
  GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
1198
  NO_DEVICE_CODE;
1199
- #endif // NEW_MMA_AVAILABLE
1200
  }
1201
 
1202
  template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
@@ -1223,7 +1223,7 @@ static __global__ void flash_attn_ext_f16(
1223
  const int32_t nb21, const int32_t nb22, const int64_t nb23,
1224
  const int32_t ne31, const int32_t ne32, const int32_t ne33,
1225
  const int32_t nb31, const int32_t nb32, const int64_t nb33) {
1226
- #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1227
 
1228
  // Skip unused kernel variants for faster compilation:
1229
  if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@@ -1354,7 +1354,7 @@ static __global__ void flash_attn_ext_f16(
1354
  GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
1355
  GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
1356
  NO_DEVICE_CODE;
1357
- #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
1358
  }
1359
 
1360
  template <int DKQ, int DV, int ncols1, int ncols2>
 
418
  float * const __restrict__ KQ_max,
419
  float * const __restrict__ KQ_rowsum,
420
  const int kb0) {
421
+ #ifdef TURING_MMA_AVAILABLE
422
  typedef fattn_mma_f16_config<DKQ, DV> c;
423
 
424
  #ifdef CP_ASYNC_AVAILABLE
 
776
  GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
777
  GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
778
  NO_DEVICE_CODE;
779
+ #endif // TURING_MMA_AVAILABLE
780
  }
781
 
782
  template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
 
800
  const int jt,
801
  const int kb0_start,
802
  const int kb0_stop) {
803
+ #ifdef TURING_MMA_AVAILABLE
804
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
805
 
806
  typedef fattn_mma_f16_config<DKQ, DV> c;
 
1196
  GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
1197
  GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
1198
  NO_DEVICE_CODE;
1199
+ #endif // TURING_MMA_AVAILABLE
1200
  }
1201
 
1202
  template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
 
1223
  const int32_t nb21, const int32_t nb22, const int64_t nb23,
1224
  const int32_t ne31, const int32_t ne32, const int32_t ne33,
1225
  const int32_t nb31, const int32_t nb32, const int64_t nb33) {
1226
+ #if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
1227
 
1228
  // Skip unused kernel variants for faster compilation:
1229
  if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
 
1354
  GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
1355
  GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
1356
  NO_DEVICE_CODE;
1357
+ #endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
1358
  }
1359
 
1360
  template <int DKQ, int DV, int ncols1, int ncols2>
ggml/src/ggml-cuda/fattn.cu CHANGED
@@ -327,7 +327,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
327
  const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
328
  const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
329
  const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
330
- const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
331
  (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
332
  const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
333
  if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
@@ -340,7 +340,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
340
  }
341
 
342
  // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
343
- if (fp16_mma_available(cc) && !new_mma_available(cc)) {
344
  ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
345
  return;
346
  }
 
327
  const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
328
  const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
329
  const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
330
+ const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
331
  (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
332
  const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
333
  if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
 
340
  }
341
 
342
  // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
343
+ if (fp16_mma_available(cc) && !turing_mma_available(cc)) {
344
  ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
345
  return;
346
  }
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -22,8 +22,9 @@
22
  #include "ggml-cuda/fattn.cuh"
23
  #include "ggml-cuda/getrows.cuh"
24
  #include "ggml-cuda/im2col.cuh"
 
25
  #include "ggml-cuda/mmq.cuh"
26
- #include "ggml-cuda/mmv.cuh"
27
  #include "ggml-cuda/mmvq.cuh"
28
  #include "ggml-cuda/norm.cuh"
29
  #include "ggml-cuda/opt-step-adamw.cuh"
@@ -2008,7 +2009,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2008
  const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
2009
  && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
2010
 
2011
- bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
 
 
2012
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2013
  bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
2014
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
@@ -2028,14 +2031,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2028
  }
2029
 
2030
  const int cc = ggml_cuda_info().devices[id].cc;
 
2031
  use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2032
- use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
 
2033
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2034
  }
2035
  } else {
2036
  const int cc = ggml_cuda_info().devices[ctx.device].cc;
 
2037
  use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2038
- use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
 
2039
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2040
  }
2041
 
@@ -2048,15 +2055,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2048
  //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
2049
 
2050
  //TODO update for generic tensor parallelism
2051
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2052
  bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2053
  bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2054
  bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2055
 
2056
- if (!split && use_mul_mat_vec) {
2057
  // the custom F16 vector kernel can be used over batched cuBLAS GEMM
2058
  // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
2059
- ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
 
 
2060
  } else if (!split && use_mul_mat_vec_q) {
2061
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
2062
  } else if (!split && use_mul_mat_q) {
@@ -2065,8 +2074,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
2065
  && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2066
  // general KQ + KQV multi-batch without FlashAttention
2067
  ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
2068
- } else if (use_mul_mat_vec) {
2069
- ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
2070
  } else if (use_mul_mat_vec_q) {
2071
  ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
2072
  } else if (use_mul_mat_q) {
@@ -2094,7 +2103,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2094
  if (ggml_is_quantized(src0->type)) {
2095
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
2096
  } else {
2097
- ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
2098
  }
2099
  return;
2100
  }
@@ -3516,7 +3525,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3516
  #endif // FLASH_ATTN_AVAILABLE
3517
  if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
3518
  const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3519
- if (!new_mma_available(cc)) {
3520
  return false;
3521
  }
3522
  const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
 
22
  #include "ggml-cuda/fattn.cuh"
23
  #include "ggml-cuda/getrows.cuh"
24
  #include "ggml-cuda/im2col.cuh"
25
+ #include "ggml-cuda/mmf.cuh"
26
  #include "ggml-cuda/mmq.cuh"
27
+ #include "ggml-cuda/mmvf.cuh"
28
  #include "ggml-cuda/mmvq.cuh"
29
  #include "ggml-cuda/norm.cuh"
30
  #include "ggml-cuda/opt-step-adamw.cuh"
 
2009
  const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
2010
  && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
2011
 
2012
+ bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
2013
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2014
+ bool use_mul_mat_f = !ggml_is_quantized(src0->type)
2015
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2016
  bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
2017
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
 
2031
  }
2032
 
2033
  const int cc = ggml_cuda_info().devices[id].cc;
2034
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
2035
  use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2036
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
2037
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
2038
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2039
  }
2040
  } else {
2041
  const int cc = ggml_cuda_info().devices[ctx.device].cc;
2042
+ const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
2043
  use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
2044
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
2045
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
2046
  any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
2047
  }
2048
 
 
2055
  //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
2056
 
2057
  //TODO update for generic tensor parallelism
2058
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2059
  bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2060
  bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2061
  bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
2062
 
2063
+ if (!split && use_mul_mat_vec_f) {
2064
  // the custom F16 vector kernel can be used over batched cuBLAS GEMM
2065
  // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
2066
+ ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);
2067
+ } else if (!split && use_mul_mat_f) {
2068
+ ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);
2069
  } else if (!split && use_mul_mat_vec_q) {
2070
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
2071
  } else if (!split && use_mul_mat_q) {
 
2074
  && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2075
  // general KQ + KQV multi-batch without FlashAttention
2076
  ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
2077
+ } else if (use_mul_mat_vec_f) {
2078
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);
2079
  } else if (use_mul_mat_vec_q) {
2080
  ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
2081
  } else if (use_mul_mat_q) {
 
2103
  if (ggml_is_quantized(src0->type)) {
2104
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
2105
  } else {
2106
+ ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
2107
  }
2108
  return;
2109
  }
 
3525
  #endif // FLASH_ATTN_AVAILABLE
3526
  if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
3527
  const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3528
+ if (!turing_mma_available(cc)) {
3529
  return false;
3530
  }
3531
  const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
ggml/src/ggml-cuda/mma.cuh CHANGED
@@ -23,13 +23,13 @@
23
  static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
24
  int ret = 0;
25
 
26
- #ifdef NEW_MMA_AVAILABLE
27
  asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
28
  : "=r"(ret) : "r"(x));
29
  #else
30
  GGML_UNUSED(x);
31
  NO_DEVICE_CODE;
32
- #endif // defined(NEW_MMA_AVAILABLE)
33
  return ret;
34
  }
35
 
@@ -167,6 +167,38 @@ namespace ggml_cuda_mma {
167
  }
168
  };
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  template <int I, int J>
171
  static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
172
  tile<I, J/2, half2> ret;
@@ -209,7 +241,7 @@ namespace ggml_cuda_mma {
209
  template <typename T>
210
  static __device__ __forceinline__ void load_ldmatrix(
211
  tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
212
- #ifdef NEW_MMA_AVAILABLE
213
  int * xi = (int *) t.x;
214
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
215
  asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -217,13 +249,13 @@ namespace ggml_cuda_mma {
217
  : "l"(xs));
218
  #else
219
  load_generic(t, xs0, stride);
220
- #endif // NEW_MMA_AVAILABLE
221
  }
222
 
223
  template <typename T>
224
  static __device__ __forceinline__ void load_ldmatrix(
225
  tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
226
- #ifdef NEW_MMA_AVAILABLE
227
  int * xi = (int *) t.x;
228
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
229
  asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -232,13 +264,13 @@ namespace ggml_cuda_mma {
232
  #else
233
  load_generic(xs0, stride);
234
  GGML_UNUSED(t);
235
- #endif // NEW_MMA_AVAILABLE
236
  }
237
 
238
  template <typename T>
239
  static __device__ __forceinline__ void load_ldmatrix(
240
  tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
241
- #if defined(NEW_MMA_AVAILABLE)
242
  int * xi = (int * ) t.x;
243
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
244
  asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
@@ -246,13 +278,13 @@ namespace ggml_cuda_mma {
246
  : "l"(xs));
247
  #else
248
  load_generic(t, xs0, stride);
249
- #endif // NEW_MMA_AVAILABLE
250
  }
251
 
252
  template <typename T>
253
  static __device__ __forceinline__ void load_ldmatrix_trans(
254
  tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
255
- #ifdef NEW_MMA_AVAILABLE
256
  int * xi = (int * ) t.x;
257
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
258
  asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
@@ -263,12 +295,12 @@ namespace ggml_cuda_mma {
263
  GGML_UNUSED(xs0);
264
  GGML_UNUSED(stride);
265
  NO_DEVICE_CODE;
266
- #endif // NEW_MMA_AVAILABLE
267
  }
268
 
269
  static __device__ __forceinline__ void mma(
270
  tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
271
- #ifdef NEW_MMA_AVAILABLE
272
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
273
  asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
274
  : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -287,12 +319,12 @@ namespace ggml_cuda_mma {
287
  GGML_UNUSED(A);
288
  GGML_UNUSED(B);
289
  NO_DEVICE_CODE;
290
- #endif // NEW_MMA_AVAILABLE
291
  }
292
 
293
  static __device__ __forceinline__ void mma(
294
  tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
295
- #ifdef NEW_MMA_AVAILABLE
296
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
297
  asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
298
  : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -317,12 +349,12 @@ namespace ggml_cuda_mma {
317
  GGML_UNUSED(A);
318
  GGML_UNUSED(B);
319
  NO_DEVICE_CODE;
320
- #endif // NEW_MMA_AVAILABLE
321
  }
322
 
323
  static __device__ __forceinline__ void mma(
324
  tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
325
- #ifdef NEW_MMA_AVAILABLE
326
  const int * Axi = (const int *) A.x;
327
  const int * Bxi = (const int *) B.x;
328
  int * Dxi = (int *) D.x;
@@ -344,12 +376,12 @@ namespace ggml_cuda_mma {
344
  GGML_UNUSED(A);
345
  GGML_UNUSED(B);
346
  NO_DEVICE_CODE;
347
- #endif // NEW_MMA_AVAILABLE
348
  }
349
 
350
  static __device__ __forceinline__ void mma(
351
  tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
352
- #ifdef NEW_MMA_AVAILABLE
353
  const int * Axi = (const int *) A.x;
354
  const int * Bxi = (const int *) B.x;
355
  int * Dxi = (int *) D.x;
@@ -380,12 +412,29 @@ namespace ggml_cuda_mma {
380
  GGML_UNUSED(A);
381
  GGML_UNUSED(B);
382
  NO_DEVICE_CODE;
383
- #endif // NEW_MMA_AVAILABLE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  }
385
 
386
  static __device__ __forceinline__ void mma(
387
  tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
388
- #ifdef NEW_MMA_AVAILABLE
389
  const int * Axi = (const int *) A.x;
390
  const int * Bxi = (const int *) B.x;
391
  int * Dxi = (int *) D.x;
@@ -407,12 +456,29 @@ namespace ggml_cuda_mma {
407
  GGML_UNUSED(A);
408
  GGML_UNUSED(B);
409
  NO_DEVICE_CODE;
410
- #endif // NEW_MMA_AVAILABLE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
  }
412
 
413
  static __device__ __forceinline__ void mma(
414
  tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
415
- #ifdef NEW_MMA_AVAILABLE
416
  const int * Axi = (const int *) A.x;
417
  const int * Bxi = (const int *) B.x;
418
  int * Dxi = (int *) D.x;
@@ -443,7 +509,7 @@ namespace ggml_cuda_mma {
443
  GGML_UNUSED(A);
444
  GGML_UNUSED(B);
445
  NO_DEVICE_CODE;
446
- #endif // NEW_MMA_AVAILABLE
447
  }
448
 
449
  static __device__ __forceinline__ void mma(
 
23
  static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
24
  int ret = 0;
25
 
26
+ #ifdef TURING_MMA_AVAILABLE
27
  asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
28
  : "=r"(ret) : "r"(x));
29
  #else
30
  GGML_UNUSED(x);
31
  NO_DEVICE_CODE;
32
+ #endif // defined(TURING_MMA_AVAILABLE)
33
  return ret;
34
  }
35
 
 
167
  }
168
  };
169
 
170
+ template <int I_, int J_>
171
+ struct tile<I_, J_, nv_bfloat162> {
172
+ static constexpr int I = I_;
173
+ static constexpr int J = J_;
174
+ static constexpr int ne = I * J / WARP_SIZE;
175
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
176
+
177
+ static __device__ __forceinline__ int get_i(const int l) {
178
+ if constexpr (I == 8 && J == 8) {
179
+ return threadIdx.x / 4;
180
+ } else if constexpr (I == 16 && J == 4) {
181
+ return l * 8 + threadIdx.x / 4;
182
+ } else if constexpr (I == 16 && J == 8) {
183
+ return (l % 2) * 8 + threadIdx.x / 4;
184
+ } else {
185
+ static_assert(I == -1 && J == -1, "template specialization not implemented");
186
+ }
187
+ }
188
+
189
+ static __device__ __forceinline__ int get_j(const int l) {
190
+ if constexpr (I == 8 && J == 8) {
191
+ return l * 4 + threadIdx.x % 4;
192
+ } else if constexpr (I == 16 && J == 4) {
193
+ return threadIdx.x % 4;
194
+ } else if constexpr (I == 16 && J == 8) {
195
+ return (l / 2) * 4 + threadIdx.x % 4;
196
+ } else {
197
+ static_assert(I == -1 && J == -1, "template specialization not implemented");
198
+ }
199
+ }
200
+ };
201
+
202
  template <int I, int J>
203
  static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
204
  tile<I, J/2, half2> ret;
 
241
  template <typename T>
242
  static __device__ __forceinline__ void load_ldmatrix(
243
  tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
244
+ #ifdef TURING_MMA_AVAILABLE
245
  int * xi = (int *) t.x;
246
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
247
  asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
 
249
  : "l"(xs));
250
  #else
251
  load_generic(t, xs0, stride);
252
+ #endif // TURING_MMA_AVAILABLE
253
  }
254
 
255
  template <typename T>
256
  static __device__ __forceinline__ void load_ldmatrix(
257
  tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
258
+ #ifdef TURING_MMA_AVAILABLE
259
  int * xi = (int *) t.x;
260
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
261
  asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
 
264
  #else
265
  load_generic(xs0, stride);
266
  GGML_UNUSED(t);
267
+ #endif // TURING_MMA_AVAILABLE
268
  }
269
 
270
  template <typename T>
271
  static __device__ __forceinline__ void load_ldmatrix(
272
  tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
273
+ #if defined(TURING_MMA_AVAILABLE)
274
  int * xi = (int * ) t.x;
275
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
276
  asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
 
278
  : "l"(xs));
279
  #else
280
  load_generic(t, xs0, stride);
281
+ #endif // TURING_MMA_AVAILABLE
282
  }
283
 
284
  template <typename T>
285
  static __device__ __forceinline__ void load_ldmatrix_trans(
286
  tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
287
+ #ifdef TURING_MMA_AVAILABLE
288
  int * xi = (int * ) t.x;
289
  const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
290
  asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
 
295
  GGML_UNUSED(xs0);
296
  GGML_UNUSED(stride);
297
  NO_DEVICE_CODE;
298
+ #endif // TURING_MMA_AVAILABLE
299
  }
300
 
301
  static __device__ __forceinline__ void mma(
302
  tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
303
+ #ifdef TURING_MMA_AVAILABLE
304
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
305
  asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
306
  : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
 
319
  GGML_UNUSED(A);
320
  GGML_UNUSED(B);
321
  NO_DEVICE_CODE;
322
+ #endif // TURING_MMA_AVAILABLE
323
  }
324
 
325
  static __device__ __forceinline__ void mma(
326
  tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
327
+ #ifdef TURING_MMA_AVAILABLE
328
  #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
329
  asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
330
  : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
 
349
  GGML_UNUSED(A);
350
  GGML_UNUSED(B);
351
  NO_DEVICE_CODE;
352
+ #endif // TURING_MMA_AVAILABLE
353
  }
354
 
355
  static __device__ __forceinline__ void mma(
356
  tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
357
+ #ifdef TURING_MMA_AVAILABLE
358
  const int * Axi = (const int *) A.x;
359
  const int * Bxi = (const int *) B.x;
360
  int * Dxi = (int *) D.x;
 
376
  GGML_UNUSED(A);
377
  GGML_UNUSED(B);
378
  NO_DEVICE_CODE;
379
+ #endif // TURING_MMA_AVAILABLE
380
  }
381
 
382
  static __device__ __forceinline__ void mma(
383
  tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
384
+ #ifdef TURING_MMA_AVAILABLE
385
  const int * Axi = (const int *) A.x;
386
  const int * Bxi = (const int *) B.x;
387
  int * Dxi = (int *) D.x;
 
412
  GGML_UNUSED(A);
413
  GGML_UNUSED(B);
414
  NO_DEVICE_CODE;
415
+ #endif // TURING_MMA_AVAILABLE
416
+ }
417
+
418
+ static __device__ __forceinline__ void mma(
419
+ tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
420
+ #ifdef AMPERE_MMA_AVAILABLE
421
+ const int * Axi = (const int *) A.x;
422
+ const int * Bxi = (const int *) B.x;
423
+ int * Dxi = (int *) D.x;
424
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
425
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
426
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
427
+ #else
428
+ GGML_UNUSED(D);
429
+ GGML_UNUSED(A);
430
+ GGML_UNUSED(B);
431
+ NO_DEVICE_CODE;
432
+ #endif // AMPERE_MMA_AVAILABLE
433
  }
434
 
435
  static __device__ __forceinline__ void mma(
436
  tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
437
+ #ifdef TURING_MMA_AVAILABLE
438
  const int * Axi = (const int *) A.x;
439
  const int * Bxi = (const int *) B.x;
440
  int * Dxi = (int *) D.x;
 
456
  GGML_UNUSED(A);
457
  GGML_UNUSED(B);
458
  NO_DEVICE_CODE;
459
+ #endif // TURING_MMA_AVAILABLE
460
+ }
461
+
462
+ static __device__ __forceinline__ void mma(
463
+ tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
464
+ #ifdef AMPERE_MMA_AVAILABLE
465
+ const int * Axi = (const int *) A.x;
466
+ const int * Bxi = (const int *) B.x;
467
+ int * Dxi = (int *) D.x;
468
+ asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
469
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
470
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
471
+ #else
472
+ GGML_UNUSED(D);
473
+ GGML_UNUSED(A);
474
+ GGML_UNUSED(B);
475
+ NO_DEVICE_CODE;
476
+ #endif // AMPERE_MMA_AVAILABLE
477
  }
478
 
479
  static __device__ __forceinline__ void mma(
480
  tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
481
+ #ifdef TURING_MMA_AVAILABLE
482
  const int * Axi = (const int *) A.x;
483
  const int * Bxi = (const int *) B.x;
484
  int * Dxi = (int *) D.x;
 
509
  GGML_UNUSED(A);
510
  GGML_UNUSED(B);
511
  NO_DEVICE_CODE;
512
+ #endif // TURING_MMA_AVAILABLE
513
  }
514
 
515
  static __device__ __forceinline__ void mma(
ggml/src/ggml-cuda/mmf.cu ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml.h"
2
+ #include "common.cuh"
3
+ #include "mma.cuh"
4
+ #include "mmf.cuh"
5
+
6
+ using namespace ggml_cuda_mma;
7
+
8
+ #define MMF_ROWS_PER_BLOCK 32
9
+
10
+ template <typename T, int rows_per_block, int cols_per_block, int nwarps>
11
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
12
+ static __global__ void mul_mat_f(
13
+ const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
14
+ const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
15
+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
16
+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
17
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
18
+ typedef tile<16, 8, T> tile_A;
19
+ typedef tile< 8, 8, T> tile_B;
20
+ typedef tile<16, 8, float> tile_C;
21
+
22
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
23
+ constexpr int tile_k_padded = warp_size + 4;
24
+ constexpr int ntA = rows_per_block / tile_A::I;
25
+ constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
26
+
27
+ const int row0 = blockIdx.x * rows_per_block;
28
+ const int channel_dst = blockIdx.y;
29
+ const int channel_x = channel_dst / channel_ratio;
30
+ const int channel_y = channel_dst;
31
+ const int sample_dst = blockIdx.z;
32
+ const int sample_x = sample_dst / sample_ratio;
33
+ const int sample_y = sample_dst;
34
+
35
+ x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
36
+ y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
37
+ dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
38
+
39
+ const float2 * y2 = (const float2 *) y;
40
+
41
+ extern __shared__ char data_mmv[];
42
+
43
+ tile_C C[ntA][ntB];
44
+
45
+ T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
46
+
47
+ for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
48
+ tile_A A[ntA][warp_size / tile_A::J];
49
+ #pragma unroll
50
+ for (int itA = 0; itA < ntA; ++itA) {
51
+ #pragma unroll
52
+ for (int i = 0; i < tile_A::I; ++i) {
53
+ tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
54
+ }
55
+ #pragma unroll
56
+ for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
57
+ load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
58
+ }
59
+ }
60
+
61
+ #pragma unroll
62
+ for (int itB = 0; itB < ntB; ++itB) {
63
+ if constexpr (std::is_same_v<T, float>) {
64
+ #pragma unroll
65
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
66
+ const int j = j0 + itB*tile_B::I;
67
+
68
+ tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
69
+ }
70
+ } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
71
+ #pragma unroll
72
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
73
+ const int j = j0 + itB*tile_B::I;
74
+
75
+ const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
76
+ tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
77
+ }
78
+ } else {
79
+ static_assert(std::is_same_v<T, void>, "unsupported type");
80
+ }
81
+ #pragma unroll
82
+ for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
83
+ tile_B B;
84
+ load_ldmatrix(B, tile_xy + k0, tile_k_padded);
85
+ #pragma unroll
86
+ for (int itA = 0; itA < ntA; ++itA) {
87
+ mma(C[itA][itB], A[itA][k0/tile_B::J], B);
88
+ }
89
+ }
90
+ }
91
+ }
92
+
93
+ float * buf_iw = (float *) data_mmv;
94
+ constexpr int kiw = nwarps*rows_per_block + 4;
95
+
96
+ if (nwarps > 1) {
97
+ __syncthreads();
98
+ }
99
+ #pragma unroll
100
+ for (int itB = 0; itB < ntB; ++itB) {
101
+ #pragma unroll
102
+ for (int itA = 0; itA < ntA; ++itA) {
103
+ #pragma unroll
104
+ for (int l = 0; l < tile_C::ne; ++l) {
105
+ const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
106
+ const int j = itB*tile_C::J + tile_C::get_j(l);
107
+ buf_iw[j*kiw + i] = C[itA][itB].x[l];
108
+ }
109
+ }
110
+ }
111
+
112
+ if (nwarps > 1) {
113
+ __syncthreads();
114
+ }
115
+
116
+ #pragma unroll
117
+ for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
118
+ const int j = j0 + threadIdx.y;
119
+
120
+ if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
121
+ return;
122
+ }
123
+
124
+ float sum = 0.0f;
125
+ static_assert(rows_per_block == warp_size, "need loop/check");
126
+ #pragma unroll
127
+ for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
128
+ const int i = i0 + threadIdx.x;
129
+
130
+ sum += buf_iw[j*kiw + i];
131
+ }
132
+ dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
133
+ }
134
+ #else
135
+ NO_DEVICE_CODE;
136
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(ids); GGML_UNUSED(dst);
137
+ GGML_UNUSED(ncols); GGML_UNUSED(nchannels_y); GGML_UNUSED(stride_row); GGML_UNUSED(stride_col_y); GGML_UNUSED(stride_col_dst);
138
+ GGML_UNUSED(channel_ratio); GGML_UNUSED(stride_channel_x); GGML_UNUSED(stride_channel_y); GGML_UNUSED(stride_channel_dst);
139
+ GGML_UNUSED(sample_ratio); GGML_UNUSED(stride_sample_x); GGML_UNUSED(stride_sample_y); GGML_UNUSED(stride_sample_dst);
140
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
141
+ }
142
+
143
+ template <typename T, int cols_per_block>
144
+ static void mul_mat_f_cuda(
145
+ const T * x, const float * y, const int32_t * ids, float * dst,
146
+ const int64_t ncols_x, const int64_t nrows_x,
147
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
148
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
149
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
150
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
151
+ cudaStream_t stream) {
152
+ typedef tile<16, 8, T> tile_A;
153
+ typedef tile< 8, 8, T> tile_B;
154
+ typedef tile<16, 8, float> tile_C;
155
+
156
+ GGML_ASSERT(!ids && "mul_mat_id not implemented");
157
+
158
+ GGML_ASSERT(ncols_x % 2 == 0);
159
+ GGML_ASSERT(stride_row % 2 == 0);
160
+ GGML_ASSERT(stride_col_y % 2 == 0);
161
+ GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
162
+ GGML_ASSERT( nsamples_dst % nsamples_x == 0);
163
+ const int64_t channel_ratio = nchannels_dst / nchannels_x;
164
+ const int64_t sample_ratio = nsamples_dst / nsamples_x;
165
+
166
+ const int device = ggml_cuda_get_device();
167
+ const int warp_size = ggml_cuda_info().devices[device].warp_size;
168
+
169
+ int64_t nwarps_best = 1;
170
+ int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
171
+ int64_t max_block_size = 256;
172
+ for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
173
+ const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
174
+ if (niter < niter_best) {
175
+ niter_best = niter;
176
+ nwarps_best = nwarps;
177
+ }
178
+ }
179
+
180
+ constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
181
+ const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
182
+ const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
183
+ const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
184
+ const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
185
+ const dim3 block_dims(warp_size, nwarps_best, 1);
186
+ switch (nwarps_best) {
187
+ case 1: {
188
+ mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>>
189
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
190
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
191
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
192
+ } break;
193
+ case 2: {
194
+ mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>>
195
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
196
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
197
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
198
+ } break;
199
+ case 3: {
200
+ mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>>
201
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
202
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
203
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
204
+ } break;
205
+ case 4: {
206
+ mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>>
207
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
208
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
209
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
210
+ } break;
211
+ case 5: {
212
+ mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>>
213
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
214
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
215
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
216
+ } break;
217
+ case 6: {
218
+ mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>>
219
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
220
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
221
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
222
+ } break;
223
+ case 7: {
224
+ mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>>
225
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
226
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
227
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
228
+ } break;
229
+ case 8: {
230
+ mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>>
231
+ (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
232
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
233
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
234
+ } break;
235
+ default: {
236
+ GGML_ABORT("fatal error");
237
+ } break;
238
+ }
239
+ }
240
+
241
+ template <typename T>
242
+ static void mul_mat_f_switch_cols_per_block(
243
+ const T * x, const float * y, const int32_t * ids, float * dst,
244
+ const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
245
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
246
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
247
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
248
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
249
+ cudaStream_t stream) {
250
+ switch (ncols_dst) {
251
+ case 1: {
252
+ mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
253
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
254
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
255
+ } break;
256
+ case 2: {
257
+ mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
258
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
259
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
260
+ } break;
261
+ case 3: {
262
+ mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
263
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
264
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
265
+ } break;
266
+ case 4: {
267
+ mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
268
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
269
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
270
+ } break;
271
+ case 5: {
272
+ mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
273
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
274
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
275
+ } break;
276
+ case 6: {
277
+ mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
278
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
279
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
280
+ } break;
281
+ case 7: {
282
+ mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
283
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
284
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
285
+ } break;
286
+ case 8: {
287
+ mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
288
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
289
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
290
+ } break;
291
+ case 9: {
292
+ mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
293
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
294
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
295
+ } break;
296
+ case 10: {
297
+ mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
298
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
299
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
300
+ } break;
301
+ case 11: {
302
+ mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
303
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
304
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
305
+ } break;
306
+ case 12: {
307
+ mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
308
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
309
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
310
+ } break;
311
+ case 13: {
312
+ mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
313
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
314
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
315
+ } break;
316
+ case 14: {
317
+ mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
318
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
319
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
320
+ } break;
321
+ case 15: {
322
+ mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
323
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
324
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
325
+ } break;
326
+ case 16: {
327
+ mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
328
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
329
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
330
+ } break;
331
+ default: {
332
+ GGML_ABORT("fatal error");
333
+ } break;
334
+ }
335
+ }
336
+
337
+ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
338
+ GGML_ASSERT( src1->type == GGML_TYPE_F32);
339
+ GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
340
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
341
+
342
+ GGML_TENSOR_BINARY_OP_LOCALS;
343
+
344
+ const size_t ts_src0 = ggml_type_size(src0->type);
345
+ const size_t ts_src1 = ggml_type_size(src1->type);
346
+ const size_t ts_dst = ggml_type_size(dst->type);
347
+
348
+ GGML_ASSERT(ne13 == ne3);
349
+
350
+ GGML_ASSERT( nb00 == ts_src0);
351
+ GGML_ASSERT( nb10 == ts_src1);
352
+ GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
353
+ GGML_ASSERT( nb0 == ts_dst);
354
+
355
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
356
+ const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
357
+
358
+ const float * src1_d = (const float *) src1->data;
359
+ const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
360
+ float * dst_d = (float *) dst->data;
361
+
362
+ const int64_t s01 = src0->nb[1] / ts_src0;
363
+ const int64_t s11 = src1->nb[1] / ts_src1;
364
+ const int64_t s1 = dst->nb[1] / ts_dst;
365
+ const int64_t s02 = src0->nb[2] / ts_src0;
366
+ const int64_t s12 = src1->nb[2] / ts_src1;
367
+ const int64_t s2 = dst->nb[2] / ts_dst;
368
+ const int64_t s03 = src0->nb[3] / ts_src0;
369
+ const int64_t s13 = src1->nb[3] / ts_src1;
370
+ const int64_t s3 = dst->nb[3] / ts_dst;
371
+
372
+ // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
373
+ const int64_t ncols_dst = ids ? ne2 : ne1;
374
+ const int64_t nchannels_y = ids ? ne11 : ne12;
375
+ const int64_t nchannels_dst = ids ? ne1 : ne2;
376
+ const int64_t stride_channel_dst = ids ? s1 : s2;
377
+ const int64_t stride_channel_y = ids ? s11 : s12;
378
+
379
+ GGML_ASSERT(!ids || ncols_dst == 1);
380
+
381
+ switch (src0->type) {
382
+ case GGML_TYPE_F32: {
383
+ const float * src0_d = (const float *) src0->data;
384
+ constexpr int vals_per_T = 1;
385
+ mul_mat_f_switch_cols_per_block(
386
+ src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
387
+ ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
388
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
389
+ } break;
390
+ case GGML_TYPE_F16: {
391
+ const half2 * src0_d = (const half2 *) src0->data;
392
+ constexpr int vals_per_T = 2;
393
+ mul_mat_f_switch_cols_per_block(
394
+ src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
395
+ ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
396
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
397
+ } break;
398
+ case GGML_TYPE_BF16: {
399
+ const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
400
+ constexpr int vals_per_T = 2;
401
+ mul_mat_f_switch_cols_per_block(
402
+ src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
403
+ ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
404
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
405
+ } break;
406
+ default:
407
+ GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
408
+ }
409
+ }
410
+
411
+ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
412
+ if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
413
+ return false;
414
+ }
415
+ if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
416
+ return false;
417
+ }
418
+ if (ne11 > 16) {
419
+ return false;
420
+ }
421
+ switch (type) {
422
+ case GGML_TYPE_F32:
423
+ return ampere_mma_available(cc);
424
+ case GGML_TYPE_F16:
425
+ return turing_mma_available(cc);
426
+ case GGML_TYPE_BF16:
427
+ return ampere_mma_available(cc);
428
+ default:
429
+ return false;
430
+ }
431
+ }
ggml/src/ggml-cuda/mmf.cuh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
4
+
5
+ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);
ggml/src/ggml-cuda/mmq.cu CHANGED
@@ -310,7 +310,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
310
  return false;
311
  }
312
 
313
- if (new_mma_available(cc)) {
314
  return true;
315
  }
316
 
 
310
  return false;
311
  }
312
 
313
+ if (turing_mma_available(cc)) {
314
  return true;
315
  }
316
 
ggml/src/ggml-cuda/mmq.cuh CHANGED
@@ -92,7 +92,7 @@ struct tile_x_sizes {
92
  };
93
 
94
  static int get_mmq_x_max_host(const int cc) {
95
- return (amd_mfma_available(cc) || new_mma_available(cc)) ? 128 :
96
  GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
97
  #ifdef GGML_CUDA_FORCE_MMQ
98
  128 : 64;
@@ -102,9 +102,9 @@ static int get_mmq_x_max_host(const int cc) {
102
  }
103
 
104
  static constexpr __device__ int get_mmq_x_max_device() {
105
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
106
  return 128;
107
- #else // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
108
 
109
  #if defined(GGML_USE_HIP)
110
  return 64;
@@ -121,7 +121,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
121
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
122
 
123
  #endif // defined(GGML_USE_HIP)
124
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
125
  }
126
 
127
  static int get_mmq_y_host(const int cc) {
@@ -233,7 +233,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
233
  static int mmq_get_granularity_host(const int mmq_x, const int cc) {
234
  if (amd_mfma_available(cc)) {
235
  return mmq_x >= 128 ? 32 : 16;
236
- } else if (new_mma_available(cc) && mmq_x >= 48) {
237
  return 16;
238
  } else {
239
  return 8;
@@ -244,7 +244,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) {
244
  static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
245
  return mmq_x >= 128 ? 32 : 16;
246
  }
247
- #elif defined(NEW_MMA_AVAILABLE)
248
  static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
249
  return mmq_x >= 48 ? 16 : 8;
250
  }
@@ -279,14 +279,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
279
  constexpr int nwarps = mmq_get_nwarps_device();
280
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
281
 
282
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
283
  int * x_qs = (int *) x_tile;
284
  float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
285
  #else
286
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
287
  int * x_qs = (int *) x_tile;
288
  float * x_df = (float *) (x_qs + txs.qs);
289
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
290
 
291
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
292
  constexpr int nrows = warp_size / threads_per_row;
@@ -305,12 +305,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
305
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
306
  const int qs0 = get_int_b2(bxi->qs, kqsx);
307
 
308
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
309
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
310
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
311
  #else
312
  x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
313
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
314
  }
315
 
316
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
@@ -327,11 +327,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
327
 
328
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
329
 
330
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
331
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
332
  #else
333
  x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
334
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
335
  }
336
  }
337
 
@@ -382,14 +382,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
382
  constexpr int nwarps = mmq_get_nwarps_device();
383
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
384
 
385
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
386
  int * x_qs = (int *) x_tile;
387
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
388
  #else
389
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
390
  int * x_qs = (int *) x_tile;
391
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
392
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
393
 
394
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
395
  constexpr int nrows = warp_size / threads_per_row;
@@ -408,12 +408,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
408
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
409
  const int qs0 = get_int_b4(bxi->qs, kqsx);
410
 
411
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
412
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
413
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
414
  #else
415
  x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
416
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
417
  }
418
 
419
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
@@ -430,11 +430,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
430
 
431
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
432
 
433
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
434
  x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
435
  #else
436
  x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
437
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
438
  }
439
  }
440
 
@@ -485,14 +485,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
485
  constexpr int nwarps = mmq_get_nwarps_device();
486
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
487
 
488
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
489
  int * x_qs = (int *) x_tile;
490
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
491
  #else
492
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
493
  int * x_qs = (int *) x_tile;
494
  float * x_df = (float *) (x_qs + txs.qs);
495
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
496
 
497
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
498
  constexpr int nrows = warp_size / threads_per_row;
@@ -527,13 +527,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
527
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
528
  qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
529
 
530
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
531
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
532
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
533
  #else
534
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
535
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
536
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
537
  }
538
 
539
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
@@ -550,11 +550,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
550
 
551
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
552
 
553
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
554
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
555
  #else
556
  x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
557
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
558
  }
559
  }
560
 
@@ -563,14 +563,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
563
  constexpr int nwarps = mmq_get_nwarps_device();
564
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
565
 
566
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
567
  int * x_qs = (int *) x_tile;
568
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
569
  #else
570
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
571
  int * x_qs = (int *) x_tile;
572
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
573
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
574
 
575
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
576
  constexpr int nrows = warp_size / threads_per_row;
@@ -603,13 +603,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
603
  qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
604
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
605
 
606
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
607
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
608
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
609
  #else
610
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
611
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
612
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
613
  }
614
 
615
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
@@ -626,11 +626,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
626
 
627
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
628
 
629
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
630
  x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
631
  #else
632
  x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
633
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
634
  }
635
  }
636
 
@@ -639,14 +639,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
639
  constexpr int nwarps = mmq_get_nwarps_device();
640
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
641
 
642
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
643
  int * x_qs = (int *) x_tile;
644
  float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
645
  #else
646
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
647
  int * x_qs = (int *) x_tile;
648
  float * x_df = (float *) (x_qs + txs.qs);
649
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
650
 
651
  // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
652
  constexpr int threads_per_row = 32;
@@ -665,13 +665,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
665
 
666
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
667
 
668
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
669
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
670
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
671
  #else
672
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
673
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
674
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
675
  }
676
 
677
  constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
@@ -688,11 +688,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
688
 
689
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
690
 
691
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
692
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
693
  #else
694
  x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
695
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
696
  }
697
  }
698
 
@@ -701,14 +701,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
701
  constexpr int nwarps = mmq_get_nwarps_device();
702
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
703
 
704
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
705
  int * x_qs = (int *) x_tile;
706
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
707
  #else
708
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
709
  int * x_qs = (int *) x_tile;
710
  float * x_df = (float *) (x_qs + txs.qs);
711
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
712
 
713
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
714
  constexpr int nrows = warp_size / threads_per_row;
@@ -730,13 +730,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
730
  const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
731
  const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
732
 
733
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
734
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
735
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
736
  #else
737
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
738
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
739
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
740
  }
741
 
742
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
@@ -753,11 +753,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
753
 
754
  const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
755
 
756
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
757
  x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
758
  #else
759
  x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
760
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
761
  }
762
  }
763
 
@@ -1178,7 +1178,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
1178
  }
1179
  }
1180
  }
1181
- #elif defined(NEW_MMA_AVAILABLE)
1182
 
1183
  typedef tile<16, 4, int> tile_A;
1184
  typedef tile<16, 8, int> tile_A_8;
@@ -1264,14 +1264,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1264
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1265
  constexpr int nwarps = mmq_get_nwarps_device();
1266
 
1267
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1268
  int * x_qs = (int *) x_tile;
1269
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1270
  #else
1271
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1272
  int * x_qs = (int *) x_tile;
1273
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1274
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1275
 
1276
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
1277
  constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
@@ -1295,11 +1295,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1295
 
1296
  const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
1297
 
1298
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1299
  x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
1300
  #else
1301
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1302
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1303
  }
1304
 
1305
  const int sc_m = bxi->scales[kqsx];
@@ -1310,11 +1310,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1310
  const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
1311
  #endif // FAST_FP16_AVAILABLE
1312
 
1313
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1314
  x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
1315
  #else
1316
  x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
1317
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1318
  }
1319
  }
1320
 
@@ -1452,7 +1452,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1452
  }
1453
  }
1454
  }
1455
- #elif defined(NEW_MMA_AVAILABLE)
1456
 
1457
  typedef tile<16, 4, int> tile_A;
1458
  typedef tile<16, 8, int> tile_A_8;
@@ -1582,7 +1582,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1582
  constexpr int nwarps = mmq_get_nwarps_device();
1583
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1584
 
1585
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1586
  int * x_qs = (int *) x_tile;
1587
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1588
  #else
@@ -1590,7 +1590,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1590
  int * x_qs = (int *) x_tile;
1591
  float * x_df = (float *) (x_qs + txs.qs);
1592
  int * x_sc = (int *) (x_df + txs.dm);
1593
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1594
 
1595
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
1596
  constexpr int nrows = warp_size / threads_per_row;
@@ -1618,11 +1618,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1618
 
1619
  const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
1620
 
1621
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1622
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
1623
  #else
1624
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1625
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1626
  }
1627
  }
1628
 
@@ -1649,7 +1649,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1649
 
1650
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1651
 
1652
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1653
  const int8_t * sc8 = (const int8_t *) &sc;
1654
  const float d = bxi->d;
1655
 
@@ -1659,10 +1659,10 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1659
  }
1660
  #else
1661
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
1662
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1663
  }
1664
 
1665
- #if !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE))
1666
  #pragma unroll
1667
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1668
  int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
@@ -1675,7 +1675,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1675
 
1676
  x_df[i] = bxi->d;
1677
  }
1678
- #endif // !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE))
1679
  }
1680
 
1681
  template <int mmq_x, int mmq_y>
@@ -1728,7 +1728,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1728
  constexpr int nwarps = mmq_get_nwarps_device();
1729
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1730
 
1731
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1732
  int * x_qs = (int *) x_tile;
1733
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1734
  #else
@@ -1736,7 +1736,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1736
  int * x_qs = (int *) x_tile;
1737
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1738
  int * x_sc = (int *) (x_dm + txs.dm);
1739
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1740
 
1741
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
1742
  constexpr int nrows = warp_size / threads_per_row;
@@ -1753,15 +1753,15 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1753
  const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1754
  const int qs0 = get_int_b4(bxi->qs, txi);
1755
 
1756
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1757
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
1758
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
1759
  #else
1760
  x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
1761
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1762
  }
1763
 
1764
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1765
  constexpr int rows_per_warp = warp_size / 2;
1766
  #pragma unroll
1767
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
@@ -1829,7 +1829,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1829
 
1830
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1831
  }
1832
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1833
  }
1834
 
1835
  template <int mmq_x, int mmq_y>
@@ -1872,7 +1872,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1872
  constexpr int nwarps = mmq_get_nwarps_device();
1873
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1874
 
1875
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1876
  int * x_qs = (int *) x_tile;
1877
  half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
1878
  #else
@@ -1880,7 +1880,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1880
  int * x_qs = (int *) x_tile;
1881
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1882
  int * x_sc = (int *) (x_dm + txs.dm);
1883
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1884
 
1885
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
1886
  constexpr int nrows = warp_size / threads_per_row;
@@ -1908,16 +1908,16 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1908
  const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
1909
  const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
1910
 
1911
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1912
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
1913
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
1914
  #else
1915
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
1916
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
1917
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1918
  }
1919
 
1920
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1921
  constexpr int rows_per_warp = warp_size / 2;
1922
  #pragma unroll
1923
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
@@ -1986,7 +1986,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
1986
 
1987
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1988
  }
1989
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1990
  }
1991
 
1992
  template <int mmq_x, int mmq_y>
@@ -2029,7 +2029,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2029
  constexpr int nwarps = mmq_get_nwarps_device();
2030
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2031
 
2032
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2033
  int * x_qs = (int *) x_tile;
2034
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2035
  int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
@@ -2038,7 +2038,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2038
  int * x_qs = (int *) x_tile;
2039
  float * x_df = (float *) (x_qs + txs.qs);
2040
  int * x_sc = (int *) (x_df + txs.dm);
2041
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2042
 
2043
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
2044
  constexpr int nrows = warp_size / threads_per_row;
@@ -2065,13 +2065,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2065
  const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
2066
  const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
2067
 
2068
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2069
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2070
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2071
  #else
2072
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2073
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2074
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2075
  }
2076
 
2077
  #pragma unroll
@@ -2084,11 +2084,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2084
 
2085
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
2086
 
2087
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2088
  x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
2089
  #else
2090
  x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
2091
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2092
  }
2093
 
2094
  constexpr int rows_per_warp = warp_size / 4;
@@ -2102,11 +2102,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2102
 
2103
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
2104
 
2105
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2106
  x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
2107
  #else
2108
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
2109
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2110
  }
2111
  }
2112
 
@@ -2199,7 +2199,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2199
  }
2200
  }
2201
  }
2202
- #elif defined(NEW_MMA_AVAILABLE)
2203
 
2204
  typedef tile<16, 4, int> tile_A;
2205
  typedef tile< 8, 4, int> tile_B;
@@ -2311,14 +2311,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2311
  constexpr int nwarps = mmq_get_nwarps_device();
2312
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2313
 
2314
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2315
  int * x_qs = (int *) x_tile;
2316
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2317
  #else
2318
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
2319
  int * x_qs = (int *) x_tile;
2320
  float * x_df = (float *) (x_qs + txs.qs);
2321
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2322
 
2323
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
2324
  constexpr int nrows = warp_size / threads_per_row;
@@ -2340,13 +2340,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2340
  const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2341
  const int k0 = kbx * (2 * QI4_NL) + kqsx;
2342
 
2343
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2344
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2345
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
2346
  #else
2347
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2348
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
2349
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2350
  }
2351
 
2352
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
@@ -2363,11 +2363,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2363
 
2364
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
2365
 
2366
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2367
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
2368
  #else
2369
  x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
2370
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2371
  }
2372
  }
2373
 
@@ -2376,14 +2376,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2376
  constexpr int nwarps = mmq_get_nwarps_device();
2377
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2378
 
2379
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2380
  int * x_qs = (int *) x_tile;
2381
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2382
  #else
2383
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
2384
  int * x_qs = (int *) x_tile;
2385
  float * x_df = (float *) (x_qs + txs.qs);
2386
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2387
 
2388
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
2389
  constexpr int nrows = warp_size / threads_per_row;
@@ -2414,22 +2414,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2414
  const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
2415
  const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
2416
 
2417
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2418
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
2419
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
2420
  #else
2421
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
2422
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
2423
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2424
  }
2425
 
2426
  const int ls = aux32 >> 28;
2427
  const float d = bxi->d;
2428
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2429
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
2430
  #else
2431
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
2432
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2433
  }
2434
  }
2435
 
@@ -2438,14 +2438,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2438
  constexpr int nwarps = mmq_get_nwarps_device();
2439
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2440
 
2441
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2442
  int * x_qs = (int *) x_tile;
2443
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2444
  #else
2445
  constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
2446
  int * x_qs = (int *) x_tile;
2447
  float * x_df = (float *) (x_qs + txs.qs);
2448
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2449
 
2450
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
2451
  constexpr int nrows = warp_size / threads_per_row;
@@ -2472,24 +2472,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2472
  const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
2473
  const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
2474
 
2475
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2476
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2477
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2478
  #else
2479
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2480
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2481
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2482
  }
2483
 
2484
  const int ls = bxi->scales[kqsx];
2485
  const float d = bxi->d;
2486
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2487
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2488
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2489
  #else
2490
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2491
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2492
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2493
  }
2494
  }
2495
 
@@ -2498,14 +2498,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2498
  constexpr int nwarps = mmq_get_nwarps_device();
2499
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2500
 
2501
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2502
  int * x_qs = (int *) x_tile;
2503
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2504
  #else
2505
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
2506
  int * x_qs = (int *) x_tile;
2507
  float * x_df = (float *) (x_qs + txs.qs);
2508
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2509
 
2510
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
2511
  constexpr int nrows = warp_size / threads_per_row;
@@ -2539,24 +2539,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2539
  const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
2540
  const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
2541
 
2542
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2543
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2544
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2545
  #else
2546
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2547
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2548
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2549
  }
2550
 
2551
  const int ls = bxi->scales[kqsx];
2552
  const float d = bxi->d;
2553
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2554
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2555
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2556
  #else
2557
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2558
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2559
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2560
  }
2561
  }
2562
 
@@ -2565,14 +2565,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2565
  constexpr int nwarps = mmq_get_nwarps_device();
2566
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2567
 
2568
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2569
  int * x_qs = (int *) x_tile;
2570
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2571
  #else
2572
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
2573
  int * x_qs = (int *) x_tile;
2574
  float * x_df = (float *) (x_qs + txs.qs);
2575
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2576
 
2577
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
2578
  constexpr int nrows = warp_size / threads_per_row;
@@ -2601,22 +2601,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2601
  const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
2602
  const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
2603
 
2604
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2605
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
2606
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
2607
  #else
2608
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2609
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2610
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2611
  }
2612
 
2613
  const int ls = aux32 >> 28;
2614
  const float d = bxi->d;
2615
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2616
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2617
  #else
2618
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2619
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2620
  }
2621
  }
2622
 
@@ -2625,14 +2625,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2625
  constexpr int nwarps = mmq_get_nwarps_device();
2626
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2627
 
2628
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2629
  int * x_qs = (int *) x_tile;
2630
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2631
  #else
2632
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2633
  int * x_qs = (int *) x_tile;
2634
  float * x_df = (float *) (x_qs + txs.qs);
2635
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2636
 
2637
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
2638
  constexpr int nrows = warp_size / threads_per_row;
@@ -2668,22 +2668,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2668
  const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2669
  const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2670
 
2671
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2672
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
2673
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
2674
  #else
2675
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
2676
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
2677
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2678
  }
2679
 
2680
  const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
2681
  const float d = bxi->d;
2682
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2683
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2684
  #else
2685
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
2686
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2687
  }
2688
  }
2689
 
@@ -2692,14 +2692,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2692
  constexpr int nwarps = mmq_get_nwarps_device();
2693
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2694
 
2695
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2696
  int * x_qs = (int *) x_tile;
2697
  half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
2698
  #else
2699
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2700
  int * x_qs = (int *) x_tile;
2701
  half2 * x_ds = (half2 *) (x_qs + txs.qs);
2702
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2703
 
2704
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
2705
  constexpr int nrows = warp_size / threads_per_row;
@@ -2727,23 +2727,23 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2727
  const int grid0 = (grid >> 0) & 0x0F0F0F0F;
2728
  const int grid1 = (grid >> 4) & 0x0F0F0F0F;
2729
 
2730
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2731
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
2732
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
2733
  #else
2734
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
2735
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
2736
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2737
  }
2738
 
2739
  const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
2740
  const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
2741
 
2742
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2743
  x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
2744
  #else
2745
  x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
2746
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2747
  }
2748
  }
2749
 
@@ -2752,14 +2752,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2752
  constexpr int nwarps = mmq_get_nwarps_device();
2753
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2754
 
2755
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2756
  int * x_qs = (int *) x_tile;
2757
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2758
  #else
2759
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2760
  int * x_qs = (int *) x_tile;
2761
  float * x_df = (float *) (x_qs + txs.qs);
2762
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2763
 
2764
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
2765
  constexpr int nrows = warp_size / threads_per_row;
@@ -2779,13 +2779,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2779
  const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2780
  const int k0 = 8 * (kqsx / 4) + kqsx % 4;
2781
 
2782
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2783
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2784
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2785
  #else
2786
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2787
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
2788
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2789
  }
2790
 
2791
  constexpr int rows_per_warp = warp_size / 8;
@@ -2804,11 +2804,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
2804
  const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
2805
  | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
2806
 
2807
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2808
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
2809
  #else
2810
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2811
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2812
  }
2813
  }
2814
 
@@ -2859,9 +2859,9 @@ static __device__ __forceinline__ void mmq_write_back_mma(
2859
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2860
 
2861
  const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
2862
- #if defined(NEW_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
2863
  static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
2864
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
2865
 
2866
  #pragma unroll
2867
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
@@ -3061,13 +3061,13 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
3061
  int * tile_y = data_mul_mat_q + mmq_x;
3062
  int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
3063
 
3064
- #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
3065
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
3066
  constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
3067
  #else
3068
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
3069
  constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
3070
- #endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
3071
 
3072
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3073
 
@@ -3534,7 +3534,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
3534
  const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
3535
  const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
3536
  const size_t nbs_ids = mmq_x*sizeof(int);
3537
- const size_t nbs_x = (new_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3538
  const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
3539
  return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
3540
  }
 
92
  };
93
 
94
  static int get_mmq_x_max_host(const int cc) {
95
+ return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 :
96
  GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
97
  #ifdef GGML_CUDA_FORCE_MMQ
98
  128 : 64;
 
102
  }
103
 
104
  static constexpr __device__ int get_mmq_x_max_device() {
105
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
106
  return 128;
107
+ #else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
108
 
109
  #if defined(GGML_USE_HIP)
110
  return 64;
 
121
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
122
 
123
  #endif // defined(GGML_USE_HIP)
124
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
125
  }
126
 
127
  static int get_mmq_y_host(const int cc) {
 
233
  static int mmq_get_granularity_host(const int mmq_x, const int cc) {
234
  if (amd_mfma_available(cc)) {
235
  return mmq_x >= 128 ? 32 : 16;
236
+ } else if (turing_mma_available(cc) && mmq_x >= 48) {
237
  return 16;
238
  } else {
239
  return 8;
 
244
  static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
245
  return mmq_x >= 128 ? 32 : 16;
246
  }
247
+ #elif defined(TURING_MMA_AVAILABLE)
248
  static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
249
  return mmq_x >= 48 ? 16 : 8;
250
  }
 
279
  constexpr int nwarps = mmq_get_nwarps_device();
280
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
281
 
282
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
283
  int * x_qs = (int *) x_tile;
284
  float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
285
  #else
286
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
287
  int * x_qs = (int *) x_tile;
288
  float * x_df = (float *) (x_qs + txs.qs);
289
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
290
 
291
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
292
  constexpr int nrows = warp_size / threads_per_row;
 
305
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
306
  const int qs0 = get_int_b2(bxi->qs, kqsx);
307
 
308
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
309
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
310
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
311
  #else
312
  x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
313
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
314
  }
315
 
316
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
 
327
 
328
  const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
329
 
330
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
331
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
332
  #else
333
  x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
334
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
335
  }
336
  }
337
 
 
382
  constexpr int nwarps = mmq_get_nwarps_device();
383
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
384
 
385
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
386
  int * x_qs = (int *) x_tile;
387
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
388
  #else
389
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
390
  int * x_qs = (int *) x_tile;
391
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
392
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
393
 
394
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
395
  constexpr int nrows = warp_size / threads_per_row;
 
408
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
409
  const int qs0 = get_int_b4(bxi->qs, kqsx);
410
 
411
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
412
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
413
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
414
  #else
415
  x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
416
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
417
  }
418
 
419
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
 
430
 
431
  const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
432
 
433
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
434
  x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
435
  #else
436
  x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
437
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
438
  }
439
  }
440
 
 
485
  constexpr int nwarps = mmq_get_nwarps_device();
486
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
487
 
488
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
489
  int * x_qs = (int *) x_tile;
490
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
491
  #else
492
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
493
  int * x_qs = (int *) x_tile;
494
  float * x_df = (float *) (x_qs + txs.qs);
495
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
496
 
497
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
498
  constexpr int nrows = warp_size / threads_per_row;
 
527
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
528
  qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
529
 
530
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
531
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
532
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
533
  #else
534
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
535
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
536
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
537
  }
538
 
539
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
 
550
 
551
  const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
552
 
553
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
554
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
555
  #else
556
  x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
557
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
558
  }
559
  }
560
 
 
563
  constexpr int nwarps = mmq_get_nwarps_device();
564
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
565
 
566
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
567
  int * x_qs = (int *) x_tile;
568
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
569
  #else
570
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
571
  int * x_qs = (int *) x_tile;
572
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
573
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
574
 
575
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
576
  constexpr int nrows = warp_size / threads_per_row;
 
603
  qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
604
  qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
605
 
606
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
607
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
608
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
609
  #else
610
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
611
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
612
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
613
  }
614
 
615
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
 
626
 
627
  const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
628
 
629
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
630
  x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
631
  #else
632
  x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
633
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
634
  }
635
  }
636
 
 
639
  constexpr int nwarps = mmq_get_nwarps_device();
640
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
641
 
642
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
643
  int * x_qs = (int *) x_tile;
644
  float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
645
  #else
646
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
647
  int * x_qs = (int *) x_tile;
648
  float * x_df = (float *) (x_qs + txs.qs);
649
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
650
 
651
  // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
652
  constexpr int threads_per_row = 32;
 
665
 
666
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
667
 
668
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
669
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
670
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
671
  #else
672
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
673
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
674
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
675
  }
676
 
677
  constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
 
688
 
689
  const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
690
 
691
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
692
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
693
  #else
694
  x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
695
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
696
  }
697
  }
698
 
 
701
  constexpr int nwarps = mmq_get_nwarps_device();
702
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
703
 
704
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
705
  int * x_qs = (int *) x_tile;
706
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
707
  #else
708
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
709
  int * x_qs = (int *) x_tile;
710
  float * x_df = (float *) (x_qs + txs.qs);
711
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
712
 
713
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
714
  constexpr int nrows = warp_size / threads_per_row;
 
730
  const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
731
  const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
732
 
733
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
734
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
735
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
736
  #else
737
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
738
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
739
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
740
  }
741
 
742
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
 
753
 
754
  const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
755
 
756
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
757
  x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
758
  #else
759
  x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
760
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
761
  }
762
  }
763
 
 
1178
  }
1179
  }
1180
  }
1181
+ #elif defined(TURING_MMA_AVAILABLE)
1182
 
1183
  typedef tile<16, 4, int> tile_A;
1184
  typedef tile<16, 8, int> tile_A_8;
 
1264
  const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1265
  constexpr int nwarps = mmq_get_nwarps_device();
1266
 
1267
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1268
  int * x_qs = (int *) x_tile;
1269
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1270
  #else
1271
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1272
  int * x_qs = (int *) x_tile;
1273
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1274
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1275
 
1276
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
1277
  constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
 
1295
 
1296
  const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
1297
 
1298
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1299
  x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
1300
  #else
1301
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1302
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1303
  }
1304
 
1305
  const int sc_m = bxi->scales[kqsx];
 
1310
  const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
1311
  #endif // FAST_FP16_AVAILABLE
1312
 
1313
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1314
  x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
1315
  #else
1316
  x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
1317
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1318
  }
1319
  }
1320
 
 
1452
  }
1453
  }
1454
  }
1455
+ #elif defined(TURING_MMA_AVAILABLE)
1456
 
1457
  typedef tile<16, 4, int> tile_A;
1458
  typedef tile<16, 8, int> tile_A_8;
 
1582
  constexpr int nwarps = mmq_get_nwarps_device();
1583
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1584
 
1585
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1586
  int * x_qs = (int *) x_tile;
1587
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1588
  #else
 
1590
  int * x_qs = (int *) x_tile;
1591
  float * x_df = (float *) (x_qs + txs.qs);
1592
  int * x_sc = (int *) (x_df + txs.dm);
1593
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1594
 
1595
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
1596
  constexpr int nrows = warp_size / threads_per_row;
 
1618
 
1619
  const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
1620
 
1621
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1622
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
1623
  #else
1624
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1625
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1626
  }
1627
  }
1628
 
 
1649
 
1650
  const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
1651
 
1652
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1653
  const int8_t * sc8 = (const int8_t *) &sc;
1654
  const float d = bxi->d;
1655
 
 
1659
  }
1660
  #else
1661
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
1662
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1663
  }
1664
 
1665
+ #if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1666
  #pragma unroll
1667
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1668
  int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
 
1675
 
1676
  x_df[i] = bxi->d;
1677
  }
1678
+ #endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1679
  }
1680
 
1681
  template <int mmq_x, int mmq_y>
 
1728
  constexpr int nwarps = mmq_get_nwarps_device();
1729
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1730
 
1731
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1732
  int * x_qs = (int *) x_tile;
1733
  half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1734
  #else
 
1736
  int * x_qs = (int *) x_tile;
1737
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1738
  int * x_sc = (int *) (x_dm + txs.dm);
1739
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1740
 
1741
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
1742
  constexpr int nrows = warp_size / threads_per_row;
 
1753
  const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1754
  const int qs0 = get_int_b4(bxi->qs, txi);
1755
 
1756
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1757
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
1758
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
1759
  #else
1760
  x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
1761
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1762
  }
1763
 
1764
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1765
  constexpr int rows_per_warp = warp_size / 2;
1766
  #pragma unroll
1767
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
 
1829
 
1830
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1831
  }
1832
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1833
  }
1834
 
1835
  template <int mmq_x, int mmq_y>
 
1872
  constexpr int nwarps = mmq_get_nwarps_device();
1873
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1874
 
1875
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1876
  int * x_qs = (int *) x_tile;
1877
  half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
1878
  #else
 
1880
  int * x_qs = (int *) x_tile;
1881
  half2 * x_dm = (half2 *) (x_qs + txs.qs);
1882
  int * x_sc = (int *) (x_dm + txs.dm);
1883
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1884
 
1885
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
1886
  constexpr int nrows = warp_size / threads_per_row;
 
1908
  const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
1909
  const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
1910
 
1911
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1912
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
1913
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
1914
  #else
1915
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
1916
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
1917
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1918
  }
1919
 
1920
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1921
  constexpr int rows_per_warp = warp_size / 2;
1922
  #pragma unroll
1923
  for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
 
1986
 
1987
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1988
  }
1989
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1990
  }
1991
 
1992
  template <int mmq_x, int mmq_y>
 
2029
  constexpr int nwarps = mmq_get_nwarps_device();
2030
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2031
 
2032
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2033
  int * x_qs = (int *) x_tile;
2034
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2035
  int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
 
2038
  int * x_qs = (int *) x_tile;
2039
  float * x_df = (float *) (x_qs + txs.qs);
2040
  int * x_sc = (int *) (x_df + txs.dm);
2041
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2042
 
2043
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
2044
  constexpr int nrows = warp_size / threads_per_row;
 
2065
  const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
2066
  const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
2067
 
2068
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2069
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2070
  x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2071
  #else
2072
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2073
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2074
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2075
  }
2076
 
2077
  #pragma unroll
 
2084
 
2085
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
2086
 
2087
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2088
  x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
2089
  #else
2090
  x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
2091
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2092
  }
2093
 
2094
  constexpr int rows_per_warp = warp_size / 4;
 
2102
 
2103
  const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
2104
 
2105
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2106
  x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
2107
  #else
2108
  x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
2109
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2110
  }
2111
  }
2112
 
 
2199
  }
2200
  }
2201
  }
2202
+ #elif defined(TURING_MMA_AVAILABLE)
2203
 
2204
  typedef tile<16, 4, int> tile_A;
2205
  typedef tile< 8, 4, int> tile_B;
 
2311
  constexpr int nwarps = mmq_get_nwarps_device();
2312
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2313
 
2314
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2315
  int * x_qs = (int *) x_tile;
2316
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2317
  #else
2318
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
2319
  int * x_qs = (int *) x_tile;
2320
  float * x_df = (float *) (x_qs + txs.qs);
2321
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2322
 
2323
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
2324
  constexpr int nrows = warp_size / threads_per_row;
 
2340
  const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2341
  const int k0 = kbx * (2 * QI4_NL) + kqsx;
2342
 
2343
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2344
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2345
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
2346
  #else
2347
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2348
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
2349
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2350
  }
2351
 
2352
  constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
 
2363
 
2364
  const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
2365
 
2366
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2367
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
2368
  #else
2369
  x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
2370
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2371
  }
2372
  }
2373
 
 
2376
  constexpr int nwarps = mmq_get_nwarps_device();
2377
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2378
 
2379
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2380
  int * x_qs = (int *) x_tile;
2381
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2382
  #else
2383
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
2384
  int * x_qs = (int *) x_tile;
2385
  float * x_df = (float *) (x_qs + txs.qs);
2386
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2387
 
2388
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
2389
  constexpr int nrows = warp_size / threads_per_row;
 
2414
  const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
2415
  const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
2416
 
2417
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2418
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
2419
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
2420
  #else
2421
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
2422
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
2423
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2424
  }
2425
 
2426
  const int ls = aux32 >> 28;
2427
  const float d = bxi->d;
2428
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2429
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
2430
  #else
2431
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
2432
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2433
  }
2434
  }
2435
 
 
2438
  constexpr int nwarps = mmq_get_nwarps_device();
2439
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2440
 
2441
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2442
  int * x_qs = (int *) x_tile;
2443
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2444
  #else
2445
  constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
2446
  int * x_qs = (int *) x_tile;
2447
  float * x_df = (float *) (x_qs + txs.qs);
2448
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2449
 
2450
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
2451
  constexpr int nrows = warp_size / threads_per_row;
 
2472
  const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
2473
  const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
2474
 
2475
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2476
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2477
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2478
  #else
2479
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2480
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2481
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2482
  }
2483
 
2484
  const int ls = bxi->scales[kqsx];
2485
  const float d = bxi->d;
2486
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2487
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2488
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2489
  #else
2490
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2491
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2492
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2493
  }
2494
  }
2495
 
 
2498
  constexpr int nwarps = mmq_get_nwarps_device();
2499
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2500
 
2501
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2502
  int * x_qs = (int *) x_tile;
2503
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2504
  #else
2505
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
2506
  int * x_qs = (int *) x_tile;
2507
  float * x_df = (float *) (x_qs + txs.qs);
2508
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2509
 
2510
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
2511
  constexpr int nrows = warp_size / threads_per_row;
 
2539
  const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
2540
  const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
2541
 
2542
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2543
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2544
  x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2545
  #else
2546
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2547
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2548
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2549
  }
2550
 
2551
  const int ls = bxi->scales[kqsx];
2552
  const float d = bxi->d;
2553
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2554
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2555
  x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2556
  #else
2557
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2558
  x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2559
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2560
  }
2561
  }
2562
 
 
2565
  constexpr int nwarps = mmq_get_nwarps_device();
2566
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2567
 
2568
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2569
  int * x_qs = (int *) x_tile;
2570
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2571
  #else
2572
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
2573
  int * x_qs = (int *) x_tile;
2574
  float * x_df = (float *) (x_qs + txs.qs);
2575
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2576
 
2577
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
2578
  constexpr int nrows = warp_size / threads_per_row;
 
2601
  const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
2602
  const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
2603
 
2604
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2605
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
2606
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
2607
  #else
2608
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2609
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2610
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2611
  }
2612
 
2613
  const int ls = aux32 >> 28;
2614
  const float d = bxi->d;
2615
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2616
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2617
  #else
2618
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2619
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2620
  }
2621
  }
2622
 
 
2625
  constexpr int nwarps = mmq_get_nwarps_device();
2626
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2627
 
2628
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2629
  int * x_qs = (int *) x_tile;
2630
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2631
  #else
2632
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2633
  int * x_qs = (int *) x_tile;
2634
  float * x_df = (float *) (x_qs + txs.qs);
2635
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2636
 
2637
  constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
2638
  constexpr int nrows = warp_size / threads_per_row;
 
2668
  const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
2669
  const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
2670
 
2671
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2672
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
2673
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
2674
  #else
2675
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
2676
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
2677
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2678
  }
2679
 
2680
  const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
2681
  const float d = bxi->d;
2682
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2683
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2684
  #else
2685
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
2686
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2687
  }
2688
  }
2689
 
 
2692
  constexpr int nwarps = mmq_get_nwarps_device();
2693
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2694
 
2695
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2696
  int * x_qs = (int *) x_tile;
2697
  half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
2698
  #else
2699
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2700
  int * x_qs = (int *) x_tile;
2701
  half2 * x_ds = (half2 *) (x_qs + txs.qs);
2702
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2703
 
2704
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
2705
  constexpr int nrows = warp_size / threads_per_row;
 
2727
  const int grid0 = (grid >> 0) & 0x0F0F0F0F;
2728
  const int grid1 = (grid >> 4) & 0x0F0F0F0F;
2729
 
2730
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2731
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
2732
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
2733
  #else
2734
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
2735
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
2736
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2737
  }
2738
 
2739
  const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
2740
  const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
2741
 
2742
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2743
  x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
2744
  #else
2745
  x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
2746
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2747
  }
2748
  }
2749
 
 
2752
  constexpr int nwarps = mmq_get_nwarps_device();
2753
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2754
 
2755
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2756
  int * x_qs = (int *) x_tile;
2757
  float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2758
  #else
2759
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2760
  int * x_qs = (int *) x_tile;
2761
  float * x_df = (float *) (x_qs + txs.qs);
2762
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2763
 
2764
  constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
2765
  constexpr int nrows = warp_size / threads_per_row;
 
2779
  const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2780
  const int k0 = 8 * (kqsx / 4) + kqsx % 4;
2781
 
2782
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2783
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2784
  x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2785
  #else
2786
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2787
  x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
2788
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2789
  }
2790
 
2791
  constexpr int rows_per_warp = warp_size / 8;
 
2804
  const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
2805
  | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
2806
 
2807
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2808
  x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
2809
  #else
2810
  x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2811
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2812
  }
2813
  }
2814
 
 
2859
  constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2860
 
2861
  const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
2862
+ #if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
2863
  static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
2864
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2865
 
2866
  #pragma unroll
2867
  for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 
3061
  int * tile_y = data_mul_mat_q + mmq_x;
3062
  int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
3063
 
3064
+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3065
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
3066
  constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
3067
  #else
3068
  constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
3069
  constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
3070
+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3071
 
3072
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3073
 
 
3534
  const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
3535
  const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
3536
  const size_t nbs_ids = mmq_x*sizeof(int);
3537
+ const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3538
  const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
3539
  return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
3540
  }
ggml/src/ggml-cuda/mmvf.cu ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml.h"
2
+ #include "common.cuh"
3
+ #include "mmvf.cuh"
4
+
5
+ template <typename T, typename type_acc, int ncols_dst, int block_size>
6
+ static __global__ void mul_mat_vec_f(
7
+ const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
8
+ const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
9
+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
10
+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
11
+ const int row = blockIdx.x;
12
+ const int channel_dst = blockIdx.y;
13
+ const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
14
+ const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
15
+ const int sample_dst = blockIdx.z;
16
+ const int sample_x = sample_dst / sample_ratio;
17
+ const int sample_y = sample_dst;
18
+ const int tid = threadIdx.x;
19
+
20
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
21
+
22
+ x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
23
+ y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
24
+ dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
25
+
26
+ const float2 * y2 = (const float2 *) y;
27
+
28
+ extern __shared__ char data_mmv[];
29
+ float * buf_iw = (float *) data_mmv;
30
+
31
+ if (block_size > warp_size) {
32
+ if (tid < warp_size) {
33
+ buf_iw[tid] = 0.0f;
34
+ }
35
+ __syncthreads();
36
+ }
37
+
38
+ float sumf[ncols_dst] = {0.0f};
39
+
40
+ if constexpr (std::is_same_v<T, float>) {
41
+ const float2 * x2 = (const float2 *) x;
42
+
43
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
44
+ const float2 tmpx = x2[col2];
45
+
46
+ #pragma unroll
47
+ for (int j = 0; j < ncols_dst; ++j) {
48
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
49
+ sumf[j] += tmpx.x*tmpy.x;
50
+ sumf[j] += tmpx.y*tmpy.y;
51
+ }
52
+ }
53
+ } else if constexpr (std::is_same_v<T, half>) {
54
+ const half2 * x2 = (const half2 *) x;
55
+
56
+ if (std::is_same_v<type_acc, float>) {
57
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
58
+ const float2 tmpx = __half22float2(x2[col2]);
59
+
60
+ #pragma unroll
61
+ for (int j = 0; j < ncols_dst; ++j) {
62
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
63
+ sumf[j] += tmpx.x * tmpy.x;
64
+ sumf[j] += tmpx.y * tmpy.y;
65
+ }
66
+ }
67
+ } else {
68
+ #ifdef FP16_AVAILABLE
69
+ half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
70
+
71
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
72
+ const half2 tmpx = x2[col2];
73
+
74
+ #pragma unroll
75
+ for (int j = 0; j < ncols_dst; ++j) {
76
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
77
+ sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
78
+ }
79
+ }
80
+
81
+ #pragma unroll
82
+ for (int j = 0; j < ncols_dst; ++j) {
83
+ sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
84
+ }
85
+ #else
86
+ NO_DEVICE_CODE;
87
+ #endif // FP16_AVAILABLE
88
+ }
89
+ } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
90
+ const int * x2 = (const int *) x;
91
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
92
+ const int tmpx = x2[col2];
93
+ #pragma unroll
94
+ for (int j = 0; j < ncols_dst; ++j) {
95
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
96
+ sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
97
+ sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
98
+ }
99
+ }
100
+ } else {
101
+ static_assert(std::is_same_v<T, void>, "unsupported type");
102
+ }
103
+
104
+ #pragma unroll
105
+ for (int j = 0; j < ncols_dst; ++j) {
106
+ sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
107
+
108
+ if (block_size > warp_size) {
109
+ buf_iw[tid/warp_size] = sumf[j];
110
+ __syncthreads();
111
+ if (tid < warp_size) {
112
+ sumf[j] = buf_iw[tid];
113
+ sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
114
+ }
115
+ if (j < ncols_dst) {
116
+ __syncthreads();
117
+ }
118
+ }
119
+ }
120
+
121
+ if (tid >= ncols_dst) {
122
+ return;
123
+ }
124
+
125
+ dst[tid*stride_col_dst + row] = sumf[tid];
126
+ }
127
+
128
+ template <typename T, typename type_acc, int ncols_dst>
129
+ static void launch_mul_mat_vec_f_cuda(
130
+ const T * x, const float * y, const int32_t * ids, float * dst,
131
+ const int64_t ncols, const int64_t nrows,
132
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
133
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
134
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
135
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
136
+ cudaStream_t stream) {
137
+ GGML_ASSERT(ncols % 2 == 0);
138
+ GGML_ASSERT(stride_row % 2 == 0);
139
+ GGML_ASSERT(stride_col_y % 2 == 0);
140
+ GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
141
+ GGML_ASSERT( nsamples_dst % nsamples_x == 0);
142
+ const int64_t channel_ratio = nchannels_dst / nchannels_x;
143
+ const int64_t sample_ratio = nsamples_dst / nsamples_x;
144
+
145
+ const int device = ggml_cuda_get_device();
146
+ const int warp_size = ggml_cuda_info().devices[device].warp_size;
147
+
148
+ int64_t block_size_best = warp_size;
149
+ int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
150
+ int64_t max_block_size = 256;
151
+ if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
152
+ max_block_size = 128;
153
+ }
154
+ for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
155
+ const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
156
+ if (niter < niter_best) {
157
+ niter_best = niter;
158
+ block_size_best = block_size;
159
+ }
160
+ }
161
+
162
+ const int nbytes_shared = warp_size*sizeof(float);
163
+ const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
164
+ const dim3 block_dims(block_size_best, 1, 1);
165
+ switch (block_size_best) {
166
+ case 32: {
167
+ mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
168
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
169
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
170
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
171
+ } break;
172
+ case 64: {
173
+ mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
174
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
175
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
176
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
177
+ } break;
178
+ case 96: {
179
+ mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
180
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
181
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
182
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
183
+ } break;
184
+ case 128: {
185
+ mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
186
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
187
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
188
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
189
+ } break;
190
+ case 160: {
191
+ mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
192
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
193
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
194
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
195
+ } break;
196
+ case 192: {
197
+ mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
198
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
199
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
200
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
201
+ } break;
202
+ case 224: {
203
+ mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
204
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
205
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
206
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
207
+ } break;
208
+ case 256: {
209
+ mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
210
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
211
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
212
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
213
+ } break;
214
+ default: {
215
+ GGML_ABORT("fatal error");
216
+ } break;
217
+ }
218
+ }
219
+
220
+ template <typename T, typename type_acc>
221
+ static void mul_mat_vec_f_cuda_switch_ncols_dst(
222
+ const T * x, const float * y, const int32_t * ids, float * dst,
223
+ const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
224
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
225
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
226
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
227
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
228
+ cudaStream_t stream) {
229
+ switch (ncols_dst) {
230
+ case 1:
231
+ launch_mul_mat_vec_f_cuda<T, type_acc, 1>
232
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
233
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
234
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
235
+ break;
236
+ case 2:
237
+ launch_mul_mat_vec_f_cuda<T, type_acc, 2>
238
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
239
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
240
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
241
+ break;
242
+ case 3:
243
+ launch_mul_mat_vec_f_cuda<T, type_acc, 3>
244
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
245
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
246
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
247
+ break;
248
+ case 4:
249
+ launch_mul_mat_vec_f_cuda<T, type_acc, 4>
250
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
251
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
252
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
253
+ break;
254
+ case 5:
255
+ launch_mul_mat_vec_f_cuda<T, type_acc, 5>
256
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
257
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
258
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
259
+ break;
260
+ case 6:
261
+ launch_mul_mat_vec_f_cuda<T, type_acc, 6>
262
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
263
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
264
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
265
+ break;
266
+ case 7:
267
+ launch_mul_mat_vec_f_cuda<T, type_acc, 7>
268
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
269
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
270
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
271
+ break;
272
+ case 8:
273
+ launch_mul_mat_vec_f_cuda<T, type_acc, 8>
274
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
275
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
276
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
277
+ break;
278
+ default:
279
+ GGML_ABORT("fatal error");
280
+ break;
281
+ }
282
+ }
283
+
284
+ template<typename T>
285
+ static void mul_mat_vec_f_cuda(
286
+ const T * x, const float * y, const int32_t * ids, float * dst,
287
+ const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
288
+ const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
289
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
290
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
291
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
292
+ enum ggml_prec prec, cudaStream_t stream) {
293
+ if constexpr(std::is_same_v<T, half>) {
294
+ if (prec == GGML_PREC_DEFAULT) {
295
+ mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
296
+ (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
297
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
298
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
299
+ return;
300
+ }
301
+ }
302
+ mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
303
+ (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
304
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
305
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
306
+ }
307
+
308
+ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
309
+ GGML_ASSERT( src1->type == GGML_TYPE_F32);
310
+ GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
311
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
312
+
313
+ GGML_TENSOR_BINARY_OP_LOCALS;
314
+
315
+ const size_t ts_src0 = ggml_type_size(src0->type);
316
+ const size_t ts_src1 = ggml_type_size(src1->type);
317
+ const size_t ts_dst = ggml_type_size(dst->type);
318
+
319
+ GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
320
+ GGML_ASSERT(ne13 == ne3);
321
+
322
+ GGML_ASSERT( nb00 == ts_src0);
323
+ GGML_ASSERT( nb10 == ts_src1);
324
+ GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
325
+ GGML_ASSERT( nb0 == ts_dst);
326
+
327
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
328
+ const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
329
+
330
+ const float * src1_d = (const float *) src1->data;
331
+ const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
332
+ float * dst_d = (float *) dst->data;
333
+
334
+ const int64_t s01 = src0->nb[1] / ts_src0;
335
+ const int64_t s11 = src1->nb[1] / ts_src1;
336
+ const int64_t s1 = dst->nb[1] / ts_dst;
337
+ const int64_t s02 = src0->nb[2] / ts_src0;
338
+ const int64_t s12 = src1->nb[2] / ts_src1;
339
+ const int64_t s2 = dst->nb[2] / ts_dst;
340
+ const int64_t s03 = src0->nb[3] / ts_src0;
341
+ const int64_t s13 = src1->nb[3] / ts_src1;
342
+ const int64_t s3 = dst->nb[3] / ts_dst;
343
+
344
+ // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
345
+ const int64_t ncols_dst = ids ? ne2 : ne1;
346
+ const int64_t nchannels_y = ids ? ne11 : ne12;
347
+ const int64_t nchannels_dst = ids ? ne1 : ne2;
348
+ const int64_t stride_channel_dst = ids ? s1 : s2;
349
+ const int64_t stride_channel_y = ids ? s11 : s12;
350
+
351
+ GGML_ASSERT(!ids || ncols_dst == 1);
352
+
353
+ switch (src0->type) {
354
+ case GGML_TYPE_F32: {
355
+ const float * src0_d = (const float *) src0->data;
356
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
357
+ ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
358
+ ne03, ne3, s03, s13, s3, prec, ctx.stream());
359
+ } break;
360
+ case GGML_TYPE_F16: {
361
+ const half * src0_d = (const half *) src0->data;
362
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
363
+ ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
364
+ ne03, ne3, s03, s13, s3, prec, ctx.stream());
365
+ } break;
366
+ case GGML_TYPE_BF16: {
367
+ const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
368
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
369
+ ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
370
+ ne03, ne3, s03, s13, s3, prec, ctx.stream());
371
+ } break;
372
+ default:
373
+ GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
374
+ }
375
+ }
376
+
377
+ void ggml_cuda_op_mul_mat_vec_f(
378
+ ggml_backend_cuda_context & ctx,
379
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
380
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
381
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
382
+
383
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
384
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
385
+
386
+ const int64_t ne00 = src0->ne[0];
387
+ const int64_t ne10 = src1->ne[0];
388
+ const int64_t ne0 = dst->ne[0];
389
+ const int64_t row_diff = row_high - row_low;
390
+
391
+ const int id = ggml_cuda_get_device();
392
+ const int cc = ggml_cuda_info().devices[id].cc;
393
+ const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
394
+
395
+
396
+ // ggml_cuda_op provides single, contiguous matrices
397
+ const int64_t stride_row = ne00;
398
+ const int64_t stride_col_y = ne10;
399
+ const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
400
+ const int64_t nchannels_x = 1;
401
+ const int64_t nchannels_y = 1;
402
+ const int64_t nchannels_dst = 1;
403
+ const int64_t stride_channel_x = 0;
404
+ const int64_t stride_channel_y = 0;
405
+ const int64_t stride_channel_dst = 0;
406
+ const int64_t nsamples_x = 1;
407
+ const int64_t nsamples_dst = 1;
408
+ const int64_t stride_sample_x = 0;
409
+ const int64_t stride_sample_y = 0;
410
+ const int64_t stride_sample_dst = 0;
411
+
412
+ switch (src0->type) {
413
+ case GGML_TYPE_F32: {
414
+ const float * src0_d = (const float *) src0_dd_i;
415
+ mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
416
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
417
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
418
+ } break;
419
+ case GGML_TYPE_F16: {
420
+ const half * src0_d = (const half *) src0_dd_i;
421
+ mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
422
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
423
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
424
+ } break;
425
+ case GGML_TYPE_BF16: {
426
+ const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
427
+ mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
428
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
429
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
430
+ } break;
431
+ default:
432
+ GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
433
+ }
434
+
435
+ GGML_UNUSED(ctx);
436
+ GGML_UNUSED(src1);
437
+ GGML_UNUSED(dst);
438
+ GGML_UNUSED(src1_ddq_i);
439
+ GGML_UNUSED(src1_ncols);
440
+ GGML_UNUSED(src1_padded_row_size);
441
+ }
442
+
443
+ bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
444
+ if (src0_ne[0] % 2 != 0) {
445
+ return false;
446
+ }
447
+ switch (type) {
448
+ case GGML_TYPE_F32:
449
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
450
+ if (ampere_mma_available(cc)) {
451
+ return ne11 <= 3;
452
+ }
453
+ if (cc >= GGML_CUDA_CC_TURING) {
454
+ return ne11 <= 4;
455
+ }
456
+ return ne11 <= 3;
457
+ } else if (GGML_CUDA_CC_IS_AMD(cc)) {
458
+ if (fp32_mma_hardware_available(cc)) {
459
+ return ne11 <= 3;
460
+ }
461
+ return ne11 <= 8;
462
+ }
463
+ return ne11 <= 8;
464
+ case GGML_TYPE_F16:
465
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
466
+ const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
467
+ if (ampere_mma_available(cc)) {
468
+ return src0_small && ne11 == 1;
469
+ }
470
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
471
+ return src0_small && ne11 <= 4;
472
+ }
473
+ if (fp16_mma_hardware_available(cc)) {
474
+ return src0_small && ne11 <= 3;
475
+ }
476
+ return ne11 <= 8;
477
+ } else if (GGML_CUDA_CC_IS_AMD(cc)) {
478
+ if (fp16_mma_hardware_available(cc)) {
479
+ if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
480
+ return ne11 <= 5;
481
+ }
482
+ return ne11 <= 2;
483
+ }
484
+ return ne11 <= 8;
485
+ }
486
+ return ne11 <= 8;
487
+ case GGML_TYPE_BF16:
488
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
489
+ const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
490
+ if (ampere_mma_available(cc)) {
491
+ return src0_small && ne11 == 1;
492
+ }
493
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
494
+ return src0_small && ne11 <= 4;
495
+ }
496
+ if (bf16_mma_hardware_available(cc)) {
497
+ return src0_small && ne11 <= 3;
498
+ }
499
+ return ne11 <= 8;
500
+ } else if (GGML_CUDA_CC_IS_AMD(cc)) {
501
+ if (bf16_mma_hardware_available(cc)) {
502
+ return ne11 <= 3;
503
+ }
504
+ return ne11 <= 8;
505
+ }
506
+ return ne11 <= 8;
507
+ default:
508
+ return false;
509
+ }
510
+ }
ggml/src/ggml-cuda/mmvf.cuh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
4
+
5
+ void ggml_cuda_op_mul_mat_vec_f(
6
+ ggml_backend_cuda_context & ctx,
7
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
8
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
9
+ const int64_t src1_padded_row_size, cudaStream_t stream);
10
+
11
+ bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
ggml/src/ggml-cuda/vendors/hip.h CHANGED
@@ -200,6 +200,7 @@
200
  #endif
201
 
202
  typedef hip_bfloat16 nv_bfloat16;
 
203
 
204
  typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
205
  typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
 
200
  #endif
201
 
202
  typedef hip_bfloat16 nv_bfloat16;
203
+ typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix
204
 
205
  typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
206
  typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
ggml/src/ggml-cuda/vendors/musa.h CHANGED
@@ -137,4 +137,5 @@
137
  #define cudaStreamEndCapture musaStreamEndCapture
138
  #define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
139
 
140
- typedef mt_bfloat16 nv_bfloat16;
 
 
137
  #define cudaStreamEndCapture musaStreamEndCapture
138
  #define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
139
 
140
+ typedef __mt_bfloat16 nv_bfloat16;
141
+ typedef __mt_bfloat162 nv_bfloat162;