JohannesGaessler commited on
Commit
65ab3e8
·
1 Parent(s): ad83dfd

CUDA: deduplicate FlashAttention code (llama/7352)

Browse files
ggml-cuda/common.cuh CHANGED
@@ -477,6 +477,17 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -
477
 
478
  typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
479
 
 
 
 
 
 
 
 
 
 
 
 
480
 
481
  //////////////////////
482
 
 
477
 
478
  typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
479
 
480
+ static __device__ __forceinline__ float get_alibi_slope(
481
+ const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
482
+ ) {
483
+ if (max_bias <= 0.0f) {
484
+ return 1.0f;
485
+ }
486
+ const float base = h < n_head_log2 ? m0 : m1;
487
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
488
+
489
+ return powf(base, exph);
490
+ }
491
 
492
  //////////////////////
493
 
ggml-cuda/fattn-common.cuh CHANGED
@@ -1,7 +1,44 @@
 
 
 
 
1
  #define FATTN_KQ_STRIDE 256
2
  #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
3
  #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  template<int D, int parallel_blocks> // D == head size
6
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
7
  __launch_bounds__(D, 1)
@@ -45,3 +82,81 @@ static __global__ void flash_attn_combine_results(
45
 
46
  dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
47
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+
3
+ #include <cstdint>
4
+
5
  #define FATTN_KQ_STRIDE 256
6
  #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
7
  #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
8
 
9
+ typedef void (* fattn_kernel_t)(
10
+ const char * __restrict__ Q,
11
+ const char * __restrict__ K,
12
+ const char * __restrict__ V,
13
+ const char * __restrict__ mask,
14
+ float * __restrict__ dst,
15
+ float2 * __restrict__ dst_meta,
16
+ const float scale,
17
+ const float max_bias,
18
+ const float m0,
19
+ const float m1,
20
+ const uint32_t n_head_log2,
21
+ const int ne00,
22
+ const int ne01,
23
+ const int ne02,
24
+ const int ne03,
25
+ const int ne10,
26
+ const int ne11,
27
+ const int ne12,
28
+ const int ne13,
29
+ const int ne31,
30
+ const int nb31,
31
+ const int nb01,
32
+ const int nb02,
33
+ const int nb03,
34
+ const int nb11,
35
+ const int nb12,
36
+ const int nb13,
37
+ const int ne0,
38
+ const int ne1,
39
+ const int ne2,
40
+ const int ne3);
41
+
42
  template<int D, int parallel_blocks> // D == head size
43
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
44
  __launch_bounds__(D, 1)
 
82
 
83
  dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
84
  }
85
+
86
+ template <int D, int parallel_blocks>
87
+ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
88
+ const ggml_tensor * Q = dst->src[0];
89
+ const ggml_tensor * K = dst->src[1];
90
+ const ggml_tensor * V = dst->src[2];
91
+
92
+ const ggml_tensor * mask = dst->src[3];
93
+
94
+ ggml_tensor * KQV = dst;
95
+
96
+ GGML_ASSERT(Q->type == GGML_TYPE_F32);
97
+ GGML_ASSERT(K->type == GGML_TYPE_F16);
98
+ GGML_ASSERT(V->type == GGML_TYPE_F16);
99
+ GGML_ASSERT(KQV->type == GGML_TYPE_F32);
100
+
101
+ GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
102
+ GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
103
+ "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
104
+
105
+ GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
106
+
107
+ ggml_cuda_pool & pool = ctx.pool();
108
+ cudaStream_t main_stream = ctx.stream();
109
+
110
+ ggml_cuda_pool_alloc<float> dst_tmp(pool);
111
+ ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
112
+
113
+ if (parallel_blocks > 1) {
114
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
115
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
116
+ }
117
+
118
+ const dim3 block_dim(WARP_SIZE, nwarps, 1);
119
+ const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
120
+ const int shmem = 0;
121
+
122
+ float scale = 1.0f;
123
+ float max_bias = 0.0f;
124
+
125
+ memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
126
+ memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
127
+
128
+ const uint32_t n_head = Q->ne[2];
129
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
130
+
131
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
132
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
133
+
134
+ fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
135
+ (const char *) Q->data,
136
+ (const char *) K->data,
137
+ (const char *) V->data,
138
+ mask ? ((const char *) mask->data) : nullptr,
139
+ (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
140
+ scale, max_bias, m0, m1, n_head_log2,
141
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
142
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3],
143
+ mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
144
+ Q->nb[1], Q->nb[2], Q->nb[3],
145
+ K->nb[1], K->nb[2], K->nb[3],
146
+ KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
147
+ );
148
+ CUDA_CHECK(cudaGetLastError());
149
+
150
+ if ((parallel_blocks) == 1) {
151
+ return;
152
+ }
153
+
154
+ const dim3 block_dim_combine(D, 1, 1);
155
+ const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
156
+ const int shmem_combine = 0;
157
+
158
+ flash_attn_combine_results<D, parallel_blocks>
159
+ <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
160
+ (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
161
+ CUDA_CHECK(cudaGetLastError());
162
+ }
ggml-cuda/fattn-tile-f16.cu CHANGED
@@ -54,17 +54,8 @@ static __global__ void flash_attn_tile_ext_f16(
54
 
55
  const int stride_KV2 = nb11 / sizeof(half2);
56
 
57
- half slopeh = __float2half(1.0f);
58
-
59
- // ALiBi
60
- if (max_bias > 0.0f) {
61
- const uint32_t h = blockIdx.y;
62
-
63
- const float base = h < n_head_log2 ? m0 : m1;
64
- const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
65
-
66
- slopeh = __float2half(powf(base, exph));
67
- }
68
 
69
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
70
 
@@ -272,124 +263,50 @@ static __global__ void flash_attn_tile_ext_f16(
272
  #endif // FP16_AVAILABLE
273
  }
274
 
275
- template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f16(
276
- const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
277
- ggml_cuda_pool & pool, cudaStream_t main_stream
278
- ) {
279
- ggml_cuda_pool_alloc<float> dst_tmp(pool);
280
- ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
281
-
282
- if (parallel_blocks > 1) {
283
- dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
284
- dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
285
- }
286
-
287
- constexpr int nwarps = 8;
288
- const dim3 block_dim(WARP_SIZE, nwarps, 1);
289
- const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
290
- const int shmem = 0;
291
-
292
- float scale = 1.0f;
293
- float max_bias = 0.0f;
294
-
295
- memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
296
- memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
297
-
298
- const uint32_t n_head = Q->ne[2];
299
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
300
-
301
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
302
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
303
-
304
- flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>
305
- <<<blocks_num, block_dim, shmem, main_stream>>> (
306
- (const char *) Q->data,
307
- (const char *) K->data,
308
- (const char *) V->data,
309
- mask ? ((const char *) mask->data) : nullptr,
310
- parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
311
- scale, max_bias, m0, m1, n_head_log2,
312
- Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
313
- K->ne[0], K->ne[1], K->ne[2], K->ne[3],
314
- mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
315
- Q->nb[1], Q->nb[2], Q->nb[3],
316
- K->nb[1], K->nb[2], K->nb[3],
317
- KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
318
- );
319
- CUDA_CHECK(cudaGetLastError());
320
-
321
- if (parallel_blocks == 1) {
322
- return;
323
  }
324
-
325
- const dim3 block_dim_combine(D, 1, 1);
326
- const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
327
- const int shmem_combine = 0;
328
-
329
- flash_attn_combine_results<D, parallel_blocks>
330
- <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
331
- (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
332
- CUDA_CHECK(cudaGetLastError());
333
  }
334
 
335
  void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
336
- const ggml_tensor * Q = dst->src[0];
337
- const ggml_tensor * K = dst->src[1];
338
- const ggml_tensor * V = dst->src[2];
339
-
340
- const ggml_tensor * mask = dst->src[3];
341
-
342
- ggml_tensor * KQV = dst;
343
 
344
  const int32_t precision = KQV->op_params[2];
345
  GGML_ASSERT(precision == GGML_PREC_DEFAULT);
346
- GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
347
 
348
  if (Q->ne[1] <= 16) {
349
  constexpr int cols_per_block = 16;
350
  constexpr int parallel_blocks = 4;
351
- switch (Q->ne[0]) {
352
- case 64:
353
- launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
354
- break;
355
- case 128:
356
- launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
357
- break;
358
- default:
359
- GGML_ASSERT(false);
360
- break;
361
- }
362
  return;
363
  }
364
 
365
  if (Q->ne[1] <= 32) {
366
  constexpr int cols_per_block = 32;
367
  constexpr int parallel_blocks = 4;
368
- switch (Q->ne[0]) {
369
- case 64:
370
- launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
371
- break;
372
- case 128:
373
- launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
374
- break;
375
- default:
376
- GGML_ASSERT(false);
377
- break;
378
- }
379
  return;
380
  }
381
 
382
  constexpr int cols_per_block = 32;
383
  constexpr int parallel_blocks = 1;
384
- switch (Q->ne[0]) {
385
- case 64:
386
- launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
387
- break;
388
- case 128:
389
- launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
390
- break;
391
- default:
392
- GGML_ASSERT(false);
393
- break;
394
- }
395
  }
 
54
 
55
  const int stride_KV2 = nb11 / sizeof(half2);
56
 
57
+ const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
58
+ const half slopeh = __float2half(slopef);
 
 
 
 
 
 
 
 
 
59
 
60
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
61
 
 
263
  #endif // FP16_AVAILABLE
264
  }
265
 
266
+ template <int cols_per_block, int parallel_blocks>
267
+ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
268
+ const ggml_tensor * Q = dst->src[0];
269
+ switch (Q->ne[0]) {
270
+ case 64: {
271
+ constexpr int D = 64;
272
+ constexpr int nwarps = 8;
273
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
274
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
275
+ } break;
276
+ case 128: {
277
+ constexpr int D = 128;
278
+ constexpr int nwarps = 8;
279
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
280
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
281
+ } break;
282
+ default: {
283
+ GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
284
+ } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  }
 
 
 
 
 
 
 
 
 
286
  }
287
 
288
  void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
289
+ const ggml_tensor * KQV = dst;
290
+ const ggml_tensor * Q = dst->src[0];
 
 
 
 
 
291
 
292
  const int32_t precision = KQV->op_params[2];
293
  GGML_ASSERT(precision == GGML_PREC_DEFAULT);
 
294
 
295
  if (Q->ne[1] <= 16) {
296
  constexpr int cols_per_block = 16;
297
  constexpr int parallel_blocks = 4;
298
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
 
 
 
 
 
 
 
 
 
 
299
  return;
300
  }
301
 
302
  if (Q->ne[1] <= 32) {
303
  constexpr int cols_per_block = 32;
304
  constexpr int parallel_blocks = 4;
305
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
 
 
 
 
 
 
 
 
 
 
306
  return;
307
  }
308
 
309
  constexpr int cols_per_block = 32;
310
  constexpr int parallel_blocks = 1;
311
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
 
 
 
 
 
 
 
 
 
 
312
  }
ggml-cuda/fattn-tile-f32.cu CHANGED
@@ -53,17 +53,7 @@ static __global__ void flash_attn_tile_ext_f32(
53
 
54
  const int stride_KV2 = nb11 / sizeof(half2);
55
 
56
- float slope = 1.0f;
57
-
58
- // ALiBi
59
- if (max_bias > 0.0f) {
60
- const uint32_t h = blockIdx.y;
61
-
62
- const float base = h < n_head_log2 ? m0 : m1;
63
- const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
64
-
65
- slope = powf(base, exph);
66
- }
67
 
68
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
69
 
@@ -270,124 +260,50 @@ static __global__ void flash_attn_tile_ext_f32(
270
  }
271
  }
272
 
273
- template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f32(
274
- const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
275
- ggml_cuda_pool & pool, cudaStream_t main_stream
276
- ) {
277
- ggml_cuda_pool_alloc<float> dst_tmp(pool);
278
- ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
279
-
280
- if (parallel_blocks > 1) {
281
- dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
282
- dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
283
- }
284
-
285
- constexpr int nwarps = 8;
286
- const dim3 block_dim(WARP_SIZE, nwarps, 1);
287
- const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
288
- const int shmem = 0;
289
-
290
- float scale = 1.0f;
291
- float max_bias = 0.0f;
292
-
293
- memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
294
- memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
295
-
296
- const uint32_t n_head = Q->ne[2];
297
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
298
-
299
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
300
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
301
-
302
- flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>
303
- <<<blocks_num, block_dim, shmem, main_stream>>> (
304
- (const char *) Q->data,
305
- (const char *) K->data,
306
- (const char *) V->data,
307
- mask ? ((const char *) mask->data) : nullptr,
308
- parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
309
- scale, max_bias, m0, m1, n_head_log2,
310
- Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
311
- K->ne[0], K->ne[1], K->ne[2], K->ne[3],
312
- mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
313
- Q->nb[1], Q->nb[2], Q->nb[3],
314
- K->nb[1], K->nb[2], K->nb[3],
315
- KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
316
- );
317
- CUDA_CHECK(cudaGetLastError());
318
-
319
- if (parallel_blocks == 1) {
320
- return;
321
  }
322
-
323
- const dim3 block_dim_combine(D, 1, 1);
324
- const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
325
- const int shmem_combine = 0;
326
-
327
- flash_attn_combine_results<D, parallel_blocks>
328
- <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
329
- (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
330
- CUDA_CHECK(cudaGetLastError());
331
  }
332
 
333
  void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
334
- const ggml_tensor * Q = dst->src[0];
335
- const ggml_tensor * K = dst->src[1];
336
- const ggml_tensor * V = dst->src[2];
337
-
338
- const ggml_tensor * mask = dst->src[3];
339
-
340
- ggml_tensor * KQV = dst;
341
 
342
  const int32_t precision = KQV->op_params[2];
343
  GGML_ASSERT(precision == GGML_PREC_DEFAULT);
344
- GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
345
 
346
  if (Q->ne[1] <= 16) {
347
  constexpr int cols_per_block = 16;
348
  constexpr int parallel_blocks = 4;
349
- switch (Q->ne[0]) {
350
- case 64:
351
- launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
352
- break;
353
- case 128:
354
- launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
355
- break;
356
- default:
357
- GGML_ASSERT(false);
358
- break;
359
- }
360
  return;
361
  }
362
 
363
  if (Q->ne[1] <= 32) {
364
  constexpr int cols_per_block = 32;
365
  constexpr int parallel_blocks = 4;
366
- switch (Q->ne[0]) {
367
- case 64:
368
- launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
369
- break;
370
- case 128:
371
- launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
372
- break;
373
- default:
374
- GGML_ASSERT(false);
375
- break;
376
- }
377
  return;
378
  }
379
 
380
  constexpr int cols_per_block = 32;
381
  constexpr int parallel_blocks = 1;
382
- switch (Q->ne[0]) {
383
- case 64:
384
- launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
385
- break;
386
- case 128:
387
- launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
388
- break;
389
- default:
390
- GGML_ASSERT(false);
391
- break;
392
- }
393
  }
 
53
 
54
  const int stride_KV2 = nb11 / sizeof(half2);
55
 
56
+ const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
 
 
 
 
 
 
 
 
 
 
57
 
58
  static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
59
 
 
260
  }
261
  }
262
 
263
+ template <int cols_per_block, int parallel_blocks>
264
+ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
265
+ const ggml_tensor * Q = dst->src[0];
266
+ switch (Q->ne[0]) {
267
+ case 64: {
268
+ constexpr int D = 64;
269
+ constexpr int nwarps = 8;
270
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
271
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
272
+ } break;
273
+ case 128: {
274
+ constexpr int D = 128;
275
+ constexpr int nwarps = 8;
276
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
277
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
278
+ } break;
279
+ default: {
280
+ GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
281
+ } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  }
 
 
 
 
 
 
 
 
 
283
  }
284
 
285
  void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
286
+ const ggml_tensor * KQV = dst;
287
+ const ggml_tensor * Q = dst->src[0];
 
 
 
 
 
288
 
289
  const int32_t precision = KQV->op_params[2];
290
  GGML_ASSERT(precision == GGML_PREC_DEFAULT);
 
291
 
292
  if (Q->ne[1] <= 16) {
293
  constexpr int cols_per_block = 16;
294
  constexpr int parallel_blocks = 4;
295
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
 
 
 
 
 
 
 
 
 
 
296
  return;
297
  }
298
 
299
  if (Q->ne[1] <= 32) {
300
  constexpr int cols_per_block = 32;
301
  constexpr int parallel_blocks = 4;
302
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
 
 
 
 
 
 
 
 
 
 
303
  return;
304
  }
305
 
306
  constexpr int cols_per_block = 32;
307
  constexpr int parallel_blocks = 1;
308
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
 
 
 
 
 
 
 
 
 
 
309
  }
ggml-cuda/fattn.cu CHANGED
@@ -85,19 +85,9 @@ static __global__ void flash_attn_ext_f16(
85
  const int stride_Q = nb01 / sizeof(float);
86
  const int stride_KV = nb11 / sizeof(half);
87
 
88
- half slopeh = __float2half(1.0f);
89
- half2 slope2 = make_half2(1.0f, 1.0f);
90
-
91
- // ALiBi
92
- if (max_bias > 0.0f) {
93
- const uint32_t h = blockIdx.y;
94
-
95
- const float base = h < n_head_log2 ? m0 : m1;
96
- const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
97
-
98
- slopeh = __float2half(powf(base, exph));
99
- slope2 = make_half2(slopeh, slopeh);
100
- }
101
 
102
  frag_b Q_b[D/16][ncols/frag_n];
103
 
@@ -439,108 +429,37 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
439
  static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
440
  static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
441
 
442
- template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl(
443
- const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
444
- ggml_cuda_pool & pool, cudaStream_t main_stream
445
- ) {
446
- ggml_cuda_pool_alloc<float> dst_tmp(pool);
447
- ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
448
-
449
- if (parallel_blocks > 1) {
450
- dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
451
- dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
452
- }
453
-
454
- constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16;
455
- const dim3 block_dim(WARP_SIZE, nwarps, 1);
456
- const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
457
- const int shmem = 0;
458
-
459
- float scale = 1.0f;
460
- float max_bias = 0.0f;
461
-
462
- memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
463
- memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
464
-
465
- const uint32_t n_head = Q->ne[2];
466
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
467
-
468
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
469
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
470
-
471
- flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
472
- <<<blocks_num, block_dim, shmem, main_stream>>> (
473
- (const char *) Q->data,
474
- (const char *) K->data,
475
- (const char *) V->data,
476
- mask ? ((const char *) mask->data) : nullptr,
477
- (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
478
- scale, max_bias, m0, m1, n_head_log2,
479
- Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
480
- K->ne[0], K->ne[1], K->ne[2], K->ne[3],
481
- mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
482
- Q->nb[1], Q->nb[2], Q->nb[3],
483
- K->nb[1], K->nb[2], K->nb[3],
484
- KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
485
- );
486
- CUDA_CHECK(cudaGetLastError());
487
-
488
- if ((parallel_blocks) == 1) {
489
- return;
490
- }
491
-
492
- const dim3 block_dim_combine(D, 1, 1);
493
- const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
494
- const int shmem_combine = 0;
495
-
496
- flash_attn_combine_results<D, parallel_blocks>
497
- <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
498
- (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
499
- CUDA_CHECK(cudaGetLastError());
500
- }
501
 
502
- template <int D, int cols_per_block, int nwarps, typename KQ_acc_t> void launch_fattn_f16(
503
- const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
504
- const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream
505
- ) {
506
  const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
 
507
 
508
  if (4*blocks_num_pb1 < 2*nsm) {
509
- launch_fattn_f16_impl<D, cols_per_block, nwarps, 4, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
 
 
510
  return;
511
  }
512
  if (2*blocks_num_pb1 < 2*nsm) {
513
- launch_fattn_f16_impl<D, cols_per_block, nwarps, 2, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
 
 
514
  return;
515
  }
516
- launch_fattn_f16_impl<D, cols_per_block, nwarps, 1, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
 
 
517
  }
518
 
519
  void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
520
- const ggml_tensor * Q = dst->src[0];
521
- const ggml_tensor * K = dst->src[1];
522
- const ggml_tensor * V = dst->src[2];
523
-
524
- const ggml_tensor * mask = dst->src[3];
525
-
526
- ggml_tensor * KQV = dst;
527
-
528
- GGML_ASSERT(Q->type == GGML_TYPE_F32);
529
- GGML_ASSERT(K->type == GGML_TYPE_F16);
530
- GGML_ASSERT(V->type == GGML_TYPE_F16);
531
- GGML_ASSERT(KQV->type == GGML_TYPE_F32);
532
-
533
- GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
534
- GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
535
- "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
536
-
537
- GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
538
 
539
  ggml_cuda_set_device(ctx.device);
540
-
541
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
542
- const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
543
-
544
  const int32_t precision = KQV->op_params[2];
545
 
546
  // On AMD the tile kernels perform poorly, use the vec kernel instead:
@@ -582,22 +501,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
582
  constexpr int nwarps = 4;
583
  switch (Q->ne[0]) {
584
  case 64:
585
- launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
586
  break;
587
  case 80:
588
- launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
589
  break;
590
  case 96:
591
- launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
592
  break;
593
  case 112:
594
- launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
595
  break;
596
  case 128:
597
- launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
598
  break;
599
  case 256:
600
- launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
601
  break;
602
  default:
603
  GGML_ASSERT(false);
@@ -608,22 +527,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
608
  constexpr int nwarps = 4;
609
  switch (Q->ne[0]) {
610
  case 64:
611
- launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
612
  break;
613
  case 80:
614
- launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
615
  break;
616
  case 96:
617
- launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
618
  break;
619
  case 112:
620
- launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
621
  break;
622
  case 128:
623
- launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
624
  break;
625
  // case 256:
626
- // launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
627
  // break;
628
  default:
629
  GGML_ASSERT(false);
@@ -643,16 +562,16 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
643
  constexpr int nwarps = 4;
644
  switch (Q->ne[0]) {
645
  case 64:
646
- launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
647
  break;
648
  case 96:
649
- launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
650
  break;
651
  case 128:
652
- launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
653
  break;
654
  case 256:
655
- launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
656
  break;
657
  default:
658
  GGML_ASSERT(false);
@@ -666,22 +585,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
666
  constexpr int nwarps = 4;
667
  switch (Q->ne[0]) {
668
  case 64:
669
- launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
670
  break;
671
  case 80:
672
- launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
673
  break;
674
  case 96:
675
- launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
676
  break;
677
  case 112:
678
- launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
679
  break;
680
  case 128:
681
- launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
682
  break;
683
  case 256:
684
- launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
685
  break;
686
  default:
687
  GGML_ASSERT(false);
@@ -694,22 +613,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
694
  constexpr int nwarps = 4;
695
  switch (Q->ne[0]) {
696
  case 64:
697
- launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
698
  break;
699
  case 80:
700
- launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
701
  break;
702
  case 96:
703
- launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
704
  break;
705
  case 112:
706
- launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
707
  break;
708
  case 128:
709
- launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
710
  break;
711
  case 256:
712
- launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
713
  break;
714
  default:
715
  GGML_ASSERT(false);
 
85
  const int stride_Q = nb01 / sizeof(float);
86
  const int stride_KV = nb11 / sizeof(half);
87
 
88
+ const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
89
+ const half slopeh = __float2half(slopef);
90
+ const half2 slope2 = make_half2(slopef, slopef);
 
 
 
 
 
 
 
 
 
 
91
 
92
  frag_b Q_b[D/16][ncols/frag_n];
93
 
 
429
  static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
430
  static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
431
 
432
+ template <int D, int cols_per_block, int nwarps, typename KQ_acc_t>
433
+ void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
434
+ const ggml_tensor * Q = dst->src[0];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
+ constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
 
 
 
437
  const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
438
+ const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
439
 
440
  if (4*blocks_num_pb1 < 2*nsm) {
441
+ constexpr int parallel_blocks = 4;
442
+ fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
443
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
444
  return;
445
  }
446
  if (2*blocks_num_pb1 < 2*nsm) {
447
+ constexpr int parallel_blocks = 2;
448
+ fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
449
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
450
  return;
451
  }
452
+ constexpr int parallel_blocks = 1;
453
+ fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
454
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
455
  }
456
 
457
  void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
458
+ const ggml_tensor * KQV = dst;
459
+ const ggml_tensor * Q = dst->src[0];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
  ggml_cuda_set_device(ctx.device);
462
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 
 
 
463
  const int32_t precision = KQV->op_params[2];
464
 
465
  // On AMD the tile kernels perform poorly, use the vec kernel instead:
 
501
  constexpr int nwarps = 4;
502
  switch (Q->ne[0]) {
503
  case 64:
504
+ launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
505
  break;
506
  case 80:
507
+ launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
508
  break;
509
  case 96:
510
+ launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
511
  break;
512
  case 112:
513
+ launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
514
  break;
515
  case 128:
516
+ launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
517
  break;
518
  case 256:
519
+ launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
520
  break;
521
  default:
522
  GGML_ASSERT(false);
 
527
  constexpr int nwarps = 4;
528
  switch (Q->ne[0]) {
529
  case 64:
530
+ launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
531
  break;
532
  case 80:
533
+ launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
534
  break;
535
  case 96:
536
+ launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
537
  break;
538
  case 112:
539
+ launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
540
  break;
541
  case 128:
542
+ launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
543
  break;
544
  // case 256:
545
+ // launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
546
  // break;
547
  default:
548
  GGML_ASSERT(false);
 
562
  constexpr int nwarps = 4;
563
  switch (Q->ne[0]) {
564
  case 64:
565
+ launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
566
  break;
567
  case 96:
568
+ launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
569
  break;
570
  case 128:
571
+ launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
572
  break;
573
  case 256:
574
+ launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
575
  break;
576
  default:
577
  GGML_ASSERT(false);
 
585
  constexpr int nwarps = 4;
586
  switch (Q->ne[0]) {
587
  case 64:
588
+ launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
589
  break;
590
  case 80:
591
+ launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
592
  break;
593
  case 96:
594
+ launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
595
  break;
596
  case 112:
597
+ launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
598
  break;
599
  case 128:
600
+ launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
601
  break;
602
  case 256:
603
+ launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
604
  break;
605
  default:
606
  GGML_ASSERT(false);
 
613
  constexpr int nwarps = 4;
614
  switch (Q->ne[0]) {
615
  case 64:
616
+ launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
617
  break;
618
  case 80:
619
+ launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
620
  break;
621
  case 96:
622
+ launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
623
  break;
624
  case 112:
625
+ launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
626
  break;
627
  case 128:
628
+ launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
629
  break;
630
  case 256:
631
+ launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
632
  break;
633
  default:
634
  GGML_ASSERT(false);
ggml-cuda/softmax.cu CHANGED
@@ -1,3 +1,4 @@
 
1
  #include "softmax.cuh"
2
 
3
  template <typename T>
@@ -23,17 +24,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
23
  const int warp_id = threadIdx.x / WARP_SIZE;
24
  const int lane_id = threadIdx.x % WARP_SIZE;
25
 
26
- float slope = 1.0f;
27
-
28
- // ALiBi
29
- if (max_bias > 0.0f) {
30
- const int h = rowx/nrows_y; // head index
31
-
32
- const float base = h < n_head_log2 ? m0 : m1;
33
- const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
34
-
35
- slope = powf(base, exph);
36
- }
37
 
38
  extern __shared__ float data_soft_max_f32[];
39
  float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
 
1
+ #include "common.cuh"
2
  #include "softmax.cuh"
3
 
4
  template <typename T>
 
24
  const int warp_id = threadIdx.x / WARP_SIZE;
25
  const int lane_id = threadIdx.x % WARP_SIZE;
26
 
27
+ const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
 
 
 
 
 
 
 
 
 
 
28
 
29
  extern __shared__ float data_soft_max_f32[];
30
  float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication