JohannesGaessler commited on
Commit
d4c0faf
·
1 Parent(s): 315df8c

CUDA: fix Pascal FA, deq. KV to FP16 for batch > 8 (llama/7681)

Browse files
ggml-cuda/fattn-common.cuh CHANGED
@@ -1,6 +1,7 @@
1
  #pragma once
2
 
3
  #include "common.cuh"
 
4
  #include "vecdotq.cuh"
5
 
6
  #include <cstdint>
@@ -53,7 +54,7 @@ typedef float (*vec_dot_KQ_f32_t)(
53
  template<typename T, int D>
54
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
55
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
56
- #if __CUDA_ARCH__ > MIN_CC_DP4A
57
 
58
  const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
59
  GGML_UNUSED(Q_v);
@@ -95,13 +96,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
95
  GGML_UNUSED(Q_q8);
96
  GGML_UNUSED(Q_ds_v);
97
  NO_DEVICE_CODE;
98
- #endif // __CUDA_ARCH__ > MIN_CC_DP4A
99
  }
100
 
101
  template<typename T, int D>
102
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
103
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
104
- #if __CUDA_ARCH__ > MIN_CC_DP4A
105
 
106
  const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
107
  GGML_UNUSED(Q_v);
@@ -147,13 +148,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
147
  GGML_UNUSED(Q_q8);
148
  GGML_UNUSED(Q_ds_v);
149
  NO_DEVICE_CODE;
150
- #endif // __CUDA_ARCH__ > MIN_CC_DP4A
151
  }
152
 
153
  template<typename T, int D>
154
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
155
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
156
- #if __CUDA_ARCH__ > MIN_CC_DP4A
157
 
158
  const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
159
  GGML_UNUSED(Q_v);
@@ -202,13 +203,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
202
  GGML_UNUSED(Q_q8);
203
  GGML_UNUSED(Q_ds_v);
204
  NO_DEVICE_CODE;
205
- #endif // __CUDA_ARCH__ > MIN_CC_DP4A
206
  }
207
 
208
  template<typename T, int D>
209
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
210
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
211
- #if __CUDA_ARCH__ > MIN_CC_DP4A
212
 
213
  const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
214
  GGML_UNUSED(Q_v);
@@ -261,13 +262,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
261
  GGML_UNUSED(Q_q8);
262
  GGML_UNUSED(Q_ds_v);
263
  NO_DEVICE_CODE;
264
- #endif // __CUDA_ARCH__ > MIN_CC_DP4A
265
  }
266
 
267
  template <typename T, int D>
268
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
269
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
270
- #if __CUDA_ARCH__ > MIN_CC_DP4A
271
 
272
  const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
273
  GGML_UNUSED(Q_v);
@@ -302,7 +303,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
302
  GGML_UNUSED(Q_q8);
303
  GGML_UNUSED(Q_ds_v);
304
  NO_DEVICE_CODE;
305
- #endif // __CUDA_ARCH__ > MIN_CC_DP4A
306
  }
307
 
308
  template <typename T, int D>
@@ -620,7 +621,10 @@ static void on_no_fattn_vec_case(const int D) {
620
  }
621
 
622
  template <int D, int parallel_blocks>
623
- void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
 
 
 
624
  const ggml_tensor * Q = dst->src[0];
625
  const ggml_tensor * K = dst->src[1];
626
  const ggml_tensor * V = dst->src[2];
@@ -641,9 +645,49 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
641
  ggml_cuda_pool & pool = ctx.pool();
642
  cudaStream_t main_stream = ctx.stream();
643
 
 
 
644
  ggml_cuda_pool_alloc<float> dst_tmp(pool);
645
  ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
646
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
  if (parallel_blocks > 1) {
648
  dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
649
  dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
@@ -667,8 +711,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
667
 
668
  fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
669
  (const char *) Q->data,
670
- (const char *) K->data,
671
- (const char *) V->data,
672
  mask ? ((const char *) mask->data) : nullptr,
673
  (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
674
  scale, max_bias, m0, m1, n_head_log2,
@@ -676,8 +720,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
676
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
677
  mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
678
  Q->nb[1], Q->nb[2], Q->nb[3],
679
- K->nb[1], K->nb[2], K->nb[3],
680
- V->nb[1], V->nb[2], V->nb[3],
681
  KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
682
  );
683
  CUDA_CHECK(cudaGetLastError());
 
1
  #pragma once
2
 
3
  #include "common.cuh"
4
+ #include "convert.cuh"
5
  #include "vecdotq.cuh"
6
 
7
  #include <cstdint>
 
54
  template<typename T, int D>
55
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
56
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
57
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A
58
 
59
  const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
60
  GGML_UNUSED(Q_v);
 
96
  GGML_UNUSED(Q_q8);
97
  GGML_UNUSED(Q_ds_v);
98
  NO_DEVICE_CODE;
99
+ #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
100
  }
101
 
102
  template<typename T, int D>
103
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
104
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
105
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A
106
 
107
  const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
108
  GGML_UNUSED(Q_v);
 
148
  GGML_UNUSED(Q_q8);
149
  GGML_UNUSED(Q_ds_v);
150
  NO_DEVICE_CODE;
151
+ #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
152
  }
153
 
154
  template<typename T, int D>
155
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
156
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
157
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A
158
 
159
  const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
160
  GGML_UNUSED(Q_v);
 
203
  GGML_UNUSED(Q_q8);
204
  GGML_UNUSED(Q_ds_v);
205
  NO_DEVICE_CODE;
206
+ #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
207
  }
208
 
209
  template<typename T, int D>
210
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
211
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
212
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A
213
 
214
  const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
215
  GGML_UNUSED(Q_v);
 
262
  GGML_UNUSED(Q_q8);
263
  GGML_UNUSED(Q_ds_v);
264
  NO_DEVICE_CODE;
265
+ #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
266
  }
267
 
268
  template <typename T, int D>
269
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
270
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
271
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A
272
 
273
  const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
274
  GGML_UNUSED(Q_v);
 
303
  GGML_UNUSED(Q_q8);
304
  GGML_UNUSED(Q_ds_v);
305
  NO_DEVICE_CODE;
306
+ #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
307
  }
308
 
309
  template <typename T, int D>
 
621
  }
622
 
623
  template <int D, int parallel_blocks>
624
+ void launch_fattn(
625
+ ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
626
+ const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
627
+ ) {
628
  const ggml_tensor * Q = dst->src[0];
629
  const ggml_tensor * K = dst->src[1];
630
  const ggml_tensor * V = dst->src[2];
 
645
  ggml_cuda_pool & pool = ctx.pool();
646
  cudaStream_t main_stream = ctx.stream();
647
 
648
+ ggml_cuda_pool_alloc<half> K_f16(pool);
649
+ ggml_cuda_pool_alloc<half> V_f16(pool);
650
  ggml_cuda_pool_alloc<float> dst_tmp(pool);
651
  ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
652
 
653
+ char * K_data = (char *) K->data;
654
+ size_t nb11 = K->nb[1];
655
+ size_t nb12 = K->nb[2];
656
+ size_t nb13 = K->nb[3];
657
+
658
+ char * V_data = (char *) V->data;
659
+ size_t nb21 = V->nb[1];
660
+ size_t nb22 = V->nb[2];
661
+ size_t nb23 = V->nb[3];
662
+
663
+ if (need_f16_K && K->type != GGML_TYPE_F16) {
664
+ K_f16.alloc(ggml_nelements(K));
665
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
666
+ to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
667
+ K_data = (char *) K_f16.ptr;
668
+
669
+ const size_t bs = ggml_blck_size(K->type);
670
+ const size_t ts = ggml_type_size(K->type);
671
+
672
+ nb11 = nb11*bs*sizeof(half)/ts;
673
+ nb12 = nb12*bs*sizeof(half)/ts;
674
+ nb13 = nb13*bs*sizeof(half)/ts;
675
+ }
676
+
677
+ if (need_f16_V && V->type != GGML_TYPE_F16) {
678
+ V_f16.alloc(ggml_nelements(V));
679
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
680
+ to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
681
+ V_data = (char *) V_f16.ptr;
682
+
683
+ const size_t bs = ggml_blck_size(V->type);
684
+ const size_t ts = ggml_type_size(V->type);
685
+
686
+ nb21 = nb21*bs*sizeof(half)/ts;
687
+ nb22 = nb22*bs*sizeof(half)/ts;
688
+ nb23 = nb23*bs*sizeof(half)/ts;
689
+ }
690
+
691
  if (parallel_blocks > 1) {
692
  dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
693
  dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
 
711
 
712
  fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
713
  (const char *) Q->data,
714
+ K_data,
715
+ V_data,
716
  mask ? ((const char *) mask->data) : nullptr,
717
  (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
718
  scale, max_bias, m0, m1, n_head_log2,
 
720
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
721
  mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
722
  Q->nb[1], Q->nb[2], Q->nb[3],
723
+ nb11, nb12, nb13,
724
+ nb21, nb22, nb23,
725
  KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
726
  );
727
  CUDA_CHECK(cudaGetLastError());
ggml-cuda/fattn-tile-f16.cu CHANGED
@@ -278,13 +278,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
278
  constexpr int D = 64;
279
  constexpr int nwarps = 8;
280
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
281
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
282
  } break;
283
  case 128: {
284
  constexpr int D = 128;
285
  constexpr int nwarps = 8;
286
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
287
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
288
  } break;
289
  default: {
290
  GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
 
278
  constexpr int D = 64;
279
  constexpr int nwarps = 8;
280
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
281
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
282
  } break;
283
  case 128: {
284
  constexpr int D = 128;
285
  constexpr int nwarps = 8;
286
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
287
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
288
  } break;
289
  default: {
290
  GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
ggml-cuda/fattn-tile-f32.cu CHANGED
@@ -275,13 +275,13 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
275
  constexpr int D = 64;
276
  constexpr int nwarps = 8;
277
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
278
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
279
  } break;
280
  case 128: {
281
  constexpr int D = 128;
282
  constexpr int nwarps = 8;
283
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
284
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
285
  } break;
286
  default: {
287
  GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
 
275
  constexpr int D = 64;
276
  constexpr int nwarps = 8;
277
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
278
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
279
  } break;
280
  case 128: {
281
  constexpr int D = 128;
282
  constexpr int nwarps = 8;
283
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
284
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
285
  } break;
286
  default: {
287
  GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -290,7 +290,9 @@ template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml
290
  void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
291
  constexpr int nwarps = D/WARP_SIZE;
292
  fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
293
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
 
 
294
  }
295
 
296
  template <int D, ggml_type type_K, ggml_type type_V>
 
290
  void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
291
  constexpr int nwarps = D/WARP_SIZE;
292
  fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
293
+ constexpr bool need_f16_K = D != 128;
294
+ constexpr bool need_f16_V = D != 128 && D != 64;
295
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
296
  }
297
 
298
  template <int D, ggml_type type_K, ggml_type type_V>
ggml-cuda/fattn-vec-f32.cuh CHANGED
@@ -271,7 +271,9 @@ template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml
271
  void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
272
  constexpr int nwarps = D/WARP_SIZE;
273
  fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
274
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
 
 
275
  }
276
 
277
  template <int D, ggml_type type_K, ggml_type type_V>
 
271
  void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
272
  constexpr int nwarps = D/WARP_SIZE;
273
  fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
274
+ constexpr bool need_f16_K = D != 128;
275
+ constexpr bool need_f16_V = D != 128 && D != 64;
276
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
277
  }
278
 
279
  template <int D, ggml_type type_K, ggml_type type_V>
ggml-cuda/fattn-wmma-f16.cuh CHANGED
@@ -438,18 +438,18 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
438
  if (4*blocks_num_pb1 < 2*nsm) {
439
  constexpr int parallel_blocks = 4;
440
  fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
441
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
442
  return;
443
  }
444
  if (2*blocks_num_pb1 < 2*nsm) {
445
  constexpr int parallel_blocks = 2;
446
  fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
447
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
448
  return;
449
  }
450
  constexpr int parallel_blocks = 1;
451
  fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
452
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
453
  }
454
 
455
  #define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
 
438
  if (4*blocks_num_pb1 < 2*nsm) {
439
  constexpr int parallel_blocks = 4;
440
  fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
441
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
442
  return;
443
  }
444
  if (2*blocks_num_pb1 < 2*nsm) {
445
  constexpr int parallel_blocks = 2;
446
  fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
447
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
448
  return;
449
  }
450
  constexpr int parallel_blocks = 1;
451
  fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
452
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
453
  }
454
 
455
  #define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
ggml-cuda/fattn.cu CHANGED
@@ -298,17 +298,13 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
298
  void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
299
  const ggml_tensor * KQV = dst;
300
  const ggml_tensor * Q = dst->src[0];
301
- const ggml_tensor * K = dst->src[1];
302
- const ggml_tensor * V = dst->src[2];
303
 
304
  ggml_cuda_set_device(ctx.device);
305
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
306
  const int32_t precision = KQV->op_params[2];
307
 
308
- const bool quantized_KV = ggml_is_quantized(K->type) || ggml_is_quantized(V->type);
309
-
310
  // On AMD the tile kernels perform poorly, use the vec kernel instead:
311
- if (cc >= CC_OFFSET_AMD || quantized_KV) {
312
  if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
313
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
314
  } else {
 
298
  void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
299
  const ggml_tensor * KQV = dst;
300
  const ggml_tensor * Q = dst->src[0];
 
 
301
 
302
  ggml_cuda_set_device(ctx.device);
303
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
304
  const int32_t precision = KQV->op_params[2];
305
 
 
 
306
  // On AMD the tile kernels perform poorly, use the vec kernel instead:
307
+ if (cc >= CC_OFFSET_AMD) {
308
  if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
309
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
310
  } else {