Spaces:
Running
Running
uvos
commited on
Commit
·
2adc060
1
Parent(s):
1e50161
CUDA/HIP: Fix fattn-vec-* when device warp size is not 32 (llama/12315)
Browse filesWhen fattn-wmma was ported over to warp64 various bits that also touch fattn-vec where converted to
selectable warp size, however the fattn-vec kernels dont work with 64 wide warps for now, so we need
to avoid launching them with parameters for warp64
ggml/src/ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -52,12 +52,11 @@ typedef half (*vec_dot_KQ_f16_t)(
|
|
| 52 |
typedef float (*vec_dot_KQ_f32_t)(
|
| 53 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
| 54 |
|
| 55 |
-
template<typename T, int D>
|
| 56 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
| 57 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 58 |
|
| 59 |
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
| 60 |
-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 61 |
GGML_UNUSED(Q_v);
|
| 62 |
|
| 63 |
T sum = 0.0f;
|
|
@@ -93,12 +92,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
|
| 93 |
return sum;
|
| 94 |
}
|
| 95 |
|
| 96 |
-
template<typename T, int D>
|
| 97 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
| 98 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 99 |
|
| 100 |
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
| 101 |
-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 102 |
GGML_UNUSED(Q_v);
|
| 103 |
|
| 104 |
T sum = 0.0f;
|
|
@@ -138,12 +136,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
|
| 138 |
return sum;
|
| 139 |
}
|
| 140 |
|
| 141 |
-
template<typename T, int D>
|
| 142 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
| 143 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 144 |
|
| 145 |
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
| 146 |
-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 147 |
GGML_UNUSED(Q_v);
|
| 148 |
|
| 149 |
T sum = 0.0f;
|
|
@@ -186,12 +183,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
|
| 186 |
return sum;
|
| 187 |
}
|
| 188 |
|
| 189 |
-
template<typename T, int D>
|
| 190 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
| 191 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 192 |
|
| 193 |
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
| 194 |
-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 195 |
GGML_UNUSED(Q_v);
|
| 196 |
|
| 197 |
T sum = 0.0f;
|
|
@@ -238,12 +234,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
|
| 238 |
return sum;
|
| 239 |
}
|
| 240 |
|
| 241 |
-
template <typename T, int D>
|
| 242 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
| 243 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 244 |
|
| 245 |
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
| 246 |
-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 247 |
GGML_UNUSED(Q_v);
|
| 248 |
|
| 249 |
T sum = 0.0f;
|
|
@@ -272,12 +267,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
|
| 272 |
return sum;
|
| 273 |
}
|
| 274 |
|
| 275 |
-
template <typename T, int D>
|
| 276 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
| 277 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
| 278 |
|
| 279 |
const half2 * K_h2 = (const half2 *) K_c;
|
| 280 |
-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 281 |
GGML_UNUSED(Q_q8);
|
| 282 |
GGML_UNUSED(Q_ds_v);
|
| 283 |
|
|
@@ -480,25 +474,25 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
|
|
| 480 |
return x[i];
|
| 481 |
}
|
| 482 |
|
| 483 |
-
template <int D>
|
| 484 |
constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
|
| 485 |
-
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
|
| 486 |
-
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
|
| 487 |
-
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
|
| 488 |
-
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
|
| 489 |
-
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
|
| 490 |
-
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
|
| 491 |
nullptr;
|
| 492 |
}
|
| 493 |
|
| 494 |
-
template <int D>
|
| 495 |
constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
|
| 496 |
-
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
|
| 497 |
-
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
|
| 498 |
-
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
|
| 499 |
-
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
|
| 500 |
-
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
|
| 501 |
-
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
|
| 502 |
nullptr;
|
| 503 |
}
|
| 504 |
|
|
@@ -681,7 +675,8 @@ static void on_no_fattn_vec_case(const int D) {
|
|
| 681 |
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
|
| 682 |
void launch_fattn(
|
| 683 |
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
| 684 |
-
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
|
|
|
|
| 685 |
) {
|
| 686 |
constexpr int ncols = ncols1 * ncols2;
|
| 687 |
|
|
@@ -704,8 +699,6 @@ void launch_fattn(
|
|
| 704 |
|
| 705 |
GGML_ASSERT(Q->ne[3] == 1);
|
| 706 |
|
| 707 |
-
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
|
| 708 |
-
|
| 709 |
ggml_cuda_pool & pool = ctx.pool();
|
| 710 |
cudaStream_t main_stream = ctx.stream();
|
| 711 |
const int id = ggml_cuda_get_device();
|
|
@@ -805,7 +798,6 @@ void launch_fattn(
|
|
| 805 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 806 |
|
| 807 |
GGML_ASSERT(block_dim.x % warp_size == 0);
|
| 808 |
-
GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
|
| 809 |
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
| 810 |
(const char *) Q->data,
|
| 811 |
K_data,
|
|
|
|
| 52 |
typedef float (*vec_dot_KQ_f32_t)(
|
| 53 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
| 54 |
|
| 55 |
+
template<typename T, int D, int warp_size>
|
| 56 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
| 57 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 58 |
|
| 59 |
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
|
|
|
| 60 |
GGML_UNUSED(Q_v);
|
| 61 |
|
| 62 |
T sum = 0.0f;
|
|
|
|
| 92 |
return sum;
|
| 93 |
}
|
| 94 |
|
| 95 |
+
template<typename T, int D, int warp_size>
|
| 96 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
| 97 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 98 |
|
| 99 |
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
|
|
|
| 100 |
GGML_UNUSED(Q_v);
|
| 101 |
|
| 102 |
T sum = 0.0f;
|
|
|
|
| 136 |
return sum;
|
| 137 |
}
|
| 138 |
|
| 139 |
+
template<typename T, int D, int warp_size>
|
| 140 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
| 141 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 142 |
|
| 143 |
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
|
|
|
| 144 |
GGML_UNUSED(Q_v);
|
| 145 |
|
| 146 |
T sum = 0.0f;
|
|
|
|
| 183 |
return sum;
|
| 184 |
}
|
| 185 |
|
| 186 |
+
template<typename T, int D, int warp_size>
|
| 187 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
| 188 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 189 |
|
| 190 |
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
|
|
|
| 191 |
GGML_UNUSED(Q_v);
|
| 192 |
|
| 193 |
T sum = 0.0f;
|
|
|
|
| 234 |
return sum;
|
| 235 |
}
|
| 236 |
|
| 237 |
+
template <typename T, int D, int warp_size>
|
| 238 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
| 239 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
| 240 |
|
| 241 |
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
|
|
|
| 242 |
GGML_UNUSED(Q_v);
|
| 243 |
|
| 244 |
T sum = 0.0f;
|
|
|
|
| 267 |
return sum;
|
| 268 |
}
|
| 269 |
|
| 270 |
+
template <typename T, int D, int warp_size>
|
| 271 |
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
| 272 |
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
| 273 |
|
| 274 |
const half2 * K_h2 = (const half2 *) K_c;
|
|
|
|
| 275 |
GGML_UNUSED(Q_q8);
|
| 276 |
GGML_UNUSED(Q_ds_v);
|
| 277 |
|
|
|
|
| 474 |
return x[i];
|
| 475 |
}
|
| 476 |
|
| 477 |
+
template <int D, int warp_size = WARP_SIZE>
|
| 478 |
constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
|
| 479 |
+
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D, warp_size> :
|
| 480 |
+
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D, warp_size> :
|
| 481 |
+
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D, warp_size> :
|
| 482 |
+
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size> :
|
| 483 |
+
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size> :
|
| 484 |
+
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size> :
|
| 485 |
nullptr;
|
| 486 |
}
|
| 487 |
|
| 488 |
+
template <int D, int warp_size = WARP_SIZE>
|
| 489 |
constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
|
| 490 |
+
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D, warp_size> :
|
| 491 |
+
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D, warp_size> :
|
| 492 |
+
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D, warp_size> :
|
| 493 |
+
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D, warp_size> :
|
| 494 |
+
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D, warp_size> :
|
| 495 |
+
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D, warp_size> :
|
| 496 |
nullptr;
|
| 497 |
}
|
| 498 |
|
|
|
|
| 675 |
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
|
| 676 |
void launch_fattn(
|
| 677 |
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
| 678 |
+
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
|
| 679 |
+
const int warp_size = WARP_SIZE
|
| 680 |
) {
|
| 681 |
constexpr int ncols = ncols1 * ncols2;
|
| 682 |
|
|
|
|
| 699 |
|
| 700 |
GGML_ASSERT(Q->ne[3] == 1);
|
| 701 |
|
|
|
|
|
|
|
| 702 |
ggml_cuda_pool & pool = ctx.pool();
|
| 703 |
cudaStream_t main_stream = ctx.stream();
|
| 704 |
const int id = ggml_cuda_get_device();
|
|
|
|
| 798 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 799 |
|
| 800 |
GGML_ASSERT(block_dim.x % warp_size == 0);
|
|
|
|
| 801 |
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
| 802 |
(const char *) Q->data,
|
| 803 |
K_data,
|
ggml/src/ggml-cuda/fattn-wmma-f16.cu
CHANGED
|
@@ -469,6 +469,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
|
| 469 |
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
|
| 470 |
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
| 471 |
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
|
|
|
| 472 |
|
| 473 |
float logit_softcap;
|
| 474 |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
|
@@ -485,7 +486,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
|
| 485 |
fattn_kernel = flash_attn_ext_f16<
|
| 486 |
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 487 |
}
|
| 488 |
-
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
| 489 |
return;
|
| 490 |
}
|
| 491 |
if (2*blocks_num_pb1 < 2*nsm) {
|
|
@@ -500,7 +501,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
|
| 500 |
fattn_kernel = flash_attn_ext_f16<
|
| 501 |
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 502 |
}
|
| 503 |
-
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
| 504 |
return;
|
| 505 |
}
|
| 506 |
constexpr int parallel_blocks = 1;
|
|
@@ -514,7 +515,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
|
| 514 |
fattn_kernel = flash_attn_ext_f16<
|
| 515 |
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 516 |
}
|
| 517 |
-
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
| 518 |
}
|
| 519 |
|
| 520 |
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
| 469 |
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
|
| 470 |
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
| 471 |
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
| 472 |
+
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
| 473 |
|
| 474 |
float logit_softcap;
|
| 475 |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
|
|
|
| 486 |
fattn_kernel = flash_attn_ext_f16<
|
| 487 |
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 488 |
}
|
| 489 |
+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
|
| 490 |
return;
|
| 491 |
}
|
| 492 |
if (2*blocks_num_pb1 < 2*nsm) {
|
|
|
|
| 501 |
fattn_kernel = flash_attn_ext_f16<
|
| 502 |
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 503 |
}
|
| 504 |
+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
|
| 505 |
return;
|
| 506 |
}
|
| 507 |
constexpr int parallel_blocks = 1;
|
|
|
|
| 515 |
fattn_kernel = flash_attn_ext_f16<
|
| 516 |
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 517 |
}
|
| 518 |
+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
|
| 519 |
}
|
| 520 |
|
| 521 |
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|