Spaces:
Sleeping
llama : add high-throughput mode (llama/14363)
Browse files* kv-cache : prepare K/V buffers for separation
ggml-ci
* batched-bench : fix oob write
ggml-ci
* llama : add "virtual sequences"
ggml-ci
* llama : use "stream" vs "virtual sequence"
ggml-ci
* graph : fix stream splitting when KV cache is not used
ggml-ci
* kv-cache : add multi-stream save/load support
ggml-ci
* llama : add "--attn-streams" flag
ggml-ci
* kv-cache : fix handling when find_slot fails
ggml-ci
* kv-cache : restore find_slot impl
ggml-ci
* kv-cache : add comments
* kv-cache : add bounds checks for sequence id
ggml-ci
* cont : add n_seq_max to batch allocr
ggml-ci
* kv-cache : perform stream copies lazily after llama_synchronize
ggml-ci
* kv-cache : avoid throwing exceptions across the C boundary
ggml-ci
* CUDA: 4D FlashAttention support (llama/14628)
* CUDA: 4D FlashAttention support
* CUDA: fix WMMA FA kernel
* llama : rename attn_streams -> kv_unified
ggml-ci
* common : rename kv_split -> kv_unified
ggml-ci
---------
Co-authored-by: Johannes Gäßler <[email protected]>
- ggml/src/ggml-cuda/fattn-common.cuh +35 -19
- ggml/src/ggml-cuda/fattn-mma-f16.cuh +22 -18
- ggml/src/ggml-cuda/fattn-tile-f16.cu +20 -14
- ggml/src/ggml-cuda/fattn-tile-f32.cu +18 -12
- ggml/src/ggml-cuda/fattn-vec-f16.cuh +13 -10
- ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -12
- ggml/src/ggml-cuda/fattn-wmma-f16.cu +15 -9
- ggml/src/ggml-cuda/ggml-cuda.cu +3 -6
|
@@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
|
|
| 33 |
const int ne13,
|
| 34 |
const int ne31,
|
| 35 |
const int ne32,
|
|
|
|
| 36 |
const int nb31,
|
| 37 |
const int nb32,
|
|
|
|
| 38 |
const int nb01,
|
| 39 |
const int nb02,
|
| 40 |
const int nb03,
|
|
@@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
|
| 521 |
template<int D, int ncols1, int ncols2> // D == head size
|
| 522 |
__launch_bounds__(D, 1)
|
| 523 |
static __global__ void flash_attn_stream_k_fixup(
|
| 524 |
-
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
| 525 |
constexpr int ncols = ncols1*ncols2;
|
| 526 |
|
| 527 |
const int bidx0 = blockIdx.x;
|
|
@@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
| 535 |
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
| 536 |
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
| 537 |
|
| 538 |
-
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
| 539 |
-
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
| 540 |
|
| 541 |
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
| 542 |
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
|
@@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
| 545 |
return;
|
| 546 |
}
|
| 547 |
|
| 548 |
-
const int
|
| 549 |
-
const int
|
|
|
|
| 550 |
|
| 551 |
if (jt*ncols1 + j >= ne01) {
|
| 552 |
return;
|
| 553 |
}
|
| 554 |
|
| 555 |
-
dst += jt*ne02*(ncols1*D) +
|
| 556 |
|
| 557 |
// Load the partial result that needs a fixup:
|
| 558 |
float dst_val = 0.0f;
|
|
@@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
| 571 |
int bidx = bidx0 - 1;
|
| 572 |
int kbc_stop = kbc0;
|
| 573 |
while(true) {
|
| 574 |
-
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
| 575 |
if (kbc == kbc_stop) { // Did not have any data.
|
| 576 |
bidx--;
|
| 577 |
kbc_stop = kbc;
|
|
@@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
|
|
| 617 |
const float2 * __restrict__ VKQ_meta,
|
| 618 |
float * __restrict__ dst,
|
| 619 |
const int parallel_blocks) {
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
|
| 624 |
const int tid = threadIdx.x;
|
| 625 |
__builtin_assume(tid < D);
|
| 626 |
|
| 627 |
extern __shared__ float2 meta[];
|
| 628 |
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
| 629 |
-
((float *) meta)[i] = ((const float *)VKQ_meta) [
|
| 630 |
}
|
| 631 |
|
| 632 |
__syncthreads();
|
|
@@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
|
|
| 644 |
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
| 645 |
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
| 646 |
|
| 647 |
-
VKQ_numerator += KQ_max_scale * VKQ_parts[l*
|
| 648 |
VKQ_denominator += KQ_max_scale * meta[l].y;
|
| 649 |
}
|
| 650 |
|
| 651 |
-
dst[
|
| 652 |
}
|
| 653 |
|
| 654 |
[[noreturn]]
|
|
@@ -705,8 +723,6 @@ void launch_fattn(
|
|
| 705 |
|
| 706 |
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
| 707 |
|
| 708 |
-
GGML_ASSERT(Q->ne[3] == 1);
|
| 709 |
-
|
| 710 |
ggml_cuda_pool & pool = ctx.pool();
|
| 711 |
cudaStream_t main_stream = ctx.stream();
|
| 712 |
const int id = ggml_cuda_get_device();
|
|
@@ -853,8 +869,8 @@ void launch_fattn(
|
|
| 853 |
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
| 854 |
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 855 |
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 856 |
-
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
|
| 857 |
-
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
|
| 858 |
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 859 |
nb11, nb12, nb13,
|
| 860 |
nb21, nb22, nb23,
|
|
@@ -869,11 +885,11 @@ void launch_fattn(
|
|
| 869 |
|
| 870 |
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
| 871 |
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
| 872 |
-
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
| 873 |
}
|
| 874 |
} else if (parallel_blocks > 1) {
|
| 875 |
const dim3 block_dim_combine(DV, 1, 1);
|
| 876 |
-
const dim3 blocks_num_combine(Q->ne[1],
|
| 877 |
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
| 878 |
|
| 879 |
flash_attn_combine_results<DV>
|
|
|
|
| 33 |
const int ne13,
|
| 34 |
const int ne31,
|
| 35 |
const int ne32,
|
| 36 |
+
const int ne33,
|
| 37 |
const int nb31,
|
| 38 |
const int nb32,
|
| 39 |
+
const int nb33,
|
| 40 |
const int nb01,
|
| 41 |
const int nb02,
|
| 42 |
const int nb03,
|
|
|
|
| 523 |
template<int D, int ncols1, int ncols2> // D == head size
|
| 524 |
__launch_bounds__(D, 1)
|
| 525 |
static __global__ void flash_attn_stream_k_fixup(
|
| 526 |
+
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
|
| 527 |
constexpr int ncols = ncols1*ncols2;
|
| 528 |
|
| 529 |
const int bidx0 = blockIdx.x;
|
|
|
|
| 537 |
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
| 538 |
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
| 539 |
|
| 540 |
+
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
| 541 |
+
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
| 542 |
|
| 543 |
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
| 544 |
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
|
|
|
| 547 |
return;
|
| 548 |
}
|
| 549 |
|
| 550 |
+
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
|
| 551 |
+
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
| 552 |
+
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
| 553 |
|
| 554 |
if (jt*ncols1 + j >= ne01) {
|
| 555 |
return;
|
| 556 |
}
|
| 557 |
|
| 558 |
+
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
|
| 559 |
|
| 560 |
// Load the partial result that needs a fixup:
|
| 561 |
float dst_val = 0.0f;
|
|
|
|
| 574 |
int bidx = bidx0 - 1;
|
| 575 |
int kbc_stop = kbc0;
|
| 576 |
while(true) {
|
| 577 |
+
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
| 578 |
if (kbc == kbc_stop) { // Did not have any data.
|
| 579 |
bidx--;
|
| 580 |
kbc_stop = kbc;
|
|
|
|
| 620 |
const float2 * __restrict__ VKQ_meta,
|
| 621 |
float * __restrict__ dst,
|
| 622 |
const int parallel_blocks) {
|
| 623 |
+
// Dimension 0: threadIdx.x
|
| 624 |
+
// Dimension 1: blockIdx.x
|
| 625 |
+
// Dimension 2: blockIdx.y
|
| 626 |
+
// Dimension 3: blockIdx.z
|
| 627 |
+
// Memory layout is permuted with [0, 2, 1, 3]
|
| 628 |
+
|
| 629 |
+
const int ne01 = gridDim.x;
|
| 630 |
+
const int ne02 = gridDim.y;
|
| 631 |
+
|
| 632 |
+
const int col = blockIdx.x;
|
| 633 |
+
const int head = blockIdx.y;
|
| 634 |
+
const int sequence = blockIdx.z;
|
| 635 |
+
|
| 636 |
+
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
|
| 637 |
+
|
| 638 |
+
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
|
| 639 |
+
VKQ_meta += j_dst_unrolled * parallel_blocks;
|
| 640 |
+
dst += j_dst_unrolled * D;
|
| 641 |
|
| 642 |
const int tid = threadIdx.x;
|
| 643 |
__builtin_assume(tid < D);
|
| 644 |
|
| 645 |
extern __shared__ float2 meta[];
|
| 646 |
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
| 647 |
+
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
|
| 648 |
}
|
| 649 |
|
| 650 |
__syncthreads();
|
|
|
|
| 662 |
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
| 663 |
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
| 664 |
|
| 665 |
+
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
|
| 666 |
VKQ_denominator += KQ_max_scale * meta[l].y;
|
| 667 |
}
|
| 668 |
|
| 669 |
+
dst[tid] = VKQ_numerator / VKQ_denominator;
|
| 670 |
}
|
| 671 |
|
| 672 |
[[noreturn]]
|
|
|
|
| 723 |
|
| 724 |
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
| 725 |
|
|
|
|
|
|
|
| 726 |
ggml_cuda_pool & pool = ctx.pool();
|
| 727 |
cudaStream_t main_stream = ctx.stream();
|
| 728 |
const int id = ggml_cuda_get_device();
|
|
|
|
| 869 |
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
| 870 |
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 871 |
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 872 |
+
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
| 873 |
+
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
|
| 874 |
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 875 |
nb11, nb12, nb13,
|
| 876 |
nb21, nb22, nb23,
|
|
|
|
| 885 |
|
| 886 |
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
| 887 |
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
| 888 |
+
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
|
| 889 |
}
|
| 890 |
} else if (parallel_blocks > 1) {
|
| 891 |
const dim3 block_dim_combine(DV, 1, 1);
|
| 892 |
+
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
|
| 893 |
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
| 894 |
|
| 895 |
flash_attn_combine_results<DV>
|
|
@@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1224 |
const int ne13,
|
| 1225 |
const int ne31,
|
| 1226 |
const int ne32,
|
|
|
|
| 1227 |
const int nb31,
|
| 1228 |
const int nb32,
|
|
|
|
| 1229 |
const int nb01,
|
| 1230 |
const int nb02,
|
| 1231 |
const int nb03,
|
|
@@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1274 |
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
|
| 1275 |
|
| 1276 |
// kbc == k block continuous, current index in continuous ijk space.
|
| 1277 |
-
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
| 1278 |
-
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
| 1279 |
|
| 1280 |
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
| 1281 |
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
|
@@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1285 |
int kb0_start = kbc % iter_k;
|
| 1286 |
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
| 1287 |
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
| 1288 |
-
const int
|
| 1289 |
-
const int
|
|
|
|
| 1290 |
|
| 1291 |
-
const float2 * Q_f2 = (const float2 *) (Q +
|
| 1292 |
-
const half2 * K_h2 = (const half2 *) (K + nb12*(
|
| 1293 |
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
| 1294 |
-
(const half2 *) (mask +
|
| 1295 |
-
float2 * dstk = ((float2 *) dst) +
|
| 1296 |
|
| 1297 |
-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(
|
| 1298 |
|
| 1299 |
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
| 1300 |
|
| 1301 |
const int kb0_start_kernel = kb0_start * kb_niter;
|
| 1302 |
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
@@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1325 |
return;
|
| 1326 |
}
|
| 1327 |
|
| 1328 |
-
const int
|
| 1329 |
-
const int
|
|
|
|
| 1330 |
|
| 1331 |
-
const float2 * Q_f2 = (const float2 *) (Q +
|
| 1332 |
-
const half2 * K_h2 = (const half2 *) (K + nb12*(
|
| 1333 |
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
| 1334 |
-
(const half2 *) (mask +
|
| 1335 |
-
float2 * dstk = ((float2 *) dst) +
|
| 1336 |
|
| 1337 |
-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(
|
| 1338 |
|
| 1339 |
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
| 1340 |
|
| 1341 |
const int kb0_start_kernel = kb0_start * kb_niter;
|
| 1342 |
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
|
|
| 1224 |
const int ne13,
|
| 1225 |
const int ne31,
|
| 1226 |
const int ne32,
|
| 1227 |
+
const int ne33,
|
| 1228 |
const int nb31,
|
| 1229 |
const int nb32,
|
| 1230 |
+
const int nb33,
|
| 1231 |
const int nb01,
|
| 1232 |
const int nb02,
|
| 1233 |
const int nb03,
|
|
|
|
| 1276 |
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
|
| 1277 |
|
| 1278 |
// kbc == k block continuous, current index in continuous ijk space.
|
| 1279 |
+
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
| 1280 |
+
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
| 1281 |
|
| 1282 |
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
| 1283 |
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
|
|
|
| 1287 |
int kb0_start = kbc % iter_k;
|
| 1288 |
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
| 1289 |
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
| 1290 |
+
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
| 1291 |
+
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
| 1292 |
+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
| 1293 |
|
| 1294 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
|
| 1295 |
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
|
| 1296 |
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
| 1297 |
+
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
| 1298 |
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
|
| 1299 |
|
| 1300 |
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
|
| 1301 |
|
| 1302 |
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
| 1303 |
|
| 1304 |
const int kb0_start_kernel = kb0_start * kb_niter;
|
| 1305 |
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
|
|
| 1328 |
return;
|
| 1329 |
}
|
| 1330 |
|
| 1331 |
+
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
| 1332 |
+
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
| 1333 |
+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
| 1334 |
|
| 1335 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
|
| 1336 |
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
|
| 1337 |
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
| 1338 |
+
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
| 1339 |
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
|
| 1340 |
|
| 1341 |
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
|
| 1342 |
|
| 1343 |
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
| 1344 |
|
| 1345 |
const int kb0_start_kernel = kb0_start * kb_niter;
|
| 1346 |
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 31 |
const int ne13,
|
| 32 |
const int ne31,
|
| 33 |
const int ne32,
|
|
|
|
| 34 |
const int nb31,
|
| 35 |
const int nb32,
|
|
|
|
| 36 |
const int nb01,
|
| 37 |
const int nb02,
|
| 38 |
const int nb03,
|
|
@@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 62 |
|
| 63 |
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
| 64 |
|
|
|
|
|
|
|
| 65 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 66 |
-
const float2 * Q_f2 = (const float2 *) (Q + nb02*
|
| 67 |
-
const half2 * K_h2 = (const half2 *) (K + nb12*(
|
| 68 |
-
const half2 * V_h2 = (const half2 *) (V + nb12*(
|
| 69 |
-
const half * maskh = (const half *) (mask +
|
| 70 |
|
| 71 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 72 |
|
| 73 |
-
const float slopef = get_alibi_slope(max_bias,
|
| 74 |
const half slopeh = __float2half(slopef);
|
| 75 |
|
| 76 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
|
@@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 255 |
__syncthreads();
|
| 256 |
}
|
| 257 |
|
|
|
|
|
|
|
| 258 |
#pragma unroll
|
| 259 |
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
| 260 |
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
|
@@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 266 |
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
|
| 267 |
kqsum_j = warp_reduce_sum((float)kqsum_j);
|
| 268 |
|
|
|
|
|
|
|
| 269 |
#pragma unroll
|
| 270 |
-
for (int i00 = 0; i00 < D; i00 +=
|
| 271 |
-
const int i0 = i00 +
|
| 272 |
|
| 273 |
-
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/
|
| 274 |
if (gridDim.y == 1) {
|
| 275 |
dst_val /= __half2half2(kqsum_j);
|
| 276 |
}
|
| 277 |
-
|
| 278 |
-
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
|
| 279 |
-
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
|
| 280 |
}
|
| 281 |
|
| 282 |
if (gridDim.y != 1 && threadIdx.x == 0) {
|
| 283 |
-
dst_meta[
|
| 284 |
}
|
| 285 |
}
|
| 286 |
#else
|
|
@@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 290 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 291 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
| 292 |
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
| 293 |
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
| 294 |
-
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 295 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 296 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 297 |
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
|
|
| 31 |
const int ne13,
|
| 32 |
const int ne31,
|
| 33 |
const int ne32,
|
| 34 |
+
const int ne33,
|
| 35 |
const int nb31,
|
| 36 |
const int nb32,
|
| 37 |
+
const int nb33,
|
| 38 |
const int nb01,
|
| 39 |
const int nb02,
|
| 40 |
const int nb03,
|
|
|
|
| 64 |
|
| 65 |
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
| 66 |
|
| 67 |
+
const int sequence = blockIdx.z / ne02;
|
| 68 |
+
const int head = blockIdx.z - sequence*ne02;
|
| 69 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 70 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
| 71 |
+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
| 72 |
+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
| 73 |
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
| 74 |
|
| 75 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 76 |
|
| 77 |
+
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
| 78 |
const half slopeh = __float2half(slopef);
|
| 79 |
|
| 80 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
|
|
|
| 259 |
__syncthreads();
|
| 260 |
}
|
| 261 |
|
| 262 |
+
float2 * dst2 = (float2 *) dst;
|
| 263 |
+
|
| 264 |
#pragma unroll
|
| 265 |
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
| 266 |
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
|
|
|
| 272 |
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
|
| 273 |
kqsum_j = warp_reduce_sum((float)kqsum_j);
|
| 274 |
|
| 275 |
+
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
| 276 |
+
|
| 277 |
#pragma unroll
|
| 278 |
+
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
|
| 279 |
+
const int i0 = i00 + threadIdx.x;
|
| 280 |
|
| 281 |
+
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
|
| 282 |
if (gridDim.y == 1) {
|
| 283 |
dst_val /= __half2half2(kqsum_j);
|
| 284 |
}
|
| 285 |
+
dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
|
|
|
|
|
|
|
| 286 |
}
|
| 287 |
|
| 288 |
if (gridDim.y != 1 && threadIdx.x == 0) {
|
| 289 |
+
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
| 290 |
}
|
| 291 |
}
|
| 292 |
#else
|
|
|
|
| 296 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 297 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
| 298 |
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
| 299 |
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
| 300 |
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 301 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 302 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 303 |
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
@@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 31 |
const int ne13,
|
| 32 |
const int ne31,
|
| 33 |
const int ne32,
|
|
|
|
| 34 |
const int nb31,
|
| 35 |
const int nb32,
|
|
|
|
| 36 |
const int nb01,
|
| 37 |
const int nb02,
|
| 38 |
const int nb03,
|
|
@@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 74 |
|
| 75 |
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
| 76 |
|
|
|
|
|
|
|
| 77 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 78 |
-
const float2 * Q_f2 = (const float2 *) (Q + nb02*
|
| 79 |
-
const half2 * K_h2 = (const half2 *) (K + nb12*(
|
| 80 |
-
const half2 * V_h2 = (const half2 *) (V + nb12*(
|
| 81 |
-
const half * maskh = (const half *) (mask +
|
| 82 |
|
| 83 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 84 |
|
| 85 |
-
const float slope = get_alibi_slope(max_bias,
|
| 86 |
|
| 87 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 88 |
|
|
@@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 265 |
__syncthreads();
|
| 266 |
}
|
| 267 |
|
|
|
|
|
|
|
| 268 |
#pragma unroll
|
| 269 |
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
| 270 |
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
|
@@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 276 |
float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
| 277 |
kqsum_j = warp_reduce_sum(kqsum_j);
|
| 278 |
|
|
|
|
|
|
|
| 279 |
#pragma unroll
|
| 280 |
-
for (int i00 = 0; i00 < D; i00 +=
|
| 281 |
-
const int i0 = i00 +
|
| 282 |
|
| 283 |
-
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/
|
| 284 |
if (gridDim.y == 1) {
|
| 285 |
dst_val.x /= kqsum_j;
|
| 286 |
dst_val.y /= kqsum_j;
|
| 287 |
}
|
| 288 |
-
|
| 289 |
-
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
|
| 290 |
-
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
|
| 291 |
}
|
| 292 |
|
| 293 |
if (gridDim.y != 1 && threadIdx.x == 0) {
|
| 294 |
-
dst_meta[
|
| 295 |
}
|
| 296 |
}
|
| 297 |
#else
|
|
|
|
| 31 |
const int ne13,
|
| 32 |
const int ne31,
|
| 33 |
const int ne32,
|
| 34 |
+
const int ne33,
|
| 35 |
const int nb31,
|
| 36 |
const int nb32,
|
| 37 |
+
const int nb33,
|
| 38 |
const int nb01,
|
| 39 |
const int nb02,
|
| 40 |
const int nb03,
|
|
|
|
| 76 |
|
| 77 |
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
| 78 |
|
| 79 |
+
const int sequence = blockIdx.z / ne02;
|
| 80 |
+
const int head = blockIdx.z - sequence*ne02;
|
| 81 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 82 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
| 83 |
+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
| 84 |
+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
| 85 |
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
| 86 |
|
| 87 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 88 |
|
| 89 |
+
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
| 90 |
|
| 91 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 92 |
|
|
|
|
| 269 |
__syncthreads();
|
| 270 |
}
|
| 271 |
|
| 272 |
+
float2 * dst2 = (float2 *) dst;
|
| 273 |
+
|
| 274 |
#pragma unroll
|
| 275 |
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
| 276 |
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
|
|
|
| 282 |
float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
| 283 |
kqsum_j = warp_reduce_sum(kqsum_j);
|
| 284 |
|
| 285 |
+
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
| 286 |
+
|
| 287 |
#pragma unroll
|
| 288 |
+
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
|
| 289 |
+
const int i0 = i00 + threadIdx.x;
|
| 290 |
|
| 291 |
+
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
|
| 292 |
if (gridDim.y == 1) {
|
| 293 |
dst_val.x /= kqsum_j;
|
| 294 |
dst_val.y /= kqsum_j;
|
| 295 |
}
|
| 296 |
+
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
|
|
|
|
|
|
|
| 297 |
}
|
| 298 |
|
| 299 |
if (gridDim.y != 1 && threadIdx.x == 0) {
|
| 300 |
+
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
| 301 |
}
|
| 302 |
}
|
| 303 |
#else
|
|
@@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 28 |
const int ne13,
|
| 29 |
const int ne31,
|
| 30 |
const int ne32,
|
|
|
|
| 31 |
const int nb31,
|
| 32 |
const int nb32,
|
|
|
|
| 33 |
const int nb01,
|
| 34 |
const int nb02,
|
| 35 |
const int nb03,
|
|
@@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 65 |
|
| 66 |
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
| 67 |
|
|
|
|
|
|
|
| 68 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 69 |
-
Q += nb02*
|
| 70 |
-
K += nb12*(
|
| 71 |
-
V += nb22*(
|
| 72 |
|
| 73 |
-
const half * maskh = (const half *) (mask +
|
| 74 |
|
| 75 |
-
const float slopef = get_alibi_slope(max_bias,
|
| 76 |
const half slopeh = __float2half(slopef);
|
| 77 |
|
| 78 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
|
@@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 330 |
if (gridDim.y == 1) {
|
| 331 |
dst_val /= kqsum[j_VKQ];
|
| 332 |
}
|
| 333 |
-
|
| 334 |
-
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
| 335 |
}
|
| 336 |
|
| 337 |
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
| 338 |
-
dst_meta[((ic0 + tid)*
|
| 339 |
}
|
| 340 |
#else
|
| 341 |
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
@@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 344 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 345 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
| 346 |
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
| 347 |
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
| 348 |
-
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 349 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 350 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 351 |
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
|
|
| 28 |
const int ne13,
|
| 29 |
const int ne31,
|
| 30 |
const int ne32,
|
| 31 |
+
const int ne33,
|
| 32 |
const int nb31,
|
| 33 |
const int nb32,
|
| 34 |
+
const int nb33,
|
| 35 |
const int nb01,
|
| 36 |
const int nb02,
|
| 37 |
const int nb03,
|
|
|
|
| 67 |
|
| 68 |
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
| 69 |
|
| 70 |
+
const int sequence = blockIdx.z / ne02;
|
| 71 |
+
const int head = blockIdx.z - sequence*ne02;
|
| 72 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 73 |
+
Q += nb03*sequence + nb02* head + nb01*ic0;
|
| 74 |
+
K += nb13*sequence + nb12*(head / gqa_ratio);
|
| 75 |
+
V += nb23*sequence + nb22*(head / gqa_ratio);
|
| 76 |
|
| 77 |
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
| 78 |
|
| 79 |
+
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
| 80 |
const half slopeh = __float2half(slopef);
|
| 81 |
|
| 82 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
|
|
|
| 334 |
if (gridDim.y == 1) {
|
| 335 |
dst_val /= kqsum[j_VKQ];
|
| 336 |
}
|
| 337 |
+
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
|
|
|
|
| 338 |
}
|
| 339 |
|
| 340 |
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
| 341 |
+
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
| 342 |
}
|
| 343 |
#else
|
| 344 |
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
|
|
| 347 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 348 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
| 349 |
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
| 350 |
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32);
|
| 351 |
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 352 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 353 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 354 |
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
@@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 28 |
const int ne13,
|
| 29 |
const int ne31,
|
| 30 |
const int ne32,
|
|
|
|
| 31 |
const int nb31,
|
| 32 |
const int nb32,
|
|
|
|
| 33 |
const int nb01,
|
| 34 |
const int nb02,
|
| 35 |
const int nb03,
|
|
@@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 53 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 54 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
| 55 |
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
| 56 |
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
| 57 |
-
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 58 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 59 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 60 |
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
@@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 77 |
|
| 78 |
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
| 79 |
|
|
|
|
|
|
|
| 80 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 81 |
-
Q += nb02*
|
| 82 |
-
K += nb12*(
|
| 83 |
-
V += nb22*(
|
| 84 |
|
| 85 |
-
const half * maskh = (const half *) (mask +
|
| 86 |
|
| 87 |
-
const float slope = get_alibi_slope(max_bias,
|
| 88 |
|
| 89 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 90 |
constexpr int nwarps = D / WARP_SIZE;
|
|
@@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 326 |
if (gridDim.y == 1) {
|
| 327 |
dst_val /= kqsum[j_VKQ];
|
| 328 |
}
|
| 329 |
-
|
| 330 |
-
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
| 331 |
}
|
| 332 |
|
| 333 |
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
| 334 |
-
dst_meta[((ic0 + tid)*
|
| 335 |
}
|
| 336 |
#else
|
| 337 |
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
@@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 340 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 341 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
| 342 |
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
| 343 |
-
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
| 344 |
-
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
| 345 |
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
| 346 |
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
| 347 |
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
|
|
| 28 |
const int ne13,
|
| 29 |
const int ne31,
|
| 30 |
const int ne32,
|
| 31 |
+
const int ne33,
|
| 32 |
const int nb31,
|
| 33 |
const int nb32,
|
| 34 |
+
const int nb33,
|
| 35 |
const int nb01,
|
| 36 |
const int nb02,
|
| 37 |
const int nb03,
|
|
|
|
| 55 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 56 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
| 57 |
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
| 58 |
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
| 59 |
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 60 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
| 61 |
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
| 62 |
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
|
|
| 79 |
|
| 80 |
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
| 81 |
|
| 82 |
+
const int sequence = blockIdx.z / ne02;
|
| 83 |
+
const int head = blockIdx.z - sequence*ne02;
|
| 84 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 85 |
+
Q += nb03*sequence + nb02* head + nb01*ic0;
|
| 86 |
+
K += nb13*sequence + nb12*(head / gqa_ratio);
|
| 87 |
+
V += nb23*sequence + nb22*(head / gqa_ratio);
|
| 88 |
|
| 89 |
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
| 90 |
|
| 91 |
+
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
| 92 |
|
| 93 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 94 |
constexpr int nwarps = D / WARP_SIZE;
|
|
|
|
| 330 |
if (gridDim.y == 1) {
|
| 331 |
dst_val /= kqsum[j_VKQ];
|
| 332 |
}
|
| 333 |
+
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
|
|
|
|
| 334 |
}
|
| 335 |
|
| 336 |
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
| 337 |
+
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
| 338 |
}
|
| 339 |
#else
|
| 340 |
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
|
|
| 343 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 344 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
| 345 |
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
| 346 |
+
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
| 347 |
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
| 348 |
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
| 349 |
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
| 350 |
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
@@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16(
|
|
| 47 |
const int ne13,
|
| 48 |
const int ne31,
|
| 49 |
const int ne32,
|
|
|
|
| 50 |
const int nb31,
|
| 51 |
const int nb32,
|
|
|
|
| 52 |
const int nb01,
|
| 53 |
const int nb02,
|
| 54 |
const int nb03,
|
|
@@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16(
|
|
| 95 |
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
| 96 |
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
| 97 |
|
|
|
|
|
|
|
| 98 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 99 |
-
const float * Q_f = (const float *) (Q + nb02*
|
| 100 |
-
const half * K_h = (const half *) (K + nb12*(
|
| 101 |
-
const half * V_h = (const half *) (V + nb12*(
|
| 102 |
-
const half * maskh = (const half *) (mask +
|
| 103 |
const half2 * mask2 = (const half2 *) maskh;
|
| 104 |
|
| 105 |
const int stride_Q = nb01 / sizeof(float);
|
| 106 |
const int stride_KV = nb11 / sizeof(half);
|
| 107 |
|
| 108 |
-
const float slopef = get_alibi_slope(max_bias,
|
| 109 |
const half slopeh = __float2half(slopef);
|
| 110 |
const half2 slope2 = make_half2(slopef, slopef);
|
| 111 |
|
|
@@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
|
|
| 400 |
if (ic0 + j_VKQ >= ne01) {
|
| 401 |
return;
|
| 402 |
}
|
| 403 |
-
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
| 404 |
|
| 405 |
float KQ_rowsum_j;
|
| 406 |
if (std::is_same<KQ_acc_t, float>::value) {
|
|
@@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
|
|
| 409 |
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
| 410 |
}
|
| 411 |
|
|
|
|
|
|
|
| 412 |
#pragma unroll
|
| 413 |
for (int i0 = 0; i0 < D; i0 += warp_size) {
|
| 414 |
const int i = i0 + threadIdx.x;
|
|
@@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 419 |
if (gridDim.y == 1) {
|
| 420 |
dst_val /= KQ_rowsum_j;
|
| 421 |
}
|
| 422 |
-
dst[
|
| 423 |
}
|
| 424 |
|
| 425 |
if (gridDim.y == 1 || threadIdx.x != 0) {
|
|
@@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 433 |
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
| 434 |
}
|
| 435 |
dst_meta_val.y = KQ_rowsum_j;
|
| 436 |
-
dst_meta[
|
| 437 |
}
|
| 438 |
#else
|
| 439 |
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
@@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16(
|
|
| 442 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 443 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
| 444 |
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
| 445 |
-
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(
|
|
|
|
| 446 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
| 447 |
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
| 448 |
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
|
|
| 47 |
const int ne13,
|
| 48 |
const int ne31,
|
| 49 |
const int ne32,
|
| 50 |
+
const int ne33,
|
| 51 |
const int nb31,
|
| 52 |
const int nb32,
|
| 53 |
+
const int nb33,
|
| 54 |
const int nb01,
|
| 55 |
const int nb02,
|
| 56 |
const int nb03,
|
|
|
|
| 97 |
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
| 98 |
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
| 99 |
|
| 100 |
+
const int sequence = blockIdx.z / ne02;
|
| 101 |
+
const int head = blockIdx.z - sequence*ne02;
|
| 102 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 103 |
+
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
| 104 |
+
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
| 105 |
+
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
| 106 |
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
| 107 |
const half2 * mask2 = (const half2 *) maskh;
|
| 108 |
|
| 109 |
const int stride_Q = nb01 / sizeof(float);
|
| 110 |
const int stride_KV = nb11 / sizeof(half);
|
| 111 |
|
| 112 |
+
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
| 113 |
const half slopeh = __float2half(slopef);
|
| 114 |
const half2 slope2 = make_half2(slopef, slopef);
|
| 115 |
|
|
|
|
| 404 |
if (ic0 + j_VKQ >= ne01) {
|
| 405 |
return;
|
| 406 |
}
|
|
|
|
| 407 |
|
| 408 |
float KQ_rowsum_j;
|
| 409 |
if (std::is_same<KQ_acc_t, float>::value) {
|
|
|
|
| 412 |
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
| 413 |
}
|
| 414 |
|
| 415 |
+
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
| 416 |
+
|
| 417 |
#pragma unroll
|
| 418 |
for (int i0 = 0; i0 < D; i0 += warp_size) {
|
| 419 |
const int i = i0 + threadIdx.x;
|
|
|
|
| 424 |
if (gridDim.y == 1) {
|
| 425 |
dst_val /= KQ_rowsum_j;
|
| 426 |
}
|
| 427 |
+
dst[j_dst_unrolled*D + i] = dst_val;
|
| 428 |
}
|
| 429 |
|
| 430 |
if (gridDim.y == 1 || threadIdx.x != 0) {
|
|
|
|
| 438 |
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
| 439 |
}
|
| 440 |
dst_meta_val.y = KQ_rowsum_j;
|
| 441 |
+
dst_meta[j_dst_unrolled] = dst_meta_val;
|
| 442 |
}
|
| 443 |
#else
|
| 444 |
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
|
|
| 447 |
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
| 448 |
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
| 449 |
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
| 450 |
+
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
|
| 451 |
+
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
| 452 |
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
| 453 |
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
| 454 |
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
@@ -3413,12 +3413,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3413 |
if (op->src[0]->ne[0] == 192) {
|
| 3414 |
return false;
|
| 3415 |
}
|
| 3416 |
-
// TODO: support broadcast
|
| 3417 |
-
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
|
| 3418 |
-
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
|
| 3419 |
-
if (op->src[0]->ne[3] != 1) {
|
| 3420 |
-
return false;
|
| 3421 |
-
}
|
| 3422 |
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
| 3423 |
return false;
|
| 3424 |
}
|
|
@@ -3431,6 +3425,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3431 |
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
| 3432 |
return true;
|
| 3433 |
}
|
|
|
|
|
|
|
|
|
|
| 3434 |
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
|
| 3435 |
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
| 3436 |
}
|
|
|
|
| 3413 |
if (op->src[0]->ne[0] == 192) {
|
| 3414 |
return false;
|
| 3415 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3416 |
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
| 3417 |
return false;
|
| 3418 |
}
|
|
|
|
| 3425 |
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
| 3426 |
return true;
|
| 3427 |
}
|
| 3428 |
+
if (op->src[3] && op->src[3]->ne[2] != 1) {
|
| 3429 |
+
return false;
|
| 3430 |
+
}
|
| 3431 |
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
|
| 3432 |
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
| 3433 |
}
|