Spaces:
Running
Running
Commit
·
47e02a8
1
Parent(s):
f6b0b76
CUDA: broadcasting for FlashAttention mask (llama/14500)
Browse files
ggml/src/ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)(
|
|
| 32 |
const int ne12,
|
| 33 |
const int ne13,
|
| 34 |
const int ne31,
|
|
|
|
| 35 |
const int nb31,
|
|
|
|
| 36 |
const int nb01,
|
| 37 |
const int nb02,
|
| 38 |
const int nb03,
|
|
@@ -851,7 +853,8 @@ void launch_fattn(
|
|
| 851 |
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
| 852 |
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 853 |
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 854 |
-
mask ? mask->ne[1] : 0, mask ?
|
|
|
|
| 855 |
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 856 |
nb11, nb12, nb13,
|
| 857 |
nb21, nb22, nb23,
|
|
|
|
| 32 |
const int ne12,
|
| 33 |
const int ne13,
|
| 34 |
const int ne31,
|
| 35 |
+
const int ne32,
|
| 36 |
const int nb31,
|
| 37 |
+
const int nb32,
|
| 38 |
const int nb01,
|
| 39 |
const int nb02,
|
| 40 |
const int nb03,
|
|
|
|
| 853 |
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
| 854 |
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 855 |
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 856 |
+
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
|
| 857 |
+
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
|
| 858 |
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 859 |
nb11, nb12, nb13,
|
| 860 |
nb21, nb22, nb23,
|
ggml/src/ggml-cuda/fattn-mma-f16.cuh
CHANGED
|
@@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1223 |
const int ne12,
|
| 1224 |
const int ne13,
|
| 1225 |
const int ne31,
|
|
|
|
| 1226 |
const int nb31,
|
|
|
|
| 1227 |
const int nb01,
|
| 1228 |
const int nb02,
|
| 1229 |
const int nb03,
|
|
@@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1288 |
|
| 1289 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
| 1290 |
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
| 1291 |
-
const half2 * mask_h2 = ncols2
|
|
|
|
| 1292 |
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
| 1293 |
|
| 1294 |
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
|
@@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1327 |
|
| 1328 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
| 1329 |
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
| 1330 |
-
const half2 * mask_h2 = ncols2
|
|
|
|
| 1331 |
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
| 1332 |
|
| 1333 |
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
|
@@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1348 |
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
| 1349 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
| 1350 |
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
| 1351 |
-
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
| 1352 |
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
| 1353 |
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
| 1354 |
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
| 1355 |
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
|
|
| 1223 |
const int ne12,
|
| 1224 |
const int ne13,
|
| 1225 |
const int ne31,
|
| 1226 |
+
const int ne32,
|
| 1227 |
const int nb31,
|
| 1228 |
+
const int nb32,
|
| 1229 |
const int nb01,
|
| 1230 |
const int nb02,
|
| 1231 |
const int nb03,
|
|
|
|
| 1290 |
|
| 1291 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
| 1292 |
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
| 1293 |
+
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
| 1294 |
+
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
| 1295 |
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
| 1296 |
|
| 1297 |
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
|
|
|
| 1330 |
|
| 1331 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
| 1332 |
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
| 1333 |
+
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
| 1334 |
+
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
| 1335 |
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
| 1336 |
|
| 1337 |
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
|
|
|
| 1352 |
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
| 1353 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
| 1354 |
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
| 1355 |
+
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
| 1356 |
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
| 1357 |
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
| 1358 |
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
| 1359 |
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
ggml/src/ggml-cuda/fattn-tile-f16.cu
CHANGED
|
@@ -6,7 +6,7 @@
|
|
| 6 |
|
| 7 |
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
| 8 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 9 |
-
__launch_bounds__(nwarps*WARP_SIZE,
|
| 10 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 11 |
static __global__ void flash_attn_tile_ext_f16(
|
| 12 |
const char * __restrict__ Q,
|
|
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 30 |
const int ne12,
|
| 31 |
const int ne13,
|
| 32 |
const int ne31,
|
|
|
|
| 33 |
const int nb31,
|
|
|
|
| 34 |
const int nb01,
|
| 35 |
const int nb02,
|
| 36 |
const int nb03,
|
|
@@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 64 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
| 65 |
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
| 66 |
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
| 67 |
-
const half * maskh = (const half *)
|
| 68 |
|
| 69 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 70 |
|
|
@@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 288 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 289 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
| 290 |
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
| 291 |
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
| 292 |
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 293 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 294 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 295 |
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
|
|
| 6 |
|
| 7 |
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
| 8 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 9 |
+
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
| 10 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 11 |
static __global__ void flash_attn_tile_ext_f16(
|
| 12 |
const char * __restrict__ Q,
|
|
|
|
| 30 |
const int ne12,
|
| 31 |
const int ne13,
|
| 32 |
const int ne31,
|
| 33 |
+
const int ne32,
|
| 34 |
const int nb31,
|
| 35 |
+
const int nb32,
|
| 36 |
const int nb01,
|
| 37 |
const int nb02,
|
| 38 |
const int nb03,
|
|
|
|
| 66 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
| 67 |
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
| 68 |
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
| 69 |
+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
| 70 |
|
| 71 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 72 |
|
|
|
|
| 290 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 291 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
| 292 |
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
| 293 |
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
| 294 |
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 295 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 296 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 297 |
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
ggml/src/ggml-cuda/fattn-tile-f32.cu
CHANGED
|
@@ -6,7 +6,7 @@
|
|
| 6 |
|
| 7 |
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
| 8 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 9 |
-
__launch_bounds__(nwarps*WARP_SIZE,
|
| 10 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 11 |
static __global__ void flash_attn_tile_ext_f32(
|
| 12 |
const char * __restrict__ Q,
|
|
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 30 |
const int ne12,
|
| 31 |
const int ne13,
|
| 32 |
const int ne31,
|
|
|
|
| 33 |
const int nb31,
|
|
|
|
| 34 |
const int nb01,
|
| 35 |
const int nb02,
|
| 36 |
const int nb03,
|
|
@@ -58,8 +60,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 58 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 59 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
| 60 |
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
| 61 |
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
| 62 |
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 63 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 64 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 65 |
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
@@ -76,7 +78,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 76 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
| 77 |
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
| 78 |
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
| 79 |
-
const half * maskh = (const half *)
|
| 80 |
|
| 81 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 82 |
|
|
|
|
| 6 |
|
| 7 |
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
| 8 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 9 |
+
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
| 10 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 11 |
static __global__ void flash_attn_tile_ext_f32(
|
| 12 |
const char * __restrict__ Q,
|
|
|
|
| 30 |
const int ne12,
|
| 31 |
const int ne13,
|
| 32 |
const int ne31,
|
| 33 |
+
const int ne32,
|
| 34 |
const int nb31,
|
| 35 |
+
const int nb32,
|
| 36 |
const int nb01,
|
| 37 |
const int nb02,
|
| 38 |
const int nb03,
|
|
|
|
| 60 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 61 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
| 62 |
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
| 63 |
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
| 64 |
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 65 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 66 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 67 |
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
|
|
| 78 |
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
| 79 |
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
| 80 |
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
| 81 |
+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
| 82 |
|
| 83 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 84 |
|
ggml/src/ggml-cuda/fattn-vec-f16.cuh
CHANGED
|
@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 27 |
const int ne12,
|
| 28 |
const int ne13,
|
| 29 |
const int ne31,
|
|
|
|
| 30 |
const int nb31,
|
|
|
|
| 31 |
const int nb01,
|
| 32 |
const int nb02,
|
| 33 |
const int nb03,
|
|
@@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 68 |
K += nb12*(blockIdx.z / gqa_ratio);
|
| 69 |
V += nb22*(blockIdx.z / gqa_ratio);
|
| 70 |
|
| 71 |
-
const half * maskh = (const half
|
| 72 |
|
| 73 |
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
| 74 |
const half slopeh = __float2half(slopef);
|
|
@@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 342 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 343 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
| 344 |
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
| 345 |
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
| 346 |
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 347 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 348 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 349 |
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
|
|
| 27 |
const int ne12,
|
| 28 |
const int ne13,
|
| 29 |
const int ne31,
|
| 30 |
+
const int ne32,
|
| 31 |
const int nb31,
|
| 32 |
+
const int nb32,
|
| 33 |
const int nb01,
|
| 34 |
const int nb02,
|
| 35 |
const int nb03,
|
|
|
|
| 70 |
K += nb12*(blockIdx.z / gqa_ratio);
|
| 71 |
V += nb22*(blockIdx.z / gqa_ratio);
|
| 72 |
|
| 73 |
+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
| 74 |
|
| 75 |
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
| 76 |
const half slopeh = __float2half(slopef);
|
|
|
|
| 344 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 345 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
| 346 |
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
| 347 |
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
| 348 |
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 349 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 350 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 351 |
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
ggml/src/ggml-cuda/fattn-vec-f32.cuh
CHANGED
|
@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 27 |
const int ne12,
|
| 28 |
const int ne13,
|
| 29 |
const int ne31,
|
|
|
|
| 30 |
const int nb31,
|
|
|
|
| 31 |
const int nb01,
|
| 32 |
const int nb02,
|
| 33 |
const int nb03,
|
|
@@ -51,8 +53,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 51 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 52 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
| 53 |
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
| 54 |
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
| 55 |
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 56 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 57 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 58 |
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
@@ -79,7 +81,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 79 |
Q += nb02* blockIdx.z + nb01*ic0;
|
| 80 |
K += nb12*(blockIdx.z / gqa_ratio);
|
| 81 |
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
|
| 82 |
-
|
|
|
|
| 83 |
|
| 84 |
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
| 85 |
|
|
|
|
| 27 |
const int ne12,
|
| 28 |
const int ne13,
|
| 29 |
const int ne31,
|
| 30 |
+
const int ne32,
|
| 31 |
const int nb31,
|
| 32 |
+
const int nb32,
|
| 33 |
const int nb01,
|
| 34 |
const int nb02,
|
| 35 |
const int nb03,
|
|
|
|
| 53 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 54 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
| 55 |
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
| 56 |
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
| 57 |
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 58 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 59 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 60 |
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
|
|
| 81 |
Q += nb02* blockIdx.z + nb01*ic0;
|
| 82 |
K += nb12*(blockIdx.z / gqa_ratio);
|
| 83 |
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
|
| 84 |
+
|
| 85 |
+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
| 86 |
|
| 87 |
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
| 88 |
|
ggml/src/ggml-cuda/fattn-wmma-f16.cu
CHANGED
|
@@ -46,7 +46,9 @@ static __global__ void flash_attn_ext_f16(
|
|
| 46 |
const int ne12,
|
| 47 |
const int ne13,
|
| 48 |
const int ne31,
|
|
|
|
| 49 |
const int nb31,
|
|
|
|
| 50 |
const int nb01,
|
| 51 |
const int nb02,
|
| 52 |
const int nb03,
|
|
@@ -94,11 +96,11 @@ static __global__ void flash_attn_ext_f16(
|
|
| 94 |
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
| 95 |
|
| 96 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 97 |
-
const float * Q_f = (const float *) (Q
|
| 98 |
-
const half * K_h = (const half *) (K
|
| 99 |
-
const half * V_h = (const half *) (V
|
| 100 |
-
const half * maskh = (const half *)
|
| 101 |
-
const half2 * mask2 = (const half2 *)
|
| 102 |
|
| 103 |
const int stride_Q = nb01 / sizeof(float);
|
| 104 |
const int stride_KV = nb11 / sizeof(half);
|
|
@@ -440,7 +442,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 440 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 441 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
| 442 |
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
| 443 |
-
GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 444 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
| 445 |
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
| 446 |
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
|
|
| 46 |
const int ne12,
|
| 47 |
const int ne13,
|
| 48 |
const int ne31,
|
| 49 |
+
const int ne32,
|
| 50 |
const int nb31,
|
| 51 |
+
const int nb32,
|
| 52 |
const int nb01,
|
| 53 |
const int nb02,
|
| 54 |
const int nb03,
|
|
|
|
| 96 |
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
| 97 |
|
| 98 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 99 |
+
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
|
| 100 |
+
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
|
| 101 |
+
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
| 102 |
+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
| 103 |
+
const half2 * mask2 = (const half2 *) maskh;
|
| 104 |
|
| 105 |
const int stride_Q = nb01 / sizeof(float);
|
| 106 |
const int stride_KV = nb11 / sizeof(half);
|
|
|
|
| 442 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 443 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
| 444 |
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
| 445 |
+
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 446 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
| 447 |
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
| 448 |
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|