Spaces:
Running
Running
Commit
·
65ab3e8
1
Parent(s):
ad83dfd
CUDA: deduplicate FlashAttention code (llama/7352)
Browse files- ggml-cuda/common.cuh +11 -0
- ggml-cuda/fattn-common.cuh +115 -0
- ggml-cuda/fattn-tile-f16.cu +26 -109
- ggml-cuda/fattn-tile-f32.cu +25 -109
- ggml-cuda/fattn.cu +48 -129
- ggml-cuda/softmax.cu +2 -11
ggml-cuda/common.cuh
CHANGED
|
@@ -477,6 +477,17 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -
|
|
| 477 |
|
| 478 |
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
| 479 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
|
| 481 |
//////////////////////
|
| 482 |
|
|
|
|
| 477 |
|
| 478 |
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
| 479 |
|
| 480 |
+
static __device__ __forceinline__ float get_alibi_slope(
|
| 481 |
+
const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
|
| 482 |
+
) {
|
| 483 |
+
if (max_bias <= 0.0f) {
|
| 484 |
+
return 1.0f;
|
| 485 |
+
}
|
| 486 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 487 |
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 488 |
+
|
| 489 |
+
return powf(base, exph);
|
| 490 |
+
}
|
| 491 |
|
| 492 |
//////////////////////
|
| 493 |
|
ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -1,7 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#define FATTN_KQ_STRIDE 256
|
| 2 |
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
| 3 |
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
template<int D, int parallel_blocks> // D == head size
|
| 6 |
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 7 |
__launch_bounds__(D, 1)
|
|
@@ -45,3 +82,81 @@ static __global__ void flash_attn_combine_results(
|
|
| 45 |
|
| 46 |
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
| 47 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
#define FATTN_KQ_STRIDE 256
|
| 6 |
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
| 7 |
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
| 8 |
|
| 9 |
+
typedef void (* fattn_kernel_t)(
|
| 10 |
+
const char * __restrict__ Q,
|
| 11 |
+
const char * __restrict__ K,
|
| 12 |
+
const char * __restrict__ V,
|
| 13 |
+
const char * __restrict__ mask,
|
| 14 |
+
float * __restrict__ dst,
|
| 15 |
+
float2 * __restrict__ dst_meta,
|
| 16 |
+
const float scale,
|
| 17 |
+
const float max_bias,
|
| 18 |
+
const float m0,
|
| 19 |
+
const float m1,
|
| 20 |
+
const uint32_t n_head_log2,
|
| 21 |
+
const int ne00,
|
| 22 |
+
const int ne01,
|
| 23 |
+
const int ne02,
|
| 24 |
+
const int ne03,
|
| 25 |
+
const int ne10,
|
| 26 |
+
const int ne11,
|
| 27 |
+
const int ne12,
|
| 28 |
+
const int ne13,
|
| 29 |
+
const int ne31,
|
| 30 |
+
const int nb31,
|
| 31 |
+
const int nb01,
|
| 32 |
+
const int nb02,
|
| 33 |
+
const int nb03,
|
| 34 |
+
const int nb11,
|
| 35 |
+
const int nb12,
|
| 36 |
+
const int nb13,
|
| 37 |
+
const int ne0,
|
| 38 |
+
const int ne1,
|
| 39 |
+
const int ne2,
|
| 40 |
+
const int ne3);
|
| 41 |
+
|
| 42 |
template<int D, int parallel_blocks> // D == head size
|
| 43 |
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 44 |
__launch_bounds__(D, 1)
|
|
|
|
| 82 |
|
| 83 |
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
| 84 |
}
|
| 85 |
+
|
| 86 |
+
template <int D, int parallel_blocks>
|
| 87 |
+
void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
|
| 88 |
+
const ggml_tensor * Q = dst->src[0];
|
| 89 |
+
const ggml_tensor * K = dst->src[1];
|
| 90 |
+
const ggml_tensor * V = dst->src[2];
|
| 91 |
+
|
| 92 |
+
const ggml_tensor * mask = dst->src[3];
|
| 93 |
+
|
| 94 |
+
ggml_tensor * KQV = dst;
|
| 95 |
+
|
| 96 |
+
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
| 97 |
+
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
| 98 |
+
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
| 99 |
+
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
| 100 |
+
|
| 101 |
+
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
| 102 |
+
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
| 103 |
+
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
| 104 |
+
|
| 105 |
+
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
| 106 |
+
|
| 107 |
+
ggml_cuda_pool & pool = ctx.pool();
|
| 108 |
+
cudaStream_t main_stream = ctx.stream();
|
| 109 |
+
|
| 110 |
+
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
| 111 |
+
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
| 112 |
+
|
| 113 |
+
if (parallel_blocks > 1) {
|
| 114 |
+
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
| 115 |
+
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
| 119 |
+
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
| 120 |
+
const int shmem = 0;
|
| 121 |
+
|
| 122 |
+
float scale = 1.0f;
|
| 123 |
+
float max_bias = 0.0f;
|
| 124 |
+
|
| 125 |
+
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
| 126 |
+
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
| 127 |
+
|
| 128 |
+
const uint32_t n_head = Q->ne[2];
|
| 129 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
| 130 |
+
|
| 131 |
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 132 |
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 133 |
+
|
| 134 |
+
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
|
| 135 |
+
(const char *) Q->data,
|
| 136 |
+
(const char *) K->data,
|
| 137 |
+
(const char *) V->data,
|
| 138 |
+
mask ? ((const char *) mask->data) : nullptr,
|
| 139 |
+
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 140 |
+
scale, max_bias, m0, m1, n_head_log2,
|
| 141 |
+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 142 |
+
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 143 |
+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
| 144 |
+
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 145 |
+
K->nb[1], K->nb[2], K->nb[3],
|
| 146 |
+
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
| 147 |
+
);
|
| 148 |
+
CUDA_CHECK(cudaGetLastError());
|
| 149 |
+
|
| 150 |
+
if ((parallel_blocks) == 1) {
|
| 151 |
+
return;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
const dim3 block_dim_combine(D, 1, 1);
|
| 155 |
+
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
| 156 |
+
const int shmem_combine = 0;
|
| 157 |
+
|
| 158 |
+
flash_attn_combine_results<D, parallel_blocks>
|
| 159 |
+
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
| 160 |
+
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
| 161 |
+
CUDA_CHECK(cudaGetLastError());
|
| 162 |
+
}
|
ggml-cuda/fattn-tile-f16.cu
CHANGED
|
@@ -54,17 +54,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 54 |
|
| 55 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
// ALiBi
|
| 60 |
-
if (max_bias > 0.0f) {
|
| 61 |
-
const uint32_t h = blockIdx.y;
|
| 62 |
-
|
| 63 |
-
const float base = h < n_head_log2 ? m0 : m1;
|
| 64 |
-
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 65 |
-
|
| 66 |
-
slopeh = __float2half(powf(base, exph));
|
| 67 |
-
}
|
| 68 |
|
| 69 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 70 |
|
|
@@ -272,124 +263,50 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 272 |
#endif // FP16_AVAILABLE
|
| 273 |
}
|
| 274 |
|
| 275 |
-
template <int
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
) {
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
| 296 |
-
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
| 297 |
-
|
| 298 |
-
const uint32_t n_head = Q->ne[2];
|
| 299 |
-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
| 300 |
-
|
| 301 |
-
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 302 |
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 303 |
-
|
| 304 |
-
flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>
|
| 305 |
-
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
| 306 |
-
(const char *) Q->data,
|
| 307 |
-
(const char *) K->data,
|
| 308 |
-
(const char *) V->data,
|
| 309 |
-
mask ? ((const char *) mask->data) : nullptr,
|
| 310 |
-
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 311 |
-
scale, max_bias, m0, m1, n_head_log2,
|
| 312 |
-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 313 |
-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 314 |
-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
| 315 |
-
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 316 |
-
K->nb[1], K->nb[2], K->nb[3],
|
| 317 |
-
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
| 318 |
-
);
|
| 319 |
-
CUDA_CHECK(cudaGetLastError());
|
| 320 |
-
|
| 321 |
-
if (parallel_blocks == 1) {
|
| 322 |
-
return;
|
| 323 |
}
|
| 324 |
-
|
| 325 |
-
const dim3 block_dim_combine(D, 1, 1);
|
| 326 |
-
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
| 327 |
-
const int shmem_combine = 0;
|
| 328 |
-
|
| 329 |
-
flash_attn_combine_results<D, parallel_blocks>
|
| 330 |
-
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
| 331 |
-
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
| 332 |
-
CUDA_CHECK(cudaGetLastError());
|
| 333 |
}
|
| 334 |
|
| 335 |
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 336 |
-
const ggml_tensor *
|
| 337 |
-
const ggml_tensor *
|
| 338 |
-
const ggml_tensor * V = dst->src[2];
|
| 339 |
-
|
| 340 |
-
const ggml_tensor * mask = dst->src[3];
|
| 341 |
-
|
| 342 |
-
ggml_tensor * KQV = dst;
|
| 343 |
|
| 344 |
const int32_t precision = KQV->op_params[2];
|
| 345 |
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
| 346 |
-
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
| 347 |
|
| 348 |
if (Q->ne[1] <= 16) {
|
| 349 |
constexpr int cols_per_block = 16;
|
| 350 |
constexpr int parallel_blocks = 4;
|
| 351 |
-
|
| 352 |
-
case 64:
|
| 353 |
-
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 354 |
-
break;
|
| 355 |
-
case 128:
|
| 356 |
-
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 357 |
-
break;
|
| 358 |
-
default:
|
| 359 |
-
GGML_ASSERT(false);
|
| 360 |
-
break;
|
| 361 |
-
}
|
| 362 |
return;
|
| 363 |
}
|
| 364 |
|
| 365 |
if (Q->ne[1] <= 32) {
|
| 366 |
constexpr int cols_per_block = 32;
|
| 367 |
constexpr int parallel_blocks = 4;
|
| 368 |
-
|
| 369 |
-
case 64:
|
| 370 |
-
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 371 |
-
break;
|
| 372 |
-
case 128:
|
| 373 |
-
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 374 |
-
break;
|
| 375 |
-
default:
|
| 376 |
-
GGML_ASSERT(false);
|
| 377 |
-
break;
|
| 378 |
-
}
|
| 379 |
return;
|
| 380 |
}
|
| 381 |
|
| 382 |
constexpr int cols_per_block = 32;
|
| 383 |
constexpr int parallel_blocks = 1;
|
| 384 |
-
|
| 385 |
-
case 64:
|
| 386 |
-
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 387 |
-
break;
|
| 388 |
-
case 128:
|
| 389 |
-
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 390 |
-
break;
|
| 391 |
-
default:
|
| 392 |
-
GGML_ASSERT(false);
|
| 393 |
-
break;
|
| 394 |
-
}
|
| 395 |
}
|
|
|
|
| 54 |
|
| 55 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 56 |
|
| 57 |
+
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
| 58 |
+
const half slopeh = __float2half(slopef);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 61 |
|
|
|
|
| 263 |
#endif // FP16_AVAILABLE
|
| 264 |
}
|
| 265 |
|
| 266 |
+
template <int cols_per_block, int parallel_blocks>
|
| 267 |
+
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 268 |
+
const ggml_tensor * Q = dst->src[0];
|
| 269 |
+
switch (Q->ne[0]) {
|
| 270 |
+
case 64: {
|
| 271 |
+
constexpr int D = 64;
|
| 272 |
+
constexpr int nwarps = 8;
|
| 273 |
+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
| 274 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
| 275 |
+
} break;
|
| 276 |
+
case 128: {
|
| 277 |
+
constexpr int D = 128;
|
| 278 |
+
constexpr int nwarps = 8;
|
| 279 |
+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
| 280 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
| 281 |
+
} break;
|
| 282 |
+
default: {
|
| 283 |
+
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
| 284 |
+
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
}
|
| 287 |
|
| 288 |
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 289 |
+
const ggml_tensor * KQV = dst;
|
| 290 |
+
const ggml_tensor * Q = dst->src[0];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
const int32_t precision = KQV->op_params[2];
|
| 293 |
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
|
|
|
| 294 |
|
| 295 |
if (Q->ne[1] <= 16) {
|
| 296 |
constexpr int cols_per_block = 16;
|
| 297 |
constexpr int parallel_blocks = 4;
|
| 298 |
+
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
return;
|
| 300 |
}
|
| 301 |
|
| 302 |
if (Q->ne[1] <= 32) {
|
| 303 |
constexpr int cols_per_block = 32;
|
| 304 |
constexpr int parallel_blocks = 4;
|
| 305 |
+
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
return;
|
| 307 |
}
|
| 308 |
|
| 309 |
constexpr int cols_per_block = 32;
|
| 310 |
constexpr int parallel_blocks = 1;
|
| 311 |
+
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
}
|
ggml-cuda/fattn-tile-f32.cu
CHANGED
|
@@ -53,17 +53,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 53 |
|
| 54 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 55 |
|
| 56 |
-
float slope =
|
| 57 |
-
|
| 58 |
-
// ALiBi
|
| 59 |
-
if (max_bias > 0.0f) {
|
| 60 |
-
const uint32_t h = blockIdx.y;
|
| 61 |
-
|
| 62 |
-
const float base = h < n_head_log2 ? m0 : m1;
|
| 63 |
-
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 64 |
-
|
| 65 |
-
slope = powf(base, exph);
|
| 66 |
-
}
|
| 67 |
|
| 68 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 69 |
|
|
@@ -270,124 +260,50 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 270 |
}
|
| 271 |
}
|
| 272 |
|
| 273 |
-
template <int
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
) {
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
| 294 |
-
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
| 295 |
-
|
| 296 |
-
const uint32_t n_head = Q->ne[2];
|
| 297 |
-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
| 298 |
-
|
| 299 |
-
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 300 |
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 301 |
-
|
| 302 |
-
flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>
|
| 303 |
-
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
| 304 |
-
(const char *) Q->data,
|
| 305 |
-
(const char *) K->data,
|
| 306 |
-
(const char *) V->data,
|
| 307 |
-
mask ? ((const char *) mask->data) : nullptr,
|
| 308 |
-
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 309 |
-
scale, max_bias, m0, m1, n_head_log2,
|
| 310 |
-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 311 |
-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 312 |
-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
| 313 |
-
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 314 |
-
K->nb[1], K->nb[2], K->nb[3],
|
| 315 |
-
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
| 316 |
-
);
|
| 317 |
-
CUDA_CHECK(cudaGetLastError());
|
| 318 |
-
|
| 319 |
-
if (parallel_blocks == 1) {
|
| 320 |
-
return;
|
| 321 |
}
|
| 322 |
-
|
| 323 |
-
const dim3 block_dim_combine(D, 1, 1);
|
| 324 |
-
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
| 325 |
-
const int shmem_combine = 0;
|
| 326 |
-
|
| 327 |
-
flash_attn_combine_results<D, parallel_blocks>
|
| 328 |
-
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
| 329 |
-
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
| 330 |
-
CUDA_CHECK(cudaGetLastError());
|
| 331 |
}
|
| 332 |
|
| 333 |
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 334 |
-
const ggml_tensor *
|
| 335 |
-
const ggml_tensor *
|
| 336 |
-
const ggml_tensor * V = dst->src[2];
|
| 337 |
-
|
| 338 |
-
const ggml_tensor * mask = dst->src[3];
|
| 339 |
-
|
| 340 |
-
ggml_tensor * KQV = dst;
|
| 341 |
|
| 342 |
const int32_t precision = KQV->op_params[2];
|
| 343 |
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
| 344 |
-
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
| 345 |
|
| 346 |
if (Q->ne[1] <= 16) {
|
| 347 |
constexpr int cols_per_block = 16;
|
| 348 |
constexpr int parallel_blocks = 4;
|
| 349 |
-
|
| 350 |
-
case 64:
|
| 351 |
-
launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 352 |
-
break;
|
| 353 |
-
case 128:
|
| 354 |
-
launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 355 |
-
break;
|
| 356 |
-
default:
|
| 357 |
-
GGML_ASSERT(false);
|
| 358 |
-
break;
|
| 359 |
-
}
|
| 360 |
return;
|
| 361 |
}
|
| 362 |
|
| 363 |
if (Q->ne[1] <= 32) {
|
| 364 |
constexpr int cols_per_block = 32;
|
| 365 |
constexpr int parallel_blocks = 4;
|
| 366 |
-
|
| 367 |
-
case 64:
|
| 368 |
-
launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 369 |
-
break;
|
| 370 |
-
case 128:
|
| 371 |
-
launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 372 |
-
break;
|
| 373 |
-
default:
|
| 374 |
-
GGML_ASSERT(false);
|
| 375 |
-
break;
|
| 376 |
-
}
|
| 377 |
return;
|
| 378 |
}
|
| 379 |
|
| 380 |
constexpr int cols_per_block = 32;
|
| 381 |
constexpr int parallel_blocks = 1;
|
| 382 |
-
|
| 383 |
-
case 64:
|
| 384 |
-
launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 385 |
-
break;
|
| 386 |
-
case 128:
|
| 387 |
-
launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 388 |
-
break;
|
| 389 |
-
default:
|
| 390 |
-
GGML_ASSERT(false);
|
| 391 |
-
break;
|
| 392 |
-
}
|
| 393 |
}
|
|
|
|
| 53 |
|
| 54 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 55 |
|
| 56 |
+
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 59 |
|
|
|
|
| 260 |
}
|
| 261 |
}
|
| 262 |
|
| 263 |
+
template <int cols_per_block, int parallel_blocks>
|
| 264 |
+
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 265 |
+
const ggml_tensor * Q = dst->src[0];
|
| 266 |
+
switch (Q->ne[0]) {
|
| 267 |
+
case 64: {
|
| 268 |
+
constexpr int D = 64;
|
| 269 |
+
constexpr int nwarps = 8;
|
| 270 |
+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
| 271 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
| 272 |
+
} break;
|
| 273 |
+
case 128: {
|
| 274 |
+
constexpr int D = 128;
|
| 275 |
+
constexpr int nwarps = 8;
|
| 276 |
+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
| 277 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
| 278 |
+
} break;
|
| 279 |
+
default: {
|
| 280 |
+
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
| 281 |
+
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
}
|
| 284 |
|
| 285 |
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 286 |
+
const ggml_tensor * KQV = dst;
|
| 287 |
+
const ggml_tensor * Q = dst->src[0];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
const int32_t precision = KQV->op_params[2];
|
| 290 |
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
|
|
|
| 291 |
|
| 292 |
if (Q->ne[1] <= 16) {
|
| 293 |
constexpr int cols_per_block = 16;
|
| 294 |
constexpr int parallel_blocks = 4;
|
| 295 |
+
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
return;
|
| 297 |
}
|
| 298 |
|
| 299 |
if (Q->ne[1] <= 32) {
|
| 300 |
constexpr int cols_per_block = 32;
|
| 301 |
constexpr int parallel_blocks = 4;
|
| 302 |
+
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
return;
|
| 304 |
}
|
| 305 |
|
| 306 |
constexpr int cols_per_block = 32;
|
| 307 |
constexpr int parallel_blocks = 1;
|
| 308 |
+
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
}
|
ggml-cuda/fattn.cu
CHANGED
|
@@ -85,19 +85,9 @@ static __global__ void flash_attn_ext_f16(
|
|
| 85 |
const int stride_Q = nb01 / sizeof(float);
|
| 86 |
const int stride_KV = nb11 / sizeof(half);
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
// ALiBi
|
| 92 |
-
if (max_bias > 0.0f) {
|
| 93 |
-
const uint32_t h = blockIdx.y;
|
| 94 |
-
|
| 95 |
-
const float base = h < n_head_log2 ? m0 : m1;
|
| 96 |
-
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 97 |
-
|
| 98 |
-
slopeh = __float2half(powf(base, exph));
|
| 99 |
-
slope2 = make_half2(slopeh, slopeh);
|
| 100 |
-
}
|
| 101 |
|
| 102 |
frag_b Q_b[D/16][ncols/frag_n];
|
| 103 |
|
|
@@ -439,108 +429,37 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
|
|
| 439 |
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
| 440 |
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
| 441 |
|
| 442 |
-
template <int D, int cols_per_block, int nwarps,
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
) {
|
| 446 |
-
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
| 447 |
-
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
| 448 |
-
|
| 449 |
-
if (parallel_blocks > 1) {
|
| 450 |
-
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
| 451 |
-
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
| 452 |
-
}
|
| 453 |
-
|
| 454 |
-
constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16;
|
| 455 |
-
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
| 456 |
-
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
|
| 457 |
-
const int shmem = 0;
|
| 458 |
-
|
| 459 |
-
float scale = 1.0f;
|
| 460 |
-
float max_bias = 0.0f;
|
| 461 |
-
|
| 462 |
-
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
| 463 |
-
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
| 464 |
-
|
| 465 |
-
const uint32_t n_head = Q->ne[2];
|
| 466 |
-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
| 467 |
-
|
| 468 |
-
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 469 |
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 470 |
-
|
| 471 |
-
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
|
| 472 |
-
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
| 473 |
-
(const char *) Q->data,
|
| 474 |
-
(const char *) K->data,
|
| 475 |
-
(const char *) V->data,
|
| 476 |
-
mask ? ((const char *) mask->data) : nullptr,
|
| 477 |
-
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 478 |
-
scale, max_bias, m0, m1, n_head_log2,
|
| 479 |
-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 480 |
-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 481 |
-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
| 482 |
-
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 483 |
-
K->nb[1], K->nb[2], K->nb[3],
|
| 484 |
-
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
| 485 |
-
);
|
| 486 |
-
CUDA_CHECK(cudaGetLastError());
|
| 487 |
-
|
| 488 |
-
if ((parallel_blocks) == 1) {
|
| 489 |
-
return;
|
| 490 |
-
}
|
| 491 |
-
|
| 492 |
-
const dim3 block_dim_combine(D, 1, 1);
|
| 493 |
-
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
| 494 |
-
const int shmem_combine = 0;
|
| 495 |
-
|
| 496 |
-
flash_attn_combine_results<D, parallel_blocks>
|
| 497 |
-
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
| 498 |
-
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
| 499 |
-
CUDA_CHECK(cudaGetLastError());
|
| 500 |
-
}
|
| 501 |
|
| 502 |
-
|
| 503 |
-
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
| 504 |
-
const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream
|
| 505 |
-
) {
|
| 506 |
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
|
|
|
| 507 |
|
| 508 |
if (4*blocks_num_pb1 < 2*nsm) {
|
| 509 |
-
|
|
|
|
|
|
|
| 510 |
return;
|
| 511 |
}
|
| 512 |
if (2*blocks_num_pb1 < 2*nsm) {
|
| 513 |
-
|
|
|
|
|
|
|
| 514 |
return;
|
| 515 |
}
|
| 516 |
-
|
|
|
|
|
|
|
| 517 |
}
|
| 518 |
|
| 519 |
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 520 |
-
const ggml_tensor *
|
| 521 |
-
const ggml_tensor *
|
| 522 |
-
const ggml_tensor * V = dst->src[2];
|
| 523 |
-
|
| 524 |
-
const ggml_tensor * mask = dst->src[3];
|
| 525 |
-
|
| 526 |
-
ggml_tensor * KQV = dst;
|
| 527 |
-
|
| 528 |
-
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
| 529 |
-
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
| 530 |
-
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
| 531 |
-
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
| 532 |
-
|
| 533 |
-
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
| 534 |
-
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
| 535 |
-
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
| 536 |
-
|
| 537 |
-
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
| 538 |
|
| 539 |
ggml_cuda_set_device(ctx.device);
|
| 540 |
-
|
| 541 |
-
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 542 |
-
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
| 543 |
-
|
| 544 |
const int32_t precision = KQV->op_params[2];
|
| 545 |
|
| 546 |
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
|
@@ -582,22 +501,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 582 |
constexpr int nwarps = 4;
|
| 583 |
switch (Q->ne[0]) {
|
| 584 |
case 64:
|
| 585 |
-
launch_fattn_f16< 64, cols_per_block, nwarps, float>(
|
| 586 |
break;
|
| 587 |
case 80:
|
| 588 |
-
launch_fattn_f16< 80, cols_per_block, nwarps, float>(
|
| 589 |
break;
|
| 590 |
case 96:
|
| 591 |
-
launch_fattn_f16< 96, cols_per_block, nwarps, float>(
|
| 592 |
break;
|
| 593 |
case 112:
|
| 594 |
-
launch_fattn_f16<112, cols_per_block, nwarps, float>(
|
| 595 |
break;
|
| 596 |
case 128:
|
| 597 |
-
launch_fattn_f16<128, cols_per_block, nwarps, float>(
|
| 598 |
break;
|
| 599 |
case 256:
|
| 600 |
-
launch_fattn_f16<256, cols_per_block, nwarps, float>(
|
| 601 |
break;
|
| 602 |
default:
|
| 603 |
GGML_ASSERT(false);
|
|
@@ -608,22 +527,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 608 |
constexpr int nwarps = 4;
|
| 609 |
switch (Q->ne[0]) {
|
| 610 |
case 64:
|
| 611 |
-
launch_fattn_f16< 64, cols_per_block, nwarps, float>(
|
| 612 |
break;
|
| 613 |
case 80:
|
| 614 |
-
launch_fattn_f16< 80, cols_per_block, nwarps, float>(
|
| 615 |
break;
|
| 616 |
case 96:
|
| 617 |
-
launch_fattn_f16< 96, cols_per_block, nwarps, float>(
|
| 618 |
break;
|
| 619 |
case 112:
|
| 620 |
-
launch_fattn_f16<112, cols_per_block, nwarps, float>(
|
| 621 |
break;
|
| 622 |
case 128:
|
| 623 |
-
launch_fattn_f16<128, cols_per_block, nwarps, float>(
|
| 624 |
break;
|
| 625 |
// case 256:
|
| 626 |
-
// launch_fattn_f16<256, cols_per_block, nwarps, float>(
|
| 627 |
// break;
|
| 628 |
default:
|
| 629 |
GGML_ASSERT(false);
|
|
@@ -643,16 +562,16 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 643 |
constexpr int nwarps = 4;
|
| 644 |
switch (Q->ne[0]) {
|
| 645 |
case 64:
|
| 646 |
-
launch_fattn_f16< 64, cols_per_block, nwarps, half>(
|
| 647 |
break;
|
| 648 |
case 96:
|
| 649 |
-
launch_fattn_f16< 96, cols_per_block, nwarps, half>(
|
| 650 |
break;
|
| 651 |
case 128:
|
| 652 |
-
launch_fattn_f16<128, cols_per_block, nwarps, half>(
|
| 653 |
break;
|
| 654 |
case 256:
|
| 655 |
-
launch_fattn_f16<256, cols_per_block, nwarps, half>(
|
| 656 |
break;
|
| 657 |
default:
|
| 658 |
GGML_ASSERT(false);
|
|
@@ -666,22 +585,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 666 |
constexpr int nwarps = 4;
|
| 667 |
switch (Q->ne[0]) {
|
| 668 |
case 64:
|
| 669 |
-
launch_fattn_f16< 64, cols_per_block, nwarps, half>(
|
| 670 |
break;
|
| 671 |
case 80:
|
| 672 |
-
launch_fattn_f16< 80, cols_per_block, nwarps, half>(
|
| 673 |
break;
|
| 674 |
case 96:
|
| 675 |
-
launch_fattn_f16< 96, cols_per_block, nwarps, half>(
|
| 676 |
break;
|
| 677 |
case 112:
|
| 678 |
-
launch_fattn_f16<112, cols_per_block, nwarps, half>(
|
| 679 |
break;
|
| 680 |
case 128:
|
| 681 |
-
launch_fattn_f16<128, cols_per_block, nwarps, half>(
|
| 682 |
break;
|
| 683 |
case 256:
|
| 684 |
-
launch_fattn_f16<256, cols_per_block, nwarps, half>(
|
| 685 |
break;
|
| 686 |
default:
|
| 687 |
GGML_ASSERT(false);
|
|
@@ -694,22 +613,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 694 |
constexpr int nwarps = 4;
|
| 695 |
switch (Q->ne[0]) {
|
| 696 |
case 64:
|
| 697 |
-
launch_fattn_f16< 64, cols_per_block, nwarps, half>(
|
| 698 |
break;
|
| 699 |
case 80:
|
| 700 |
-
launch_fattn_f16< 80, cols_per_block, nwarps, half>(
|
| 701 |
break;
|
| 702 |
case 96:
|
| 703 |
-
launch_fattn_f16< 96, cols_per_block, nwarps, half>(
|
| 704 |
break;
|
| 705 |
case 112:
|
| 706 |
-
launch_fattn_f16<112, cols_per_block, nwarps, half>(
|
| 707 |
break;
|
| 708 |
case 128:
|
| 709 |
-
launch_fattn_f16<128, cols_per_block, nwarps, half>(
|
| 710 |
break;
|
| 711 |
case 256:
|
| 712 |
-
launch_fattn_f16<256, cols_per_block, nwarps, half>(
|
| 713 |
break;
|
| 714 |
default:
|
| 715 |
GGML_ASSERT(false);
|
|
|
|
| 85 |
const int stride_Q = nb01 / sizeof(float);
|
| 86 |
const int stride_KV = nb11 / sizeof(half);
|
| 87 |
|
| 88 |
+
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
| 89 |
+
const half slopeh = __float2half(slopef);
|
| 90 |
+
const half2 slope2 = make_half2(slopef, slopef);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
frag_b Q_b[D/16][ncols/frag_n];
|
| 93 |
|
|
|
|
| 429 |
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
| 430 |
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
| 431 |
|
| 432 |
+
template <int D, int cols_per_block, int nwarps, typename KQ_acc_t>
|
| 433 |
+
void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 434 |
+
const ggml_tensor * Q = dst->src[0];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
|
| 436 |
+
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
|
|
|
|
|
|
|
|
|
|
| 437 |
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
| 438 |
+
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
| 439 |
|
| 440 |
if (4*blocks_num_pb1 < 2*nsm) {
|
| 441 |
+
constexpr int parallel_blocks = 4;
|
| 442 |
+
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
| 443 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
| 444 |
return;
|
| 445 |
}
|
| 446 |
if (2*blocks_num_pb1 < 2*nsm) {
|
| 447 |
+
constexpr int parallel_blocks = 2;
|
| 448 |
+
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
| 449 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
| 450 |
return;
|
| 451 |
}
|
| 452 |
+
constexpr int parallel_blocks = 1;
|
| 453 |
+
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
| 454 |
+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
| 455 |
}
|
| 456 |
|
| 457 |
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 458 |
+
const ggml_tensor * KQV = dst;
|
| 459 |
+
const ggml_tensor * Q = dst->src[0];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
|
| 461 |
ggml_cuda_set_device(ctx.device);
|
| 462 |
+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
|
|
|
|
|
|
|
|
|
| 463 |
const int32_t precision = KQV->op_params[2];
|
| 464 |
|
| 465 |
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
|
|
|
| 501 |
constexpr int nwarps = 4;
|
| 502 |
switch (Q->ne[0]) {
|
| 503 |
case 64:
|
| 504 |
+
launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
|
| 505 |
break;
|
| 506 |
case 80:
|
| 507 |
+
launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
|
| 508 |
break;
|
| 509 |
case 96:
|
| 510 |
+
launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
|
| 511 |
break;
|
| 512 |
case 112:
|
| 513 |
+
launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
|
| 514 |
break;
|
| 515 |
case 128:
|
| 516 |
+
launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
|
| 517 |
break;
|
| 518 |
case 256:
|
| 519 |
+
launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
|
| 520 |
break;
|
| 521 |
default:
|
| 522 |
GGML_ASSERT(false);
|
|
|
|
| 527 |
constexpr int nwarps = 4;
|
| 528 |
switch (Q->ne[0]) {
|
| 529 |
case 64:
|
| 530 |
+
launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
|
| 531 |
break;
|
| 532 |
case 80:
|
| 533 |
+
launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
|
| 534 |
break;
|
| 535 |
case 96:
|
| 536 |
+
launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
|
| 537 |
break;
|
| 538 |
case 112:
|
| 539 |
+
launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
|
| 540 |
break;
|
| 541 |
case 128:
|
| 542 |
+
launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
|
| 543 |
break;
|
| 544 |
// case 256:
|
| 545 |
+
// launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
|
| 546 |
// break;
|
| 547 |
default:
|
| 548 |
GGML_ASSERT(false);
|
|
|
|
| 562 |
constexpr int nwarps = 4;
|
| 563 |
switch (Q->ne[0]) {
|
| 564 |
case 64:
|
| 565 |
+
launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
|
| 566 |
break;
|
| 567 |
case 96:
|
| 568 |
+
launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
|
| 569 |
break;
|
| 570 |
case 128:
|
| 571 |
+
launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
|
| 572 |
break;
|
| 573 |
case 256:
|
| 574 |
+
launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
|
| 575 |
break;
|
| 576 |
default:
|
| 577 |
GGML_ASSERT(false);
|
|
|
|
| 585 |
constexpr int nwarps = 4;
|
| 586 |
switch (Q->ne[0]) {
|
| 587 |
case 64:
|
| 588 |
+
launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
|
| 589 |
break;
|
| 590 |
case 80:
|
| 591 |
+
launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
|
| 592 |
break;
|
| 593 |
case 96:
|
| 594 |
+
launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
|
| 595 |
break;
|
| 596 |
case 112:
|
| 597 |
+
launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
|
| 598 |
break;
|
| 599 |
case 128:
|
| 600 |
+
launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
|
| 601 |
break;
|
| 602 |
case 256:
|
| 603 |
+
launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
|
| 604 |
break;
|
| 605 |
default:
|
| 606 |
GGML_ASSERT(false);
|
|
|
|
| 613 |
constexpr int nwarps = 4;
|
| 614 |
switch (Q->ne[0]) {
|
| 615 |
case 64:
|
| 616 |
+
launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
|
| 617 |
break;
|
| 618 |
case 80:
|
| 619 |
+
launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
|
| 620 |
break;
|
| 621 |
case 96:
|
| 622 |
+
launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
|
| 623 |
break;
|
| 624 |
case 112:
|
| 625 |
+
launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
|
| 626 |
break;
|
| 627 |
case 128:
|
| 628 |
+
launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
|
| 629 |
break;
|
| 630 |
case 256:
|
| 631 |
+
launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
|
| 632 |
break;
|
| 633 |
default:
|
| 634 |
GGML_ASSERT(false);
|
ggml-cuda/softmax.cu
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
#include "softmax.cuh"
|
| 2 |
|
| 3 |
template <typename T>
|
|
@@ -23,17 +24,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
|
|
| 23 |
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 24 |
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 25 |
|
| 26 |
-
float slope =
|
| 27 |
-
|
| 28 |
-
// ALiBi
|
| 29 |
-
if (max_bias > 0.0f) {
|
| 30 |
-
const int h = rowx/nrows_y; // head index
|
| 31 |
-
|
| 32 |
-
const float base = h < n_head_log2 ? m0 : m1;
|
| 33 |
-
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 34 |
-
|
| 35 |
-
slope = powf(base, exph);
|
| 36 |
-
}
|
| 37 |
|
| 38 |
extern __shared__ float data_soft_max_f32[];
|
| 39 |
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
#include "softmax.cuh"
|
| 3 |
|
| 4 |
template <typename T>
|
|
|
|
| 24 |
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 25 |
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 26 |
|
| 27 |
+
const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
extern __shared__ float data_soft_max_f32[];
|
| 30 |
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|