JohannesGaessler commited on
Commit
47e02a8
·
1 Parent(s): f6b0b76

CUDA: broadcasting for FlashAttention mask (llama/14500)

Browse files
ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)(
32
  const int ne12,
33
  const int ne13,
34
  const int ne31,
 
35
  const int nb31,
 
36
  const int nb01,
37
  const int nb02,
38
  const int nb03,
@@ -851,7 +853,8 @@ void launch_fattn(
851
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
852
  Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
853
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
854
- mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
 
855
  Q->nb[1], Q->nb[2], Q->nb[3],
856
  nb11, nb12, nb13,
857
  nb21, nb22, nb23,
 
32
  const int ne12,
33
  const int ne13,
34
  const int ne31,
35
+ const int ne32,
36
  const int nb31,
37
+ const int nb32,
38
  const int nb01,
39
  const int nb02,
40
  const int nb03,
 
853
  scale, max_bias, m0, m1, n_head_log2, logit_softcap,
854
  Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
855
  K->ne[0], K->ne[1], K->ne[2], K->ne[3],
856
+ mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
857
+ mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
858
  Q->nb[1], Q->nb[2], Q->nb[3],
859
  nb11, nb12, nb13,
860
  nb21, nb22, nb23,
ggml/src/ggml-cuda/fattn-mma-f16.cuh CHANGED
@@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16(
1223
  const int ne12,
1224
  const int ne13,
1225
  const int ne31,
 
1226
  const int nb31,
 
1227
  const int nb01,
1228
  const int nb02,
1229
  const int nb03,
@@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16(
1288
 
1289
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1290
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1291
- const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
 
1292
  float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1293
 
1294
  const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16(
1327
 
1328
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1329
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1330
- const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
 
1331
  float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1332
 
1333
  const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16(
1348
  GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
1349
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
1350
  GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
1351
- GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
1352
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
1353
  GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
1354
  GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
1355
  GGML_UNUSED(ne2); GGML_UNUSED(ne3);
 
1223
  const int ne12,
1224
  const int ne13,
1225
  const int ne31,
1226
+ const int ne32,
1227
  const int nb31,
1228
+ const int nb32,
1229
  const int nb01,
1230
  const int nb02,
1231
  const int nb03,
 
1290
 
1291
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1292
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1293
+ const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1294
+ (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
1295
  float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1296
 
1297
  const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
 
1330
 
1331
  const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1332
  const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1333
+ const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1334
+ (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
1335
  float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
1336
 
1337
  const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
 
1352
  GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
1353
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
1354
  GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
1355
+ GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
1356
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
1357
  GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
1358
  GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
1359
  GGML_UNUSED(ne2); GGML_UNUSED(ne3);
ggml/src/ggml-cuda/fattn-tile-f16.cu CHANGED
@@ -6,7 +6,7 @@
6
 
7
  template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
8
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
- __launch_bounds__(nwarps*WARP_SIZE, 1)
10
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
11
  static __global__ void flash_attn_tile_ext_f16(
12
  const char * __restrict__ Q,
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16(
30
  const int ne12,
31
  const int ne13,
32
  const int ne31,
 
33
  const int nb31,
 
34
  const int nb01,
35
  const int nb02,
36
  const int nb03,
@@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16(
64
  const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
65
  const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
66
  const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
67
- const half * maskh = (const half *) mask + ne11*ic0;
68
 
69
  const int stride_KV2 = nb11 / sizeof(half2);
70
 
@@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16(
288
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
289
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
290
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
291
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
292
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
293
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
294
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
295
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
 
6
 
7
  template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
8
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
+ __launch_bounds__(nwarps*WARP_SIZE, 2)
10
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
11
  static __global__ void flash_attn_tile_ext_f16(
12
  const char * __restrict__ Q,
 
30
  const int ne12,
31
  const int ne13,
32
  const int ne31,
33
+ const int ne32,
34
  const int nb31,
35
+ const int nb32,
36
  const int nb01,
37
  const int nb02,
38
  const int nb03,
 
66
  const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
67
  const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
68
  const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
69
+ const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
70
 
71
  const int stride_KV2 = nb11 / sizeof(half2);
72
 
 
290
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
291
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
292
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
293
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
294
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
295
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
296
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
297
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
ggml/src/ggml-cuda/fattn-tile-f32.cu CHANGED
@@ -6,7 +6,7 @@
6
 
7
  template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
8
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
- __launch_bounds__(nwarps*WARP_SIZE, 1)
10
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
11
  static __global__ void flash_attn_tile_ext_f32(
12
  const char * __restrict__ Q,
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f32(
30
  const int ne12,
31
  const int ne13,
32
  const int ne31,
 
33
  const int nb31,
 
34
  const int nb01,
35
  const int nb02,
36
  const int nb03,
@@ -58,8 +60,8 @@ static __global__ void flash_attn_tile_ext_f32(
58
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
59
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
60
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
61
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
62
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
63
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
64
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
65
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -76,7 +78,7 @@ static __global__ void flash_attn_tile_ext_f32(
76
  const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
77
  const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
78
  const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
79
- const half * maskh = (const half *) mask + ne11*ic0;
80
 
81
  const int stride_KV2 = nb11 / sizeof(half2);
82
 
 
6
 
7
  template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
8
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
+ __launch_bounds__(nwarps*WARP_SIZE, 2)
10
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
11
  static __global__ void flash_attn_tile_ext_f32(
12
  const char * __restrict__ Q,
 
30
  const int ne12,
31
  const int ne13,
32
  const int ne31,
33
+ const int ne32,
34
  const int nb31,
35
+ const int nb32,
36
  const int nb01,
37
  const int nb02,
38
  const int nb03,
 
60
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
61
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
62
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
63
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
64
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
65
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
66
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
67
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
 
78
  const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
79
  const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
80
  const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
81
+ const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
82
 
83
  const int stride_KV2 = nb11 / sizeof(half2);
84
 
ggml/src/ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16(
27
  const int ne12,
28
  const int ne13,
29
  const int ne31,
 
30
  const int nb31,
 
31
  const int nb01,
32
  const int nb02,
33
  const int nb03,
@@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16(
68
  K += nb12*(blockIdx.z / gqa_ratio);
69
  V += nb22*(blockIdx.z / gqa_ratio);
70
 
71
- const half * maskh = (const half *) mask + ne11*ic0;
72
 
73
  const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
74
  const half slopeh = __float2half(slopef);
@@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16(
342
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
343
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
344
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
345
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
346
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
347
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
348
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
349
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
 
27
  const int ne12,
28
  const int ne13,
29
  const int ne31,
30
+ const int ne32,
31
  const int nb31,
32
+ const int nb32,
33
  const int nb01,
34
  const int nb02,
35
  const int nb03,
 
70
  K += nb12*(blockIdx.z / gqa_ratio);
71
  V += nb22*(blockIdx.z / gqa_ratio);
72
 
73
+ const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
74
 
75
  const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
76
  const half slopeh = __float2half(slopef);
 
344
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
345
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
346
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
347
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
348
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
349
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
350
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
351
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
ggml/src/ggml-cuda/fattn-vec-f32.cuh CHANGED
@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f32(
27
  const int ne12,
28
  const int ne13,
29
  const int ne31,
 
30
  const int nb31,
 
31
  const int nb01,
32
  const int nb02,
33
  const int nb03,
@@ -51,8 +53,8 @@ static __global__ void flash_attn_vec_ext_f32(
51
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
52
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
53
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
54
- GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
55
- GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
56
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
57
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
58
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -79,7 +81,8 @@ static __global__ void flash_attn_vec_ext_f32(
79
  Q += nb02* blockIdx.z + nb01*ic0;
80
  K += nb12*(blockIdx.z / gqa_ratio);
81
  V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
82
- const half * maskh = (const half *) mask + ne11*ic0;
 
83
 
84
  const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
85
 
 
27
  const int ne12,
28
  const int ne13,
29
  const int ne31,
30
+ const int ne32,
31
  const int nb31,
32
+ const int nb32,
33
  const int nb01,
34
  const int nb02,
35
  const int nb03,
 
53
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
54
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
55
  GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
56
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
57
+ GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
58
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
59
  GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
60
  GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
 
81
  Q += nb02* blockIdx.z + nb01*ic0;
82
  K += nb12*(blockIdx.z / gqa_ratio);
83
  V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
84
+
85
+ const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
86
 
87
  const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
88
 
ggml/src/ggml-cuda/fattn-wmma-f16.cu CHANGED
@@ -46,7 +46,9 @@ static __global__ void flash_attn_ext_f16(
46
  const int ne12,
47
  const int ne13,
48
  const int ne31,
 
49
  const int nb31,
 
50
  const int nb01,
51
  const int nb02,
52
  const int nb03,
@@ -94,11 +96,11 @@ static __global__ void flash_attn_ext_f16(
94
  constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
95
 
96
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
97
- const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
98
- const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
99
- const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
100
- const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
101
- const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
102
 
103
  const int stride_Q = nb01 / sizeof(float);
104
  const int stride_KV = nb11 / sizeof(half);
@@ -440,7 +442,7 @@ static __global__ void flash_attn_ext_f16(
440
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
441
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
442
  GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
443
- GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
444
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
445
  GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
446
  GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
 
46
  const int ne12,
47
  const int ne13,
48
  const int ne31,
49
+ const int ne32,
50
  const int nb31,
51
+ const int nb32,
52
  const int nb01,
53
  const int nb02,
54
  const int nb03,
 
96
  constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
97
 
98
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
99
+ const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
100
+ const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
101
+ const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
102
+ const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
103
+ const half2 * mask2 = (const half2 *) maskh;
104
 
105
  const int stride_Q = nb01 / sizeof(float);
106
  const int stride_KV = nb11 / sizeof(half);
 
442
  GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
443
  GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
444
  GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
445
+ GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
446
  GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
447
  GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
448
  GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);