ggerganov JohannesGaessler commited on
Commit
b2d73a2
·
1 Parent(s): 7073590

llama : add high-throughput mode (llama/14363)

Browse files

* kv-cache : prepare K/V buffers for separation

ggml-ci

* batched-bench : fix oob write

ggml-ci

* llama : add "virtual sequences"

ggml-ci

* llama : use "stream" vs "virtual sequence"

ggml-ci

* graph : fix stream splitting when KV cache is not used

ggml-ci

* kv-cache : add multi-stream save/load support

ggml-ci

* llama : add "--attn-streams" flag

ggml-ci

* kv-cache : fix handling when find_slot fails

ggml-ci

* kv-cache : restore find_slot impl

ggml-ci

* kv-cache : add comments

* kv-cache : add bounds checks for sequence id

ggml-ci

* cont : add n_seq_max to batch allocr

ggml-ci

* kv-cache : perform stream copies lazily after llama_synchronize

ggml-ci

* kv-cache : avoid throwing exceptions across the C boundary

ggml-ci

* CUDA: 4D FlashAttention support (llama/14628)

* CUDA: 4D FlashAttention support

* CUDA: fix WMMA FA kernel

* llama : rename attn_streams -> kv_unified

ggml-ci

* common : rename kv_split -> kv_unified

ggml-ci

---------

Co-authored-by: Johannes Gäßler <[email protected]>

ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
33
  const int ne13,
34
  const int ne31,
35
  const int ne32,
 
36
  const int nb31,
37
  const int nb32,
 
38
  const int nb01,
39
  const int nb02,
40
  const int nb03,
@@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
521
  template<int D, int ncols1, int ncols2> // D == head size
522
  __launch_bounds__(D, 1)
523
  static __global__ void flash_attn_stream_k_fixup(
524
- float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
525
  constexpr int ncols = ncols1*ncols2;
526
 
527
  const int bidx0 = blockIdx.x;
@@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
535
  const int iter_k = ne11 / FATTN_KQ_STRIDE;
536
  const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
537
 
538
- const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
539
- const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
540
 
541
  const bool did_not_have_any_data = kbc0 == kbc0_stop;
542
  const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
545
  return;
546
  }
547
 
548
- const int channel = kbc0 / (iter_k*iter_j);
549
- const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
 
550
 
551
  if (jt*ncols1 + j >= ne01) {
552
  return;
553
  }
554
 
555
- dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
556
 
557
  // Load the partial result that needs a fixup:
558
  float dst_val = 0.0f;
@@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
571
  int bidx = bidx0 - 1;
572
  int kbc_stop = kbc0;
573
  while(true) {
574
- const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
575
  if (kbc == kbc_stop) { // Did not have any data.
576
  bidx--;
577
  kbc_stop = kbc;
@@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
617
  const float2 * __restrict__ VKQ_meta,
618
  float * __restrict__ dst,
619
  const int parallel_blocks) {
620
- VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
621
- VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
622
- dst += D * gridDim.z*blockIdx.x;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
 
624
  const int tid = threadIdx.x;
625
  __builtin_assume(tid < D);
626
 
627
  extern __shared__ float2 meta[];
628
  for (int i = tid; i < 2*parallel_blocks; i += D) {
629
- ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
630
  }
631
 
632
  __syncthreads();
@@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
644
  const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
645
  *((uint32_t *) &KQ_max_scale) &= ftz_mask;
646
 
647
- VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
648
  VKQ_denominator += KQ_max_scale * meta[l].y;
649
  }
650
 
651
- dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
652
  }
653
 
654
  [[noreturn]]
@@ -705,8 +723,6 @@ void launch_fattn(
705
 
706
  GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
707
 
708
- GGML_ASSERT(Q->ne[3] == 1);
709
-
710
  ggml_cuda_pool & pool = ctx.pool();
711
  cudaStream_t main_stream = ctx.stream();
712
  const int id = ggml_cuda_get_device();
@@ -853,8 +869,8 @@ void launch_fattn(
853
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
854
  Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
855
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
856
- mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
857
- mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
858
  Q->nb[1], Q->nb[2], Q->nb[3],
859
  nb11, nb12, nb13,
860
  nb21, nb22, nb23,
@@ -869,11 +885,11 @@ void launch_fattn(
869
 
870
  flash_attn_stream_k_fixup<DV, ncols1, ncols2>
871
  <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
872
- ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
873
  }
874
  } else if (parallel_blocks > 1) {
875
  const dim3 block_dim_combine(DV, 1, 1);
876
- const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
877
  const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
878
 
879
  flash_attn_combine_results<DV>
 
33
  const int ne13,
34
  const int ne31,
35
  const int ne32,
36
+ const int ne33,
37
  const int nb31,
38
  const int nb32,
39
+ const int nb33,
40
  const int nb01,
41
  const int nb02,
42
  const int nb03,
 
523
  template<int D, int ncols1, int ncols2> // D == head size
524
  __launch_bounds__(D, 1)
525
  static __global__ void flash_attn_stream_k_fixup(
526
+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
527
  constexpr int ncols = ncols1*ncols2;
528
 
529
  const int bidx0 = blockIdx.x;
 
537
  const int iter_k = ne11 / FATTN_KQ_STRIDE;
538
  const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
539
 
540
+ const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
541
+ const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
542
 
543
  const bool did_not_have_any_data = kbc0 == kbc0_stop;
544
  const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
 
547
  return;
548
  }
549
 
550
+ const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
551
+ const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
552
+ const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
553
 
554
  if (jt*ncols1 + j >= ne01) {
555
  return;
556
  }
557
 
558
+ dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
559
 
560
  // Load the partial result that needs a fixup:
561
  float dst_val = 0.0f;
 
574
  int bidx = bidx0 - 1;
575
  int kbc_stop = kbc0;
576
  while(true) {
577
+ const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
578
  if (kbc == kbc_stop) { // Did not have any data.
579
  bidx--;
580
  kbc_stop = kbc;
 
620
  const float2 * __restrict__ VKQ_meta,
621
  float * __restrict__ dst,
622
  const int parallel_blocks) {
623
+ // Dimension 0: threadIdx.x
624
+ // Dimension 1: blockIdx.x
625
+ // Dimension 2: blockIdx.y
626
+ // Dimension 3: blockIdx.z
627
+ // Memory layout is permuted with [0, 2, 1, 3]
628
+
629
+ const int ne01 = gridDim.x;
630
+ const int ne02 = gridDim.y;
631
+
632
+ const int col = blockIdx.x;
633
+ const int head = blockIdx.y;
634
+ const int sequence = blockIdx.z;
635
+
636
+ const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
637
+
638
+ VKQ_parts += j_dst_unrolled * parallel_blocks*D;
639
+ VKQ_meta += j_dst_unrolled * parallel_blocks;
640
+ dst += j_dst_unrolled * D;
641
 
642
  const int tid = threadIdx.x;
643
  __builtin_assume(tid < D);
644
 
645
  extern __shared__ float2 meta[];
646
  for (int i = tid; i < 2*parallel_blocks; i += D) {
647
+ ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
648
  }
649
 
650
  __syncthreads();
 
662
  const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
663
  *((uint32_t *) &KQ_max_scale) &= ftz_mask;
664
 
665
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
666
  VKQ_denominator += KQ_max_scale * meta[l].y;
667
  }
668
 
669
+ dst[tid] = VKQ_numerator / VKQ_denominator;
670
  }
671
 
672
  [[noreturn]]
 
723
 
724
  GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
725
 
 
 
726
  ggml_cuda_pool & pool = ctx.pool();
727
  cudaStream_t main_stream = ctx.stream();
728
  const int id = ggml_cuda_get_device();
 
869
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
870
  Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
871
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
872
+ mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
873
+ mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
874
  Q->nb[1], Q->nb[2], Q->nb[3],
875
  nb11, nb12, nb13,
876
  nb21, nb22, nb23,
 
885
 
886
  flash_attn_stream_k_fixup<DV, ncols1, ncols2>
887
  <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
888
+ ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
889
  }
890
  } else if (parallel_blocks > 1) {
891
  const dim3 block_dim_combine(DV, 1, 1);
892
+ const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
893
  const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
894
 
895
  flash_attn_combine_results<DV>
ggml/src/ggml-cuda/fattn-mma-f16.cuh CHANGED
@@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16(
1224
  const int ne13,
1225
  const int ne31,
1226
  const int ne32,
 
1227
  const int nb31,
1228
  const int nb32,
 
1229
  const int nb01,
1230
  const int nb02,
1231
  const int nb03,
@@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16(
1274
  constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
1275
 
1276
  // kbc == k block continuous, current index in continuous ijk space.
1277
- int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
1278
- const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
1279
 
1280
  // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
1281
  // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16(
1285
  int kb0_start = kbc % iter_k;
1286
  int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
1287
  while (kbc < kbc_stop && kb0_stop == iter_k) {
1288
- const int channel = kbc / (iter_k*iter_j);
1289
- const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
 
1290
 
1291
- const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1292
- const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1293
  const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1294
- (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
1295
- float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1296
 
1297
- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1298
 
1299
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1300
 
1301
  const int kb0_start_kernel = kb0_start * kb_niter;
1302
  const int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16(
1325
  return;
1326
  }
1327
 
1328
- const int channel = kbc / (iter_k*iter_j);
1329
- const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
 
1330
 
1331
- const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1332
- const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1333
  const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1334
- (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
1335
- float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1336
 
1337
- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
1338
 
1339
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
1340
 
1341
  const int kb0_start_kernel = kb0_start * kb_niter;
1342
  const int kb0_stop_kernel = kb0_stop * kb_niter;
 
1224
  const int ne13,
1225
  const int ne31,
1226
  const int ne32,
1227
+ const int ne33,
1228
  const int nb31,
1229
  const int nb32,
1230
+ const int nb33,
1231
  const int nb01,
1232
  const int nb02,
1233
  const int nb03,
 
1276
  constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
1277
 
1278
  // kbc == k block continuous, current index in continuous ijk space.
1279
+ int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1280
+ const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1281
 
1282
  // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
1283
  // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
 
1287
  int kb0_start = kbc % iter_k;
1288
  int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
1289
  while (kbc < kbc_stop && kb0_stop == iter_k) {
1290
+ const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1291
+ const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1292
+ const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
1293
 
1294
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1295
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
1296
  const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1297
+ (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1298
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
1299
 
1300
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
1301
 
1302
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
1303
 
1304
  const int kb0_start_kernel = kb0_start * kb_niter;
1305
  const int kb0_stop_kernel = kb0_stop * kb_niter;
 
1328
  return;
1329
  }
1330
 
1331
+ const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1332
+ const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1333
+ const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
1334
 
1335
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1336
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
1337
  const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1338
+ (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1339
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
1340
 
1341
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
1342
 
1343
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
1344
 
1345
  const int kb0_start_kernel = kb0_start * kb_niter;
1346
  const int kb0_stop_kernel = kb0_stop * kb_niter;
ggml/src/ggml-cuda/fattn-tile-f16.cu CHANGED
@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16(
31
  const int ne13,
32
  const int ne31,
33
  const int ne32,
 
34
  const int nb31,
35
  const int nb32,
 
36
  const int nb01,
37
  const int nb02,
38
  const int nb03,
@@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16(
62
 
63
  const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
64
 
 
 
65
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
66
- const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
67
- const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
68
- const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
69
- const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
70
 
71
  const int stride_KV2 = nb11 / sizeof(half2);
72
 
73
- const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
74
  const half slopeh = __float2half(slopef);
75
 
76
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16(
255
  __syncthreads();
256
  }
257
 
 
 
258
  #pragma unroll
259
  for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
260
  const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16(
266
  half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
267
  kqsum_j = warp_reduce_sum((float)kqsum_j);
268
 
 
 
269
  #pragma unroll
270
- for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
271
- const int i0 = i00 + 2*threadIdx.x;
272
 
273
- half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
274
  if (gridDim.y == 1) {
275
  dst_val /= __half2half2(kqsum_j);
276
  }
277
- const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
278
- dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
279
- dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
280
  }
281
 
282
  if (gridDim.y != 1 && threadIdx.x == 0) {
283
- dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
284
  }
285
  }
286
  #else
@@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16(
290
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
291
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
292
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
293
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
294
- GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
295
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
296
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
297
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
 
31
  const int ne13,
32
  const int ne31,
33
  const int ne32,
34
+ const int ne33,
35
  const int nb31,
36
  const int nb32,
37
+ const int nb33,
38
  const int nb01,
39
  const int nb02,
40
  const int nb03,
 
64
 
65
  const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
66
 
67
+ const int sequence = blockIdx.z / ne02;
68
+ const int head = blockIdx.z - sequence*ne02;
69
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
70
+ const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
71
+ const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
72
+ const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
73
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
74
 
75
  const int stride_KV2 = nb11 / sizeof(half2);
76
 
77
+ const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
78
  const half slopeh = __float2half(slopef);
79
 
80
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
 
259
  __syncthreads();
260
  }
261
 
262
+ float2 * dst2 = (float2 *) dst;
263
+
264
  #pragma unroll
265
  for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
266
  const int j_VKQ = j_VKQ_0 + threadIdx.y;
 
272
  half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
273
  kqsum_j = warp_reduce_sum((float)kqsum_j);
274
 
275
+ const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
276
+
277
  #pragma unroll
278
+ for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
279
+ const int i0 = i00 + threadIdx.x;
280
 
281
+ half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
282
  if (gridDim.y == 1) {
283
  dst_val /= __half2half2(kqsum_j);
284
  }
285
+ dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
 
 
286
  }
287
 
288
  if (gridDim.y != 1 && threadIdx.x == 0) {
289
+ dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
290
  }
291
  }
292
  #else
 
296
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
297
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
298
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
299
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
300
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
301
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
302
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
303
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
ggml/src/ggml-cuda/fattn-tile-f32.cu CHANGED
@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32(
31
  const int ne13,
32
  const int ne31,
33
  const int ne32,
 
34
  const int nb31,
35
  const int nb32,
 
36
  const int nb01,
37
  const int nb02,
38
  const int nb03,
@@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32(
74
 
75
  const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
76
 
 
 
77
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
78
- const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
79
- const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
80
- const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
81
- const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
82
 
83
  const int stride_KV2 = nb11 / sizeof(half2);
84
 
85
- const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
86
 
87
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
88
 
@@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32(
265
  __syncthreads();
266
  }
267
 
 
 
268
  #pragma unroll
269
  for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
270
  const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32(
276
  float kqsum_j = kqsum[j_VKQ_0/nwarps];
277
  kqsum_j = warp_reduce_sum(kqsum_j);
278
 
 
 
279
  #pragma unroll
280
- for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
281
- const int i0 = i00 + 2*threadIdx.x;
282
 
283
- float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
284
  if (gridDim.y == 1) {
285
  dst_val.x /= kqsum_j;
286
  dst_val.y /= kqsum_j;
287
  }
288
- const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
289
- dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
290
- dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
291
  }
292
 
293
  if (gridDim.y != 1 && threadIdx.x == 0) {
294
- dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
295
  }
296
  }
297
  #else
 
31
  const int ne13,
32
  const int ne31,
33
  const int ne32,
34
+ const int ne33,
35
  const int nb31,
36
  const int nb32,
37
+ const int nb33,
38
  const int nb01,
39
  const int nb02,
40
  const int nb03,
 
76
 
77
  const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
78
 
79
+ const int sequence = blockIdx.z / ne02;
80
+ const int head = blockIdx.z - sequence*ne02;
81
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
82
+ const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
83
+ const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
84
+ const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
85
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
86
 
87
  const int stride_KV2 = nb11 / sizeof(half2);
88
 
89
+ const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
90
 
91
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
92
 
 
269
  __syncthreads();
270
  }
271
 
272
+ float2 * dst2 = (float2 *) dst;
273
+
274
  #pragma unroll
275
  for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
276
  const int j_VKQ = j_VKQ_0 + threadIdx.y;
 
282
  float kqsum_j = kqsum[j_VKQ_0/nwarps];
283
  kqsum_j = warp_reduce_sum(kqsum_j);
284
 
285
+ const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
286
+
287
  #pragma unroll
288
+ for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
289
+ const int i0 = i00 + threadIdx.x;
290
 
291
+ float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
292
  if (gridDim.y == 1) {
293
  dst_val.x /= kqsum_j;
294
  dst_val.y /= kqsum_j;
295
  }
296
+ dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
 
 
297
  }
298
 
299
  if (gridDim.y != 1 && threadIdx.x == 0) {
300
+ dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
301
  }
302
  }
303
  #else
ggml/src/ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16(
28
  const int ne13,
29
  const int ne31,
30
  const int ne32,
 
31
  const int nb31,
32
  const int nb32,
 
33
  const int nb01,
34
  const int nb02,
35
  const int nb03,
@@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16(
65
 
66
  const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
67
 
 
 
68
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
69
- Q += nb02* blockIdx.z + nb01*ic0;
70
- K += nb12*(blockIdx.z / gqa_ratio);
71
- V += nb22*(blockIdx.z / gqa_ratio);
72
 
73
- const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
74
 
75
- const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
76
  const half slopeh = __float2half(slopef);
77
 
78
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16(
330
  if (gridDim.y == 1) {
331
  dst_val /= kqsum[j_VKQ];
332
  }
333
- const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
334
- dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
335
  }
336
 
337
  if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
338
- dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
339
  }
340
  #else
341
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16(
344
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
345
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
346
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
347
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
348
- GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
349
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
350
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
351
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
 
28
  const int ne13,
29
  const int ne31,
30
  const int ne32,
31
+ const int ne33,
32
  const int nb31,
33
  const int nb32,
34
+ const int nb33,
35
  const int nb01,
36
  const int nb02,
37
  const int nb03,
 
67
 
68
  const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
69
 
70
+ const int sequence = blockIdx.z / ne02;
71
+ const int head = blockIdx.z - sequence*ne02;
72
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
73
+ Q += nb03*sequence + nb02* head + nb01*ic0;
74
+ K += nb13*sequence + nb12*(head / gqa_ratio);
75
+ V += nb23*sequence + nb22*(head / gqa_ratio);
76
 
77
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
78
 
79
+ const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
80
  const half slopeh = __float2half(slopef);
81
 
82
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
 
334
  if (gridDim.y == 1) {
335
  dst_val /= kqsum[j_VKQ];
336
  }
337
+ dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
 
338
  }
339
 
340
  if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
341
+ dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
342
  }
343
  #else
344
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
 
347
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
348
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
349
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
350
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32);
351
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
352
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
353
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
354
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
ggml/src/ggml-cuda/fattn-vec-f32.cuh CHANGED
@@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32(
28
  const int ne13,
29
  const int ne31,
30
  const int ne32,
 
31
  const int nb31,
32
  const int nb32,
 
33
  const int nb01,
34
  const int nb02,
35
  const int nb03,
@@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32(
53
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
54
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
55
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
56
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
57
- GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
58
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
59
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
60
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32(
77
 
78
  const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
79
 
 
 
80
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
81
- Q += nb02* blockIdx.z + nb01*ic0;
82
- K += nb12*(blockIdx.z / gqa_ratio);
83
- V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
84
 
85
- const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
86
 
87
- const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
88
 
89
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
90
  constexpr int nwarps = D / WARP_SIZE;
@@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32(
326
  if (gridDim.y == 1) {
327
  dst_val /= kqsum[j_VKQ];
328
  }
329
- const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
330
- dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
331
  }
332
 
333
  if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
334
- dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
335
  }
336
  #else
337
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32(
340
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
341
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
342
  GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
343
- GGML_UNUSED(ne31); GGML_UNUSED(ne32);
344
- GGML_UNUSED(nb31); GGML_UNUSED(nb32);
345
  GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
346
  GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
347
  GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
 
28
  const int ne13,
29
  const int ne31,
30
  const int ne32,
31
+ const int ne33,
32
  const int nb31,
33
  const int nb32,
34
+ const int nb33,
35
  const int nb01,
36
  const int nb02,
37
  const int nb03,
 
55
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
56
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
57
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
58
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
59
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
60
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
61
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
62
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
 
79
 
80
  const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
81
 
82
+ const int sequence = blockIdx.z / ne02;
83
+ const int head = blockIdx.z - sequence*ne02;
84
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
85
+ Q += nb03*sequence + nb02* head + nb01*ic0;
86
+ K += nb13*sequence + nb12*(head / gqa_ratio);
87
+ V += nb23*sequence + nb22*(head / gqa_ratio);
88
 
89
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
90
 
91
+ const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
92
 
93
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
94
  constexpr int nwarps = D / WARP_SIZE;
 
330
  if (gridDim.y == 1) {
331
  dst_val /= kqsum[j_VKQ];
332
  }
333
+ dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
 
334
  }
335
 
336
  if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
337
+ dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
338
  }
339
  #else
340
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
 
343
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
344
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
345
  GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
346
+ GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
347
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
348
  GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
349
  GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
350
  GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
ggml/src/ggml-cuda/fattn-wmma-f16.cu CHANGED
@@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16(
47
  const int ne13,
48
  const int ne31,
49
  const int ne32,
 
50
  const int nb31,
51
  const int nb32,
 
52
  const int nb01,
53
  const int nb02,
54
  const int nb03,
@@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16(
95
  constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
96
  constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
97
 
 
 
98
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
99
- const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
100
- const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
101
- const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
102
- const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
103
  const half2 * mask2 = (const half2 *) maskh;
104
 
105
  const int stride_Q = nb01 / sizeof(float);
106
  const int stride_KV = nb11 / sizeof(half);
107
 
108
- const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
109
  const half slopeh = __float2half(slopef);
110
  const half2 slope2 = make_half2(slopef, slopef);
111
 
@@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
400
  if (ic0 + j_VKQ >= ne01) {
401
  return;
402
  }
403
- const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
404
 
405
  float KQ_rowsum_j;
406
  if (std::is_same<KQ_acc_t, float>::value) {
@@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
409
  KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
410
  }
411
 
 
 
412
  #pragma unroll
413
  for (int i0 = 0; i0 < D; i0 += warp_size) {
414
  const int i = i0 + threadIdx.x;
@@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
419
  if (gridDim.y == 1) {
420
  dst_val /= KQ_rowsum_j;
421
  }
422
- dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
423
  }
424
 
425
  if (gridDim.y == 1 || threadIdx.x != 0) {
@@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
433
  dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
434
  }
435
  dst_meta_val.y = KQ_rowsum_j;
436
- dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
437
  }
438
  #else
439
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16(
442
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
443
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
444
  GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
445
- GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
 
446
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
447
  GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
448
  GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
 
47
  const int ne13,
48
  const int ne31,
49
  const int ne32,
50
+ const int ne33,
51
  const int nb31,
52
  const int nb32,
53
+ const int nb33,
54
  const int nb01,
55
  const int nb02,
56
  const int nb03,
 
97
  constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
98
  constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
99
 
100
+ const int sequence = blockIdx.z / ne02;
101
+ const int head = blockIdx.z - sequence*ne02;
102
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
103
+ const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
104
+ const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
105
+ const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
106
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
107
  const half2 * mask2 = (const half2 *) maskh;
108
 
109
  const int stride_Q = nb01 / sizeof(float);
110
  const int stride_KV = nb11 / sizeof(half);
111
 
112
+ const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
113
  const half slopeh = __float2half(slopef);
114
  const half2 slope2 = make_half2(slopef, slopef);
115
 
 
404
  if (ic0 + j_VKQ >= ne01) {
405
  return;
406
  }
 
407
 
408
  float KQ_rowsum_j;
409
  if (std::is_same<KQ_acc_t, float>::value) {
 
412
  KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
413
  }
414
 
415
+ const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
416
+
417
  #pragma unroll
418
  for (int i0 = 0; i0 < D; i0 += warp_size) {
419
  const int i = i0 + threadIdx.x;
 
424
  if (gridDim.y == 1) {
425
  dst_val /= KQ_rowsum_j;
426
  }
427
+ dst[j_dst_unrolled*D + i] = dst_val;
428
  }
429
 
430
  if (gridDim.y == 1 || threadIdx.x != 0) {
 
438
  dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
439
  }
440
  dst_meta_val.y = KQ_rowsum_j;
441
+ dst_meta[j_dst_unrolled] = dst_meta_val;
442
  }
443
  #else
444
  GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
 
447
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
448
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
449
  GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
450
+ GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
451
+ GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
452
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
453
  GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
454
  GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3413,12 +3413,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3413
  if (op->src[0]->ne[0] == 192) {
3414
  return false;
3415
  }
3416
- // TODO: support broadcast
3417
- // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
3418
- // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
3419
- if (op->src[0]->ne[3] != 1) {
3420
- return false;
3421
- }
3422
  if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
3423
  return false;
3424
  }
@@ -3431,6 +3425,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3431
  if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
3432
  return true;
3433
  }
 
 
 
3434
  return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
3435
  op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
3436
  }
 
3413
  if (op->src[0]->ne[0] == 192) {
3414
  return false;
3415
  }
 
 
 
 
 
 
3416
  if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
3417
  return false;
3418
  }
 
3425
  if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
3426
  return true;
3427
  }
3428
+ if (op->src[3] && op->src[3]->ne[2] != 1) {
3429
+ return false;
3430
+ }
3431
  return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
3432
  op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
3433
  }