Spaces:
Sleeping
Sleeping
Commit
·
d4c0faf
1
Parent(s):
315df8c
CUDA: fix Pascal FA, deq. KV to FP16 for batch > 8 (llama/7681)
Browse files- ggml-cuda/fattn-common.cuh +59 -15
- ggml-cuda/fattn-tile-f16.cu +2 -2
- ggml-cuda/fattn-tile-f32.cu +2 -2
- ggml-cuda/fattn-vec-f16.cuh +3 -1
- ggml-cuda/fattn-vec-f32.cuh +3 -1
- ggml-cuda/fattn-wmma-f16.cuh +3 -3
- ggml-cuda/fattn.cu +1 -5
ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include "common.cuh"
|
|
|
|
| 4 |
#include "vecdotq.cuh"
|
| 5 |
|
| 6 |
#include <cstdint>
|
|
@@ -53,7 +54,7 @@ typedef float (*vec_dot_KQ_f32_t)(
|
|
| 53 |
template<typename T, int D>
|
| 54 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
| 55 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 56 |
-
#if __CUDA_ARCH__
|
| 57 |
|
| 58 |
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
| 59 |
GGML_UNUSED(Q_v);
|
|
@@ -95,13 +96,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
|
| 95 |
GGML_UNUSED(Q_q8);
|
| 96 |
GGML_UNUSED(Q_ds_v);
|
| 97 |
NO_DEVICE_CODE;
|
| 98 |
-
#endif // __CUDA_ARCH__
|
| 99 |
}
|
| 100 |
|
| 101 |
template<typename T, int D>
|
| 102 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
| 103 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 104 |
-
#if __CUDA_ARCH__
|
| 105 |
|
| 106 |
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
| 107 |
GGML_UNUSED(Q_v);
|
|
@@ -147,13 +148,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
|
| 147 |
GGML_UNUSED(Q_q8);
|
| 148 |
GGML_UNUSED(Q_ds_v);
|
| 149 |
NO_DEVICE_CODE;
|
| 150 |
-
#endif // __CUDA_ARCH__
|
| 151 |
}
|
| 152 |
|
| 153 |
template<typename T, int D>
|
| 154 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
| 155 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 156 |
-
#if __CUDA_ARCH__
|
| 157 |
|
| 158 |
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
| 159 |
GGML_UNUSED(Q_v);
|
|
@@ -202,13 +203,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
|
| 202 |
GGML_UNUSED(Q_q8);
|
| 203 |
GGML_UNUSED(Q_ds_v);
|
| 204 |
NO_DEVICE_CODE;
|
| 205 |
-
#endif // __CUDA_ARCH__
|
| 206 |
}
|
| 207 |
|
| 208 |
template<typename T, int D>
|
| 209 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
| 210 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 211 |
-
#if __CUDA_ARCH__
|
| 212 |
|
| 213 |
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
| 214 |
GGML_UNUSED(Q_v);
|
|
@@ -261,13 +262,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
|
| 261 |
GGML_UNUSED(Q_q8);
|
| 262 |
GGML_UNUSED(Q_ds_v);
|
| 263 |
NO_DEVICE_CODE;
|
| 264 |
-
#endif // __CUDA_ARCH__
|
| 265 |
}
|
| 266 |
|
| 267 |
template <typename T, int D>
|
| 268 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
| 269 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 270 |
-
#if __CUDA_ARCH__
|
| 271 |
|
| 272 |
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
| 273 |
GGML_UNUSED(Q_v);
|
|
@@ -302,7 +303,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
|
| 302 |
GGML_UNUSED(Q_q8);
|
| 303 |
GGML_UNUSED(Q_ds_v);
|
| 304 |
NO_DEVICE_CODE;
|
| 305 |
-
#endif // __CUDA_ARCH__
|
| 306 |
}
|
| 307 |
|
| 308 |
template <typename T, int D>
|
|
@@ -620,7 +621,10 @@ static void on_no_fattn_vec_case(const int D) {
|
|
| 620 |
}
|
| 621 |
|
| 622 |
template <int D, int parallel_blocks>
|
| 623 |
-
void launch_fattn(
|
|
|
|
|
|
|
|
|
|
| 624 |
const ggml_tensor * Q = dst->src[0];
|
| 625 |
const ggml_tensor * K = dst->src[1];
|
| 626 |
const ggml_tensor * V = dst->src[2];
|
|
@@ -641,9 +645,49 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
|
|
| 641 |
ggml_cuda_pool & pool = ctx.pool();
|
| 642 |
cudaStream_t main_stream = ctx.stream();
|
| 643 |
|
|
|
|
|
|
|
| 644 |
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
| 645 |
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
| 646 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 647 |
if (parallel_blocks > 1) {
|
| 648 |
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
| 649 |
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
|
@@ -667,8 +711,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
|
|
| 667 |
|
| 668 |
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
|
| 669 |
(const char *) Q->data,
|
| 670 |
-
|
| 671 |
-
|
| 672 |
mask ? ((const char *) mask->data) : nullptr,
|
| 673 |
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 674 |
scale, max_bias, m0, m1, n_head_log2,
|
|
@@ -676,8 +720,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
|
|
| 676 |
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 677 |
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
| 678 |
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 679 |
-
|
| 680 |
-
|
| 681 |
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
| 682 |
);
|
| 683 |
CUDA_CHECK(cudaGetLastError());
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include "common.cuh"
|
| 4 |
+
#include "convert.cuh"
|
| 5 |
#include "vecdotq.cuh"
|
| 6 |
|
| 7 |
#include <cstdint>
|
|
|
|
| 54 |
template<typename T, int D>
|
| 55 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
| 56 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 57 |
+
#if __CUDA_ARCH__ >= MIN_CC_DP4A
|
| 58 |
|
| 59 |
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
| 60 |
GGML_UNUSED(Q_v);
|
|
|
|
| 96 |
GGML_UNUSED(Q_q8);
|
| 97 |
GGML_UNUSED(Q_ds_v);
|
| 98 |
NO_DEVICE_CODE;
|
| 99 |
+
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
| 100 |
}
|
| 101 |
|
| 102 |
template<typename T, int D>
|
| 103 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
| 104 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 105 |
+
#if __CUDA_ARCH__ >= MIN_CC_DP4A
|
| 106 |
|
| 107 |
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
| 108 |
GGML_UNUSED(Q_v);
|
|
|
|
| 148 |
GGML_UNUSED(Q_q8);
|
| 149 |
GGML_UNUSED(Q_ds_v);
|
| 150 |
NO_DEVICE_CODE;
|
| 151 |
+
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
| 152 |
}
|
| 153 |
|
| 154 |
template<typename T, int D>
|
| 155 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
| 156 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 157 |
+
#if __CUDA_ARCH__ >= MIN_CC_DP4A
|
| 158 |
|
| 159 |
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
| 160 |
GGML_UNUSED(Q_v);
|
|
|
|
| 203 |
GGML_UNUSED(Q_q8);
|
| 204 |
GGML_UNUSED(Q_ds_v);
|
| 205 |
NO_DEVICE_CODE;
|
| 206 |
+
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
| 207 |
}
|
| 208 |
|
| 209 |
template<typename T, int D>
|
| 210 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
| 211 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 212 |
+
#if __CUDA_ARCH__ >= MIN_CC_DP4A
|
| 213 |
|
| 214 |
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
| 215 |
GGML_UNUSED(Q_v);
|
|
|
|
| 262 |
GGML_UNUSED(Q_q8);
|
| 263 |
GGML_UNUSED(Q_ds_v);
|
| 264 |
NO_DEVICE_CODE;
|
| 265 |
+
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
| 266 |
}
|
| 267 |
|
| 268 |
template <typename T, int D>
|
| 269 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
| 270 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 271 |
+
#if __CUDA_ARCH__ >= MIN_CC_DP4A
|
| 272 |
|
| 273 |
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
| 274 |
GGML_UNUSED(Q_v);
|
|
|
|
| 303 |
GGML_UNUSED(Q_q8);
|
| 304 |
GGML_UNUSED(Q_ds_v);
|
| 305 |
NO_DEVICE_CODE;
|
| 306 |
+
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
| 307 |
}
|
| 308 |
|
| 309 |
template <typename T, int D>
|
|
|
|
| 621 |
}
|
| 622 |
|
| 623 |
template <int D, int parallel_blocks>
|
| 624 |
+
void launch_fattn(
|
| 625 |
+
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
| 626 |
+
const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
|
| 627 |
+
) {
|
| 628 |
const ggml_tensor * Q = dst->src[0];
|
| 629 |
const ggml_tensor * K = dst->src[1];
|
| 630 |
const ggml_tensor * V = dst->src[2];
|
|
|
|
| 645 |
ggml_cuda_pool & pool = ctx.pool();
|
| 646 |
cudaStream_t main_stream = ctx.stream();
|
| 647 |
|
| 648 |
+
ggml_cuda_pool_alloc<half> K_f16(pool);
|
| 649 |
+
ggml_cuda_pool_alloc<half> V_f16(pool);
|
| 650 |
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
| 651 |
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
| 652 |
|
| 653 |
+
char * K_data = (char *) K->data;
|
| 654 |
+
size_t nb11 = K->nb[1];
|
| 655 |
+
size_t nb12 = K->nb[2];
|
| 656 |
+
size_t nb13 = K->nb[3];
|
| 657 |
+
|
| 658 |
+
char * V_data = (char *) V->data;
|
| 659 |
+
size_t nb21 = V->nb[1];
|
| 660 |
+
size_t nb22 = V->nb[2];
|
| 661 |
+
size_t nb23 = V->nb[3];
|
| 662 |
+
|
| 663 |
+
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
| 664 |
+
K_f16.alloc(ggml_nelements(K));
|
| 665 |
+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
| 666 |
+
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
| 667 |
+
K_data = (char *) K_f16.ptr;
|
| 668 |
+
|
| 669 |
+
const size_t bs = ggml_blck_size(K->type);
|
| 670 |
+
const size_t ts = ggml_type_size(K->type);
|
| 671 |
+
|
| 672 |
+
nb11 = nb11*bs*sizeof(half)/ts;
|
| 673 |
+
nb12 = nb12*bs*sizeof(half)/ts;
|
| 674 |
+
nb13 = nb13*bs*sizeof(half)/ts;
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
| 678 |
+
V_f16.alloc(ggml_nelements(V));
|
| 679 |
+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
| 680 |
+
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
| 681 |
+
V_data = (char *) V_f16.ptr;
|
| 682 |
+
|
| 683 |
+
const size_t bs = ggml_blck_size(V->type);
|
| 684 |
+
const size_t ts = ggml_type_size(V->type);
|
| 685 |
+
|
| 686 |
+
nb21 = nb21*bs*sizeof(half)/ts;
|
| 687 |
+
nb22 = nb22*bs*sizeof(half)/ts;
|
| 688 |
+
nb23 = nb23*bs*sizeof(half)/ts;
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
if (parallel_blocks > 1) {
|
| 692 |
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
| 693 |
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
|
|
|
| 711 |
|
| 712 |
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
|
| 713 |
(const char *) Q->data,
|
| 714 |
+
K_data,
|
| 715 |
+
V_data,
|
| 716 |
mask ? ((const char *) mask->data) : nullptr,
|
| 717 |
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 718 |
scale, max_bias, m0, m1, n_head_log2,
|
|
|
|
| 720 |
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 721 |
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
| 722 |
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 723 |
+
nb11, nb12, nb13,
|
| 724 |
+
nb21, nb22, nb23,
|
| 725 |
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
| 726 |
);
|
| 727 |
CUDA_CHECK(cudaGetLastError());
|
ggml-cuda/fattn-tile-f16.cu
CHANGED
|
@@ -278,13 +278,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 278 |
constexpr int D = 64;
|
| 279 |
constexpr int nwarps = 8;
|
| 280 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
| 281 |
-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
| 282 |
} break;
|
| 283 |
case 128: {
|
| 284 |
constexpr int D = 128;
|
| 285 |
constexpr int nwarps = 8;
|
| 286 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
| 287 |
-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
| 288 |
} break;
|
| 289 |
default: {
|
| 290 |
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
|
|
|
| 278 |
constexpr int D = 64;
|
| 279 |
constexpr int nwarps = 8;
|
| 280 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
| 281 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
| 282 |
} break;
|
| 283 |
case 128: {
|
| 284 |
constexpr int D = 128;
|
| 285 |
constexpr int nwarps = 8;
|
| 286 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
| 287 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
| 288 |
} break;
|
| 289 |
default: {
|
| 290 |
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
ggml-cuda/fattn-tile-f32.cu
CHANGED
|
@@ -275,13 +275,13 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 275 |
constexpr int D = 64;
|
| 276 |
constexpr int nwarps = 8;
|
| 277 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
| 278 |
-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
| 279 |
} break;
|
| 280 |
case 128: {
|
| 281 |
constexpr int D = 128;
|
| 282 |
constexpr int nwarps = 8;
|
| 283 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
| 284 |
-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
| 285 |
} break;
|
| 286 |
default: {
|
| 287 |
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
|
|
|
| 275 |
constexpr int D = 64;
|
| 276 |
constexpr int nwarps = 8;
|
| 277 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
| 278 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
| 279 |
} break;
|
| 280 |
case 128: {
|
| 281 |
constexpr int D = 128;
|
| 282 |
constexpr int nwarps = 8;
|
| 283 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
| 284 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
| 285 |
} break;
|
| 286 |
default: {
|
| 287 |
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
ggml-cuda/fattn-vec-f16.cuh
CHANGED
|
@@ -290,7 +290,9 @@ template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml
|
|
| 290 |
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 291 |
constexpr int nwarps = D/WARP_SIZE;
|
| 292 |
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
| 293 |
-
|
|
|
|
|
|
|
| 294 |
}
|
| 295 |
|
| 296 |
template <int D, ggml_type type_K, ggml_type type_V>
|
|
|
|
| 290 |
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 291 |
constexpr int nwarps = D/WARP_SIZE;
|
| 292 |
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
| 293 |
+
constexpr bool need_f16_K = D != 128;
|
| 294 |
+
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 295 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
| 296 |
}
|
| 297 |
|
| 298 |
template <int D, ggml_type type_K, ggml_type type_V>
|
ggml-cuda/fattn-vec-f32.cuh
CHANGED
|
@@ -271,7 +271,9 @@ template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml
|
|
| 271 |
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 272 |
constexpr int nwarps = D/WARP_SIZE;
|
| 273 |
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
| 274 |
-
|
|
|
|
|
|
|
| 275 |
}
|
| 276 |
|
| 277 |
template <int D, ggml_type type_K, ggml_type type_V>
|
|
|
|
| 271 |
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 272 |
constexpr int nwarps = D/WARP_SIZE;
|
| 273 |
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
| 274 |
+
constexpr bool need_f16_K = D != 128;
|
| 275 |
+
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 276 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
| 277 |
}
|
| 278 |
|
| 279 |
template <int D, ggml_type type_K, ggml_type type_V>
|
ggml-cuda/fattn-wmma-f16.cuh
CHANGED
|
@@ -438,18 +438,18 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
|
| 438 |
if (4*blocks_num_pb1 < 2*nsm) {
|
| 439 |
constexpr int parallel_blocks = 4;
|
| 440 |
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
| 441 |
-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
| 442 |
return;
|
| 443 |
}
|
| 444 |
if (2*blocks_num_pb1 < 2*nsm) {
|
| 445 |
constexpr int parallel_blocks = 2;
|
| 446 |
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
| 447 |
-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
| 448 |
return;
|
| 449 |
}
|
| 450 |
constexpr int parallel_blocks = 1;
|
| 451 |
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
| 452 |
-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
| 453 |
}
|
| 454 |
|
| 455 |
#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
|
|
|
|
| 438 |
if (4*blocks_num_pb1 < 2*nsm) {
|
| 439 |
constexpr int parallel_blocks = 4;
|
| 440 |
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
| 441 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
| 442 |
return;
|
| 443 |
}
|
| 444 |
if (2*blocks_num_pb1 < 2*nsm) {
|
| 445 |
constexpr int parallel_blocks = 2;
|
| 446 |
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
| 447 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
| 448 |
return;
|
| 449 |
}
|
| 450 |
constexpr int parallel_blocks = 1;
|
| 451 |
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
| 452 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
| 453 |
}
|
| 454 |
|
| 455 |
#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
|
ggml-cuda/fattn.cu
CHANGED
|
@@ -298,17 +298,13 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
|
| 298 |
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 299 |
const ggml_tensor * KQV = dst;
|
| 300 |
const ggml_tensor * Q = dst->src[0];
|
| 301 |
-
const ggml_tensor * K = dst->src[1];
|
| 302 |
-
const ggml_tensor * V = dst->src[2];
|
| 303 |
|
| 304 |
ggml_cuda_set_device(ctx.device);
|
| 305 |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 306 |
const int32_t precision = KQV->op_params[2];
|
| 307 |
|
| 308 |
-
const bool quantized_KV = ggml_is_quantized(K->type) || ggml_is_quantized(V->type);
|
| 309 |
-
|
| 310 |
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
| 311 |
-
if (cc >= CC_OFFSET_AMD
|
| 312 |
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
| 313 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 314 |
} else {
|
|
|
|
| 298 |
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 299 |
const ggml_tensor * KQV = dst;
|
| 300 |
const ggml_tensor * Q = dst->src[0];
|
|
|
|
|
|
|
| 301 |
|
| 302 |
ggml_cuda_set_device(ctx.device);
|
| 303 |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 304 |
const int32_t precision = KQV->op_params[2];
|
| 305 |
|
|
|
|
|
|
|
| 306 |
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
| 307 |
+
if (cc >= CC_OFFSET_AMD) {
|
| 308 |
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
| 309 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 310 |
} else {
|