JohannesGaessler commited on
Commit
315df8c
·
1 Parent(s): c1442f3

CUDA: quantized KV support for FA vec (llama/7527)

Browse files

* CUDA: quantized KV support for FA vec

* try CI fix

* fix commented-out kernel variants

* add q8_0 q4_0 tests

* fix nwarps > batch size

* split fattn compile via extern templates

* fix flake8

* fix metal tests

* fix cmake

* make generate_cu_files.py executable

* add autogenerated .cu files

* fix AMD

* error if type_v != FP16 and not flash_attn

* remove obsolete code

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ggml-cuda.cu +6 -2
  2. ggml-cuda/fattn-common.cuh +537 -2
  3. ggml-cuda/fattn-tile-f16.cu +3 -0
  4. ggml-cuda/fattn-tile-f32.cu +3 -0
  5. ggml-cuda/fattn-vec-f16.cuh +392 -2
  6. ggml-cuda/fattn-vec-f32.cuh +374 -1
  7. ggml-cuda/fattn-wmma-f16.cuh +490 -0
  8. ggml-cuda/fattn.cu +232 -521
  9. ggml-cuda/mmq.cu +2 -2
  10. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  11. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  12. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  13. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  14. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  15. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  16. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  17. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  18. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  19. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  20. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  21. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  22. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  23. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  24. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  25. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  26. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  27. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  28. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  29. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  30. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  31. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  32. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  33. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  34. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  35. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  36. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  37. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  38. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  39. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  40. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  41. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  42. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  43. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  44. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  45. ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  46. ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  47. ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  48. ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  49. ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  50. ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
ggml-cuda.cu CHANGED
@@ -2905,10 +2905,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2905
  #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2906
  return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
2907
  #else
2908
- if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
2909
  return true;
2910
  }
2911
- return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
 
 
 
 
2912
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2913
  default:
2914
  return false;
 
2905
  #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2906
  return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
2907
  #else
2908
+ if (op->src[0]->ne[0] == 128) {
2909
  return true;
2910
  }
2911
+ if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
2912
+ return true;
2913
+ }
2914
+ return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
2915
+ op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
2916
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2917
  default:
2918
  return false;
ggml-cuda/fattn-common.cuh CHANGED
@@ -1,4 +1,7 @@
 
 
1
  #include "common.cuh"
 
2
 
3
  #include <cstdint>
4
 
@@ -34,11 +37,523 @@ typedef void (* fattn_kernel_t)(
34
  const int nb11,
35
  const int nb12,
36
  const int nb13,
 
 
 
37
  const int ne0,
38
  const int ne1,
39
  const int ne2,
40
  const int ne3);
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  template<int D, int parallel_blocks> // D == head size
43
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
44
  __launch_bounds__(D, 1)
@@ -83,6 +598,27 @@ static __global__ void flash_attn_combine_results(
83
  dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
84
  }
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  template <int D, int parallel_blocks>
87
  void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
88
  const ggml_tensor * Q = dst->src[0];
@@ -94,8 +630,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
94
  ggml_tensor * KQV = dst;
95
 
96
  GGML_ASSERT(Q->type == GGML_TYPE_F32);
97
- GGML_ASSERT(K->type == GGML_TYPE_F16);
98
- GGML_ASSERT(V->type == GGML_TYPE_F16);
99
  GGML_ASSERT(KQV->type == GGML_TYPE_F32);
100
 
101
  GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
@@ -143,6 +677,7 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
143
  mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
144
  Q->nb[1], Q->nb[2], Q->nb[3],
145
  K->nb[1], K->nb[2], K->nb[3],
 
146
  KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
147
  );
148
  CUDA_CHECK(cudaGetLastError());
 
1
+ #pragma once
2
+
3
  #include "common.cuh"
4
+ #include "vecdotq.cuh"
5
 
6
  #include <cstdint>
7
 
 
37
  const int nb11,
38
  const int nb12,
39
  const int nb13,
40
+ const int nb21,
41
+ const int nb22,
42
+ const int nb23,
43
  const int ne0,
44
  const int ne1,
45
  const int ne2,
46
  const int ne3);
47
 
48
+ typedef half (*vec_dot_KQ_f16_t)(
49
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
50
+ typedef float (*vec_dot_KQ_f32_t)(
51
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
52
+
53
+ template<typename T, int D>
54
+ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
55
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
56
+ #if __CUDA_ARCH__ > MIN_CC_DP4A
57
+
58
+ const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
59
+ GGML_UNUSED(Q_v);
60
+
61
+ half sum = 0.0f;
62
+
63
+ #pragma unroll
64
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
65
+ const int k_KQ = k_KQ_0 + threadIdx.x;
66
+
67
+ const int ib = k_KQ / QI8_1;
68
+ const int iqs4 = k_KQ % QI4_0;
69
+ const int shift = k_KQ & (QI8_1/2);
70
+
71
+ const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
72
+ const int u = Q_q8[k_KQ_0/WARP_SIZE];
73
+
74
+ const int sumi = __dp4a(v, u, 0);
75
+
76
+ #if FP16_AVAILABLE
77
+ if (std::is_same<T, half>::value) {
78
+ const half2 * Q_ds = (const half2 *) Q_ds_v;
79
+
80
+ const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
81
+ sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
82
+ } else
83
+ #endif // FP16_AVAILABLE
84
+ {
85
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
86
+
87
+ sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
88
+ }
89
+ }
90
+
91
+ return sum;
92
+ #else
93
+ GGML_UNUSED(K_c);
94
+ GGML_UNUSED(Q_v);
95
+ GGML_UNUSED(Q_q8);
96
+ GGML_UNUSED(Q_ds_v);
97
+ NO_DEVICE_CODE;
98
+ #endif // __CUDA_ARCH__ > MIN_CC_DP4A
99
+ }
100
+
101
+ template<typename T, int D>
102
+ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
103
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
104
+ #if __CUDA_ARCH__ > MIN_CC_DP4A
105
+
106
+ const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
107
+ GGML_UNUSED(Q_v);
108
+
109
+ T sum = 0.0f;
110
+
111
+ #pragma unroll
112
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
113
+ const int k_KQ = k_KQ_0 + threadIdx.x;
114
+
115
+ const int ib = k_KQ / QI8_1;
116
+ const int iqs4 = k_KQ % QI4_1;
117
+ const int shift = k_KQ & (QI8_1/2);
118
+
119
+ const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
120
+ const int u = Q_q8[k_KQ_0/WARP_SIZE];
121
+
122
+ const int sumi = __dp4a(v, u, 0);
123
+
124
+ #if FP16_AVAILABLE
125
+ if (std::is_same<T, half>::value) {
126
+ const half2 * Q_ds = (const half2 *) Q_ds_v;
127
+
128
+ const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
129
+ const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
130
+ sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
131
+ } else
132
+ #endif // FP16_AVAILABLE
133
+ {
134
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
135
+
136
+ const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
137
+ const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
138
+
139
+ sum += (T) (sumid4d8 + m4s8scaled);
140
+ }
141
+ }
142
+
143
+ return sum;
144
+ #else
145
+ GGML_UNUSED(K_c);
146
+ GGML_UNUSED(Q_v);
147
+ GGML_UNUSED(Q_q8);
148
+ GGML_UNUSED(Q_ds_v);
149
+ NO_DEVICE_CODE;
150
+ #endif // __CUDA_ARCH__ > MIN_CC_DP4A
151
+ }
152
+
153
+ template<typename T, int D>
154
+ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
155
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
156
+ #if __CUDA_ARCH__ > MIN_CC_DP4A
157
+
158
+ const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
159
+ GGML_UNUSED(Q_v);
160
+
161
+ T sum = 0.0f;
162
+
163
+ #pragma unroll
164
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
165
+ const int k_KQ = k_KQ_0 + threadIdx.x;
166
+
167
+ const int ib = k_KQ / QI8_1;
168
+ const int iqs4 = k_KQ % QI5_0;
169
+ const int iqs8 = k_KQ % QI8_1;
170
+ const int shift = k_KQ & (QI8_1/2);
171
+
172
+ int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
173
+ const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
174
+ v |= (vh << 4) & 0x00000010; // 0 -> 4
175
+ v |= (vh << 11) & 0x00001000; // 1 -> 12
176
+ v |= (vh << 18) & 0x00100000; // 2 -> 20
177
+ v |= (vh << 25) & 0x10000000; // 3 -> 28
178
+
179
+ const int u = Q_q8[k_KQ_0/WARP_SIZE];
180
+
181
+ const int sumi = __dp4a(v, u, 0);
182
+
183
+ #if FP16_AVAILABLE
184
+ if (std::is_same<T, half>::value) {
185
+ const half2 * Q_ds = (const half2 *) Q_ds_v;
186
+
187
+ const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
188
+ sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
189
+ } else
190
+ #endif // FP16_AVAILABLE
191
+ {
192
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
193
+
194
+ sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
195
+ }
196
+ }
197
+
198
+ return sum;
199
+ #else
200
+ GGML_UNUSED(K_c);
201
+ GGML_UNUSED(Q_v);
202
+ GGML_UNUSED(Q_q8);
203
+ GGML_UNUSED(Q_ds_v);
204
+ NO_DEVICE_CODE;
205
+ #endif // __CUDA_ARCH__ > MIN_CC_DP4A
206
+ }
207
+
208
+ template<typename T, int D>
209
+ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
210
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
211
+ #if __CUDA_ARCH__ > MIN_CC_DP4A
212
+
213
+ const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
214
+ GGML_UNUSED(Q_v);
215
+
216
+ T sum = 0.0f;
217
+
218
+ #pragma unroll
219
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
220
+ const int k_KQ = k_KQ_0 + threadIdx.x;
221
+
222
+ const int ib = k_KQ / QI8_1;
223
+ const int iqs4 = k_KQ % QI5_1;
224
+ const int iqs8 = k_KQ % QI8_1;
225
+ const int shift = k_KQ & (QI8_1/2);
226
+
227
+ int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
228
+ const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
229
+ v |= (vh << 4) & 0x00000010; // 0 -> 4
230
+ v |= (vh << 11) & 0x00001000; // 1 -> 12
231
+ v |= (vh << 18) & 0x00100000; // 2 -> 20
232
+ v |= (vh << 25) & 0x10000000; // 3 -> 28
233
+
234
+ const int u = Q_q8[k_KQ_0/WARP_SIZE];
235
+
236
+ const int sumi = __dp4a(v, u, 0);
237
+
238
+ #if FP16_AVAILABLE
239
+ if (std::is_same<T, half>::value) {
240
+ const half2 * Q_ds = (const half2 *) Q_ds_v;
241
+
242
+ const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
243
+ const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
244
+ sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
245
+ } else
246
+ #endif // FP16_AVAILABLE
247
+ {
248
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
249
+
250
+ const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
251
+ const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
252
+
253
+ sum += (T) (sumid5d8 + m5s8scaled);
254
+ }
255
+ }
256
+
257
+ return sum;
258
+ #else
259
+ GGML_UNUSED(K_c);
260
+ GGML_UNUSED(Q_v);
261
+ GGML_UNUSED(Q_q8);
262
+ GGML_UNUSED(Q_ds_v);
263
+ NO_DEVICE_CODE;
264
+ #endif // __CUDA_ARCH__ > MIN_CC_DP4A
265
+ }
266
+
267
+ template <typename T, int D>
268
+ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
269
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
270
+ #if __CUDA_ARCH__ > MIN_CC_DP4A
271
+
272
+ const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
273
+ GGML_UNUSED(Q_v);
274
+
275
+ T sum = 0.0f;
276
+
277
+ #pragma unroll
278
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
279
+ const int k_KQ = k_KQ_0 + threadIdx.x;
280
+
281
+ const int ib = k_KQ / QI8_0;
282
+ const int iqs = k_KQ % QI8_0;
283
+
284
+ const int v = get_int_from_int8(K_q8_0[ib].qs, iqs);
285
+
286
+ T Q_d;
287
+ if (std::is_same<T, half>::value) {
288
+ const half2 * Q_ds = (const half2 *) Q_ds_v;
289
+ Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
290
+ } else {
291
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
292
+ Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
293
+ }
294
+
295
+ sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
296
+ }
297
+
298
+ return sum;
299
+ #else
300
+ GGML_UNUSED(K_c);
301
+ GGML_UNUSED(Q_v);
302
+ GGML_UNUSED(Q_q8);
303
+ GGML_UNUSED(Q_ds_v);
304
+ NO_DEVICE_CODE;
305
+ #endif // __CUDA_ARCH__ > MIN_CC_DP4A
306
+ }
307
+
308
+ template <typename T, int D>
309
+ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
310
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
311
+
312
+ const half2 * K_h2 = (const half2 *) K_c;
313
+ GGML_UNUSED(Q_q8);
314
+ GGML_UNUSED(Q_ds_v);
315
+
316
+ #if FP16_AVAILABLE
317
+ if (std::is_same<T, half>::value) {
318
+ const half2 * Q_h2 = (const half2 *) Q_v;
319
+
320
+ half2 sum2 = make_half2(0.0f, 0.0f);
321
+
322
+ #pragma unroll
323
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
324
+ const int k_KQ = k_KQ_0 + threadIdx.x;
325
+
326
+ const half2 K_ik = K_h2[k_KQ];
327
+ sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
328
+ }
329
+
330
+ return __low2half(sum2) + __high2half(sum2);
331
+ }
332
+ #endif // FP16_AVAILABLE
333
+
334
+ const float2 * Q_f2 = (const float2 *) Q_v;
335
+
336
+ float sum = 0.0f;
337
+
338
+ #pragma unroll
339
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
340
+ const int k_KQ = k_KQ_0 + threadIdx.x;
341
+
342
+ const half2 K_ik = K_h2[k_KQ];
343
+ sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
344
+ sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
345
+ }
346
+
347
+ return sum;
348
+ }
349
+
350
+ template <typename Tds>
351
+ static __device__ __forceinline__ void quantize_q8_1_to_shared(
352
+ const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
353
+
354
+ float vals[sizeof(int)] = {0.0f};
355
+ #pragma unroll
356
+ for (int l = 0; l < sizeof(int); ++l) {
357
+ vals[l] = scale * x[4*threadIdx.x + l];
358
+ }
359
+
360
+ float amax = fabsf(vals[0]);
361
+ float sum = vals[0];
362
+ #pragma unroll
363
+ for (int l = 1; l < sizeof(int); ++l) {
364
+ amax = fmaxf(amax, fabsf(vals[l]));
365
+ sum += vals[l];
366
+ }
367
+ #pragma unroll
368
+ for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
369
+ amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));
370
+ sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32);
371
+ }
372
+
373
+ const float d = amax / 127;
374
+ int q32 = 0;
375
+ int8_t * q8 = (int8_t *) &q32;
376
+
377
+ if (d != 0.0f) {
378
+ #pragma unroll
379
+ for (int l = 0; l < sizeof(int); ++l) {
380
+ q8[l] = roundf(vals[l] / d);
381
+ }
382
+ }
383
+
384
+ yq32[threadIdx.x] = q32;
385
+ if (threadIdx.x % QI8_1 == 0) {
386
+ if (std::is_same<Tds, half2>::value) {
387
+ ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum);
388
+ } else {
389
+ ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum);
390
+ }
391
+ }
392
+ }
393
+
394
+ typedef half (*dequantize_1_f16_t)(const void *, const int64_t);
395
+ typedef float (*dequantize_1_f32_t)(const void *, const int64_t);
396
+
397
+ template <typename T>
398
+ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) {
399
+ const block_q4_0 * x = (const block_q4_0 *) vx;
400
+
401
+ const int64_t ib = i / QK4_0;
402
+ const int iqs = i % (QK4_0/2);
403
+ const int shift = (i % QK4_0) / (QK4_0/2);
404
+
405
+ const T d = x[ib].d;
406
+ const int q0 = x[ib].qs[iqs];
407
+ const int q = ((q0 >> (4*shift)) & 0x0F) - 8;
408
+
409
+ #if FP16_AVAILABLE
410
+ if (std::is_same<T, half>::value) {
411
+ return ((half) d)*((half) q);
412
+ }
413
+ #endif // FP16_AVAILABLE
414
+
415
+ return ((float) d)*((float) q);
416
+ }
417
+
418
+ template <typename T>
419
+ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) {
420
+ const block_q4_1 * x = (const block_q4_1 *) vx;
421
+
422
+ const int64_t ib = i / QK4_1;
423
+ const int iqs = i % (QK4_1/2);
424
+ const int shift = (i % QK4_1) / (QK4_1/2);
425
+
426
+ const half2 dm = x[ib].dm;
427
+ const int q0 = x[ib].qs[iqs];
428
+ const int q = ((q0 >> (4*shift)) & 0x0F);
429
+
430
+ #if FP16_AVAILABLE
431
+ if (std::is_same<T, half>::value) {
432
+ return __low2half(dm)*((half) q) + __high2half(dm);
433
+ }
434
+ #endif // FP16_AVAILABLE
435
+
436
+ return __low2float(dm)*((float) q) + __high2float(dm);
437
+ }
438
+
439
+ template <typename T>
440
+ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) {
441
+ const block_q5_0 * x = (const block_q5_0 *) vx;
442
+
443
+ const int64_t ib = i / QK5_0;
444
+ const int idq = i % QK5_0;
445
+ const int iqs = i % (QK5_0/2);
446
+ const int shift = (i % QK5_0) / (QK5_0/2);
447
+
448
+ const T d = x[ib].d;
449
+ const int ql0 = x[ib].qs[iqs];
450
+ const int qh0 = get_int_from_uint8(x[ib].qh, 0);
451
+ const int ql = ((ql0 >> (4*shift)) & 0x0F);
452
+ const int qh = ((qh0 >> idq) << 4) & 0x10;
453
+ const int q = (ql | qh) - 16;
454
+
455
+ #if FP16_AVAILABLE
456
+ if (std::is_same<T, half>::value) {
457
+ return ((half) d)*((half) q);
458
+ }
459
+ #endif // FP16_AVAILABLE
460
+
461
+ return ((float) d)*((float) q);
462
+ }
463
+
464
+ template <typename T>
465
+ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) {
466
+ const block_q5_1 * x = (const block_q5_1 *) vx;
467
+
468
+ const int64_t ib = i / QK5_1;
469
+ const int idq = i % QK5_1;
470
+ const int iqs = i % (QK5_1/2);
471
+ const int shift = (i % QK5_1) / (QK5_1/2);
472
+
473
+ const half2 dm = x[ib].dm;
474
+ const int ql0 = x[ib].qs[iqs];
475
+ const int qh0 = get_int_from_uint8_aligned(x[ib].qh, 0);
476
+ const int ql = ((ql0 >> (4*shift)) & 0x0F);
477
+ const int qh = ((qh0 >> idq) << 4) & 0x10;
478
+ const int q = (ql | qh);
479
+
480
+ #if FP16_AVAILABLE
481
+ if (std::is_same<T, half>::value) {
482
+ return __low2half(dm)*((half) q) + __high2half(dm);
483
+ }
484
+ #endif // FP16_AVAILABLE
485
+
486
+ return __low2float(dm)*((float) q) + __high2float(dm);
487
+ }
488
+
489
+ template <typename T>
490
+ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) {
491
+ const block_q8_0 * x = (const block_q8_0 *) vx;
492
+
493
+ const int64_t ib = i / QK8_0;
494
+ const int iqs = i % QK8_0;
495
+
496
+ const T d = x[ib].d;
497
+ const int q = x[ib].qs[iqs];
498
+
499
+ #if FP16_AVAILABLE
500
+ if (std::is_same<T, half>::value) {
501
+ return ((half) d)*((half) q);
502
+ }
503
+ #endif // FP16_AVAILABLE
504
+
505
+ return ((float) d)*((float) q);
506
+ }
507
+
508
+ template <typename T>
509
+ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) {
510
+ const half * x = (const half *) vx;
511
+
512
+ return x[i];
513
+ }
514
+
515
+ template <int D>
516
+ constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
517
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
518
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
519
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
520
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
521
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
522
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
523
+ nullptr;
524
+ }
525
+
526
+ template <int D>
527
+ constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
528
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
529
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
530
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
531
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
532
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
533
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
534
+ nullptr;
535
+ }
536
+
537
+ constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
538
+ return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
539
+ type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
540
+ type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
541
+ type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
542
+ type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
543
+ type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
544
+ nullptr;
545
+ }
546
+
547
+ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
548
+ return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float> :
549
+ type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> :
550
+ type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> :
551
+ type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> :
552
+ type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> :
553
+ type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> :
554
+ nullptr;
555
+ }
556
+
557
  template<int D, int parallel_blocks> // D == head size
558
  #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
559
  __launch_bounds__(D, 1)
 
598
  dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
599
  }
600
 
601
+ static void on_no_fattn_vec_case(const int D) {
602
+ if (D == 64) {
603
+ fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
604
+ fprintf(stderr, "By default only f16 KV cache is supported.\n");
605
+ fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
606
+ GGML_ASSERT(false);
607
+ } else if (D == 128) {
608
+ fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
609
+ fprintf(stderr, "Supported combinations:\n");
610
+ fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
611
+ fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
612
+ fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
613
+ fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
614
+ GGML_ASSERT(false);
615
+ } else {
616
+ fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
617
+ fprintf(stderr, "Only f16 is supported.\n");
618
+ GGML_ASSERT(false);
619
+ }
620
+ }
621
+
622
  template <int D, int parallel_blocks>
623
  void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
624
  const ggml_tensor * Q = dst->src[0];
 
630
  ggml_tensor * KQV = dst;
631
 
632
  GGML_ASSERT(Q->type == GGML_TYPE_F32);
 
 
633
  GGML_ASSERT(KQV->type == GGML_TYPE_F32);
634
 
635
  GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
 
677
  mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
678
  Q->nb[1], Q->nb[2], Q->nb[3],
679
  K->nb[1], K->nb[2], K->nb[3],
680
+ V->nb[1], V->nb[2], V->nb[3],
681
  KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
682
  );
683
  CUDA_CHECK(cudaGetLastError());
ggml-cuda/fattn-tile-f16.cu CHANGED
@@ -36,6 +36,9 @@ static __global__ void flash_attn_tile_ext_f16(
36
  const int nb11,
37
  const int nb12,
38
  const int nb13,
 
 
 
39
  const int ne0,
40
  const int ne1,
41
  const int ne2,
 
36
  const int nb11,
37
  const int nb12,
38
  const int nb13,
39
+ const int nb21,
40
+ const int nb22,
41
+ const int nb23,
42
  const int ne0,
43
  const int ne1,
44
  const int ne2,
ggml-cuda/fattn-tile-f32.cu CHANGED
@@ -36,6 +36,9 @@ static __global__ void flash_attn_tile_ext_f32(
36
  const int nb11,
37
  const int nb12,
38
  const int nb13,
 
 
 
39
  const int ne0,
40
  const int ne1,
41
  const int ne2,
 
36
  const int nb11,
37
  const int nb12,
38
  const int nb13,
39
+ const int nb21,
40
+ const int nb22,
41
+ const int nb23,
42
  const int ne0,
43
  const int ne1,
44
  const int ne2,
ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -1,5 +1,395 @@
1
  #include "common.cuh"
 
2
 
3
- void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #include "common.cuh"
2
+ #include "fattn-common.cuh"
3
 
4
+ template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
5
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
6
+ __launch_bounds__(D, 1)
7
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
8
+ static __global__ void flash_attn_vec_ext_f16(
9
+ const char * __restrict__ Q,
10
+ const char * __restrict__ K,
11
+ const char * __restrict__ V,
12
+ const char * __restrict__ mask,
13
+ float * __restrict__ dst,
14
+ float2 * __restrict__ dst_meta,
15
+ const float scale,
16
+ const float max_bias,
17
+ const float m0,
18
+ const float m1,
19
+ const uint32_t n_head_log2,
20
+ const int ne00,
21
+ const int ne01,
22
+ const int ne02,
23
+ const int ne03,
24
+ const int ne10,
25
+ const int ne11,
26
+ const int ne12,
27
+ const int ne13,
28
+ const int ne31,
29
+ const int nb31,
30
+ const int nb01,
31
+ const int nb02,
32
+ const int nb03,
33
+ const int nb11,
34
+ const int nb12,
35
+ const int nb13,
36
+ const int nb21,
37
+ const int nb22,
38
+ const int nb23,
39
+ const int ne0,
40
+ const int ne1,
41
+ const int ne2,
42
+ const int ne3) {
43
+ #if FP16_AVAILABLE
44
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
45
 
46
+ constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
47
+ constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
48
+ constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
49
+
50
+ const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
51
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
52
+
53
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
54
+ Q += nb02* blockIdx.y + nb01*ic0;
55
+ K += nb12*(blockIdx.y / gqa_ratio);
56
+ V += nb22*(blockIdx.y / gqa_ratio);
57
+
58
+ const half * maskh = (const half *) mask + ne11*ic0;
59
+
60
+ const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
61
+ const half slopeh = __float2half(slopef);
62
+
63
+ static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
64
+ constexpr int nwarps = D / WARP_SIZE;
65
+ const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
66
+ __builtin_assume(tid < D);
67
+
68
+ __shared__ half KQ[ncols*D];
69
+ half2 * KQ2 = (half2 *) KQ;
70
+
71
+ half kqmax[ncols];
72
+ #pragma unroll
73
+ for (int j = 0; j < ncols; ++j) {
74
+ kqmax[j] = -HALF_MAX_HALF;
75
+ }
76
+ half kqsum[ncols] = {0.0f};
77
+
78
+ __shared__ half kqmax_shared[ncols][WARP_SIZE];
79
+ __shared__ half kqsum_shared[ncols][WARP_SIZE];
80
+ #pragma unroll
81
+ for (int j = 0; j < ncols; ++j) {
82
+ if (threadIdx.y == 0) {
83
+ kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
84
+ kqsum_shared[j][threadIdx.x] = 0.0f;
85
+ }
86
+ }
87
+ __syncthreads();
88
+
89
+ // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
90
+ half2 Q_h2[ncols][D/(2*WARP_SIZE)];
91
+ int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)];
92
+ half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
93
+ if (Q_q8_1) {
94
+ #pragma unroll
95
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
96
+ const int j = j0 + threadIdx.y;
97
+
98
+ if (j0 + nwarps > ncols && j >= ncols) {
99
+ break;
100
+ }
101
+
102
+ // Reuse KQ as temporary storage for converting Q to q8_1:
103
+ int * tmp_q_i32 = (int *) &KQ[j*D];
104
+ half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
105
+
106
+ // Set memory to zero if out of bounds:
107
+ if (ncols > 2 && ic0 + j >= ne01) {
108
+ #pragma unroll
109
+ for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
110
+ const int i = i0 + threadIdx.x;
111
+
112
+ tmp_q_i32[i] = 0;
113
+ }
114
+ if (threadIdx.x < D/QK8_1) {
115
+ tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f);
116
+ }
117
+ continue;
118
+ }
119
+
120
+ const float * Q_f = (const float *) (Q + j*nb01);
121
+ #pragma unroll
122
+ for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
123
+ quantize_q8_1_to_shared<half2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
124
+ }
125
+ }
126
+
127
+ __syncthreads();
128
+
129
+ #pragma unroll
130
+ for (int j = 0; j < ncols; ++j) {
131
+ int * tmp_q_i32 = (int *) &KQ[j*D];
132
+ half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
133
+
134
+ #pragma unroll
135
+ for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
136
+ const int i = i0 + threadIdx.x;
137
+
138
+ Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
139
+ Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1];
140
+ }
141
+ }
142
+
143
+ __syncthreads();
144
+ } else {
145
+ #pragma unroll
146
+ for (int j = 0; j < ncols; ++j) {
147
+ const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
148
+
149
+ #pragma unroll
150
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
151
+ const int i = i0 + threadIdx.x;
152
+
153
+ const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
154
+ Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
155
+ }
156
+ }
157
+ }
158
+
159
+
160
+ #pragma unroll
161
+ for (int j = 0; j < ncols; ++j) {
162
+ KQ[j*D + tid] = -HALF_MAX_HALF;
163
+ }
164
+
165
+ half2 VKQ[ncols] = {{0.0f, 0.0f}};
166
+
167
+ const int k_start = parallel_blocks == 1 ? 0 : ip*D;
168
+ for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
169
+ // Calculate KQ tile and keep track of new maximum KQ values:
170
+
171
+ // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
172
+ // see https://github.com/ggerganov/llama.cpp/pull/7061 .
173
+ // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
174
+ half kqmax_new = kqmax[0];
175
+ half kqmax_new_arr[ncols];
176
+ #pragma unroll
177
+ for (int j = 0; j < ncols; ++j) {
178
+ kqmax_new_arr[j] = kqmax[j];
179
+ }
180
+
181
+ #pragma unroll
182
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
183
+ const int i_KQ = i_KQ_0 + threadIdx.y;
184
+
185
+ if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
186
+ break;
187
+ }
188
+
189
+ #pragma unroll
190
+ for (int j = 0; j < ncols; ++j) {
191
+ half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
192
+ sum = warp_reduce_sum(sum);
193
+ sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
194
+
195
+ if (ncols == 1) {
196
+ kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
197
+ } else {
198
+ kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
199
+ }
200
+
201
+ if (threadIdx.x == 0) {
202
+ KQ[j*D + i_KQ] = sum;
203
+ }
204
+ }
205
+ }
206
+
207
+ #pragma unroll
208
+ for (int j = 0; j < ncols; ++j) {
209
+ half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
210
+
211
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
212
+ if (threadIdx.x == 0) {
213
+ kqmax_shared[j][threadIdx.y] = kqmax_new_j;
214
+ }
215
+ }
216
+
217
+ __syncthreads();
218
+
219
+ #pragma unroll
220
+ for (int j = 0; j < ncols; ++j) {
221
+ half kqmax_new_j = kqmax_shared[j][threadIdx.x];
222
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
223
+
224
+ const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
225
+ kqmax[j] = kqmax_new_j;
226
+
227
+ const half val = hexp(KQ[j*D + tid] - kqmax[j]);
228
+ kqsum[j] = kqsum[j]*KQ_max_scale + val;
229
+ KQ[j*D + tid] = val;
230
+
231
+ VKQ[j] *= __half2half2(KQ_max_scale);
232
+ }
233
+
234
+ __syncthreads();
235
+
236
+ #pragma unroll
237
+ for (int k0 = 0; k0 < D; k0 += 2) {
238
+ if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
239
+ break;
240
+ }
241
+
242
+ half2 V_k;
243
+ reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
244
+ reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
245
+ #pragma unroll
246
+ for (int j = 0; j < ncols; ++j) {
247
+ VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
248
+ }
249
+ }
250
+
251
+ __syncthreads();
252
+ }
253
+
254
+ #pragma unroll
255
+ for (int j = 0; j < ncols; ++j) {
256
+ kqsum[j] = warp_reduce_sum(kqsum[j]);
257
+ if (threadIdx.x == 0) {
258
+ kqsum_shared[j][threadIdx.y] = kqsum[j];
259
+ }
260
+ }
261
+
262
+ __syncthreads();
263
+
264
+ #pragma unroll
265
+ for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
266
+ if (ncols > 2 && ic0 + j_VKQ >= ne01) {
267
+ break;
268
+ }
269
+
270
+ kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
271
+ kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
272
+
273
+ half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
274
+ if (parallel_blocks == 1) {
275
+ dst_val /= kqsum[j_VKQ];
276
+ }
277
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
278
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
279
+ }
280
+
281
+ if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
282
+ dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
283
+ }
284
+ #else
285
+ NO_DEVICE_CODE;
286
+ #endif // FP16_AVAILABLE
287
+ }
288
+
289
+ template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
290
+ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
291
+ constexpr int nwarps = D/WARP_SIZE;
292
+ fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
293
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
294
+ }
295
+
296
+ template <int D, ggml_type type_K, ggml_type type_V>
297
+ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
298
+ ggml_tensor * KQV = dst;
299
+ ggml_tensor * Q = dst->src[0];
300
+ ggml_tensor * K = dst->src[1];
301
+ ggml_tensor * V = dst->src[2];
302
+
303
+ const int32_t precision = KQV->op_params[2];
304
+ GGML_ASSERT(precision == GGML_PREC_DEFAULT);
305
+
306
+ GGML_ASSERT(K->type == type_K);
307
+ GGML_ASSERT(V->type == type_V);
308
+
309
+ if (Q->ne[1] == 1) {
310
+ constexpr int cols_per_block = 1;
311
+ constexpr int parallel_blocks = 4;
312
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
313
+ return;
314
+ }
315
+
316
+ if (Q->ne[1] == 2) {
317
+ constexpr int cols_per_block = 2;
318
+ constexpr int parallel_blocks = 4;
319
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
320
+ return;
321
+ }
322
+
323
+ if (Q->ne[1] <= 4) {
324
+ constexpr int cols_per_block = 4;
325
+ constexpr int parallel_blocks = 4;
326
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
327
+ return;
328
+ }
329
+
330
+ if (Q->ne[1] <= 8) {
331
+ constexpr int cols_per_block = 8;
332
+ constexpr int parallel_blocks = 4;
333
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
334
+ return;
335
+ }
336
+
337
+ constexpr int cols_per_block = 8;
338
+ constexpr int parallel_blocks = 1;
339
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
340
+ }
341
+
342
+ #define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
343
+ template void ggml_cuda_flash_attn_ext_vec_f16_case \
344
+ <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
345
+
346
+ extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
347
+ extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
348
+ extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
349
+ extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
350
+ extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
351
+ extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
352
+
353
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
354
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
355
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
356
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
357
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
358
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
359
+
360
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
361
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
362
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
363
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
364
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
365
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
366
+
367
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
368
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
369
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
370
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
371
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
372
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
373
+
374
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
375
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
376
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
377
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
378
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
379
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
380
+
381
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
382
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
383
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
384
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
385
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
386
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
387
+
388
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
389
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
390
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
391
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
392
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
393
+ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
394
+
395
+ extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
ggml-cuda/fattn-vec-f32.cuh CHANGED
@@ -1,3 +1,376 @@
1
  #include "common.cuh"
 
2
 
3
- void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #include "common.cuh"
2
+ #include "fattn-common.cuh"
3
 
4
+ template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
5
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
6
+ __launch_bounds__(D, 1)
7
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
8
+ static __global__ void flash_attn_vec_ext_f32(
9
+ const char * __restrict__ Q,
10
+ const char * __restrict__ K,
11
+ const char * __restrict__ V,
12
+ const char * __restrict__ mask,
13
+ float * __restrict__ dst,
14
+ float2 * __restrict__ dst_meta,
15
+ const float scale,
16
+ const float max_bias,
17
+ const float m0,
18
+ const float m1,
19
+ const uint32_t n_head_log2,
20
+ const int ne00,
21
+ const int ne01,
22
+ const int ne02,
23
+ const int ne03,
24
+ const int ne10,
25
+ const int ne11,
26
+ const int ne12,
27
+ const int ne13,
28
+ const int ne31,
29
+ const int nb31,
30
+ const int nb01,
31
+ const int nb02,
32
+ const int nb03,
33
+ const int nb11,
34
+ const int nb12,
35
+ const int nb13,
36
+ const int nb21,
37
+ const int nb22,
38
+ const int nb23,
39
+ const int ne0,
40
+ const int ne1,
41
+ const int ne2,
42
+ const int ne3) {
43
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
44
+
45
+ constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
46
+ constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
47
+ constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
48
+
49
+ const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
50
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
51
+
52
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
53
+ Q += nb02* blockIdx.y + nb01*ic0;
54
+ K += nb12*(blockIdx.y / gqa_ratio);
55
+ V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
56
+ const half * maskh = (const half *) mask + ne11*ic0;
57
+
58
+ const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
59
+
60
+ static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
61
+ constexpr int nwarps = D / WARP_SIZE;
62
+ const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
63
+ __builtin_assume(tid < D);
64
+
65
+ __shared__ float KQ[ncols*D];
66
+ #pragma unroll
67
+ for (int j = 0; j < ncols; ++j) {
68
+ KQ[j*D + tid] = -FLT_MAX/2.0f;
69
+ }
70
+
71
+ float kqmax[ncols];
72
+ #pragma unroll
73
+ for (int j = 0; j < ncols; ++j) {
74
+ kqmax[j] = -FLT_MAX/2.0f;
75
+ }
76
+ float kqsum[ncols] = {0.0f};
77
+
78
+ __shared__ float kqmax_shared[ncols][WARP_SIZE];
79
+ __shared__ float kqsum_shared[ncols][WARP_SIZE];
80
+ #pragma unroll
81
+ for (int j = 0; j < ncols; ++j) {
82
+ if (threadIdx.y == 0) {
83
+ kqmax_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
84
+ kqsum_shared[j][threadIdx.x] = 0.0f;
85
+ }
86
+ }
87
+ __syncthreads();
88
+
89
+ // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
90
+ float2 Q_f2[ncols][D/(2*WARP_SIZE)];
91
+ int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)];
92
+ float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
93
+ if (Q_q8_1) {
94
+ #pragma unroll
95
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
96
+ const int j = j0 + threadIdx.y;
97
+
98
+ if (j0 + nwarps > ncols && j >= ncols) {
99
+ break;
100
+ }
101
+
102
+ // Reuse KQ as temporary storage for converting Q to q8_1:
103
+ int * tmp_q_i32 = (int *) &KQ[j*D];
104
+ float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
105
+
106
+ // Set memory to zero if out of bounds:
107
+ if (ncols > 2 && ic0 + j >= ne01) {
108
+ #pragma unroll
109
+ for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
110
+ const int i = i0 + threadIdx.x;
111
+
112
+ tmp_q_i32[i] = 0;
113
+ }
114
+ if (threadIdx.x < D/QK8_1) {
115
+ tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
116
+ }
117
+ continue;
118
+ }
119
+
120
+ const float * Q_f = (const float *) (Q + j*nb01);
121
+ #pragma unroll
122
+ for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
123
+ quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
124
+ }
125
+ }
126
+
127
+ __syncthreads();
128
+
129
+ #pragma unroll
130
+ for (int j = 0; j < ncols; ++j) {
131
+ int * tmp_q_i32 = (int *) &KQ[j*D];
132
+ float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
133
+
134
+ #pragma unroll
135
+ for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
136
+ const int i = i0 + threadIdx.x;
137
+
138
+ Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
139
+ Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1];
140
+ }
141
+ }
142
+
143
+ __syncthreads();
144
+ } else {
145
+ #pragma unroll
146
+ for (int j = 0; j < ncols; ++j) {
147
+ const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
148
+ #pragma unroll
149
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
150
+ const int i = i0 + threadIdx.x;
151
+
152
+ Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
153
+ Q_f2[j][i0/WARP_SIZE].x *= scale;
154
+ Q_f2[j][i0/WARP_SIZE].y *= scale;
155
+ }
156
+ }
157
+ }
158
+
159
+ float VKQ[ncols] = {0.0f};
160
+
161
+ const int k_start = parallel_blocks == 1 ? 0 : ip*D;
162
+ for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
163
+ // Calculate KQ tile and keep track of new maximum KQ values:
164
+
165
+ float kqmax_new_arr[ncols];
166
+ #pragma unroll
167
+ for (int j = 0; j < ncols; ++j) {
168
+ kqmax_new_arr[j] = kqmax[j];
169
+ }
170
+
171
+ #pragma unroll
172
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
173
+ const int i_KQ = i_KQ_0 + threadIdx.y;
174
+
175
+ if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
176
+ break;
177
+ }
178
+
179
+ #pragma unroll
180
+ for (int j = 0; j < ncols; ++j) {
181
+ float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
182
+ sum = warp_reduce_sum(sum);
183
+ sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
184
+
185
+ kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
186
+
187
+ if (threadIdx.x == 0) {
188
+ KQ[j*D + i_KQ] = sum;
189
+ }
190
+ }
191
+ }
192
+
193
+ #pragma unroll
194
+ for (int j = 0; j < ncols; ++j) {
195
+ float kqmax_new_j = kqmax_new_arr[j];
196
+
197
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
198
+ if (threadIdx.x == 0) {
199
+ kqmax_shared[j][threadIdx.y] = kqmax_new_j;
200
+ }
201
+ }
202
+
203
+ __syncthreads();
204
+
205
+ #pragma unroll
206
+ for (int j = 0; j < ncols; ++j) {
207
+ float kqmax_new_j = kqmax_shared[j][threadIdx.x];
208
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
209
+
210
+ const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
211
+ kqmax[j] = kqmax_new_j;
212
+
213
+ const float val = expf(KQ[j*D + tid] - kqmax[j]);
214
+ kqsum[j] = kqsum[j]*KQ_max_scale + val;
215
+ KQ[j*D + tid] = val;
216
+
217
+ VKQ[j] *= KQ_max_scale;
218
+ }
219
+
220
+ __syncthreads();
221
+
222
+ #pragma unroll
223
+ for (int k = 0; k < D; ++k) {
224
+ if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) {
225
+ break;
226
+ }
227
+
228
+ const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
229
+ #pragma unroll
230
+ for (int j = 0; j < ncols; ++j) {
231
+ VKQ[j] += V_ki*KQ[j*D + k];
232
+ }
233
+ }
234
+
235
+ __syncthreads();
236
+ }
237
+
238
+ #pragma unroll
239
+ for (int j = 0; j < ncols; ++j) {
240
+ kqsum[j] = warp_reduce_sum(kqsum[j]);
241
+ if (threadIdx.x == 0) {
242
+ kqsum_shared[j][threadIdx.y] = kqsum[j];
243
+ }
244
+ }
245
+
246
+ __syncthreads();
247
+
248
+ #pragma unroll
249
+ for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
250
+ if (ncols > 2 && ic0 + j_VKQ >= ne01) {
251
+ break;
252
+ }
253
+
254
+ kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
255
+ kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
256
+
257
+ float dst_val = VKQ[j_VKQ];
258
+ if (parallel_blocks == 1) {
259
+ dst_val /= kqsum[j_VKQ];
260
+ }
261
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
262
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
263
+ }
264
+
265
+ if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
266
+ dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
267
+ }
268
+ }
269
+
270
+ template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
271
+ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
272
+ constexpr int nwarps = D/WARP_SIZE;
273
+ fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
274
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
275
+ }
276
+
277
+ template <int D, ggml_type type_K, ggml_type type_V>
278
+ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
279
+ ggml_tensor * KQV = dst;
280
+ ggml_tensor * Q = dst->src[0];
281
+ ggml_tensor * K = dst->src[1];
282
+ ggml_tensor * V = dst->src[2];
283
+
284
+ const int32_t precision = KQV->op_params[2];
285
+ GGML_ASSERT(precision == GGML_PREC_DEFAULT);
286
+
287
+ GGML_ASSERT(K->type == type_K);
288
+ GGML_ASSERT(V->type == type_V);
289
+
290
+ if (Q->ne[1] == 1) {
291
+ constexpr int cols_per_block = 1;
292
+ constexpr int parallel_blocks = 4;
293
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
294
+ return;
295
+ }
296
+
297
+ if (Q->ne[1] == 2) {
298
+ constexpr int cols_per_block = 2;
299
+ constexpr int parallel_blocks = 4;
300
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
301
+ return;
302
+ }
303
+
304
+ if (Q->ne[1] <= 4) {
305
+ constexpr int cols_per_block = 4;
306
+ constexpr int parallel_blocks = 4;
307
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
308
+ return;
309
+ }
310
+
311
+ if (Q->ne[1] <= 8) {
312
+ constexpr int cols_per_block = 8;
313
+ constexpr int parallel_blocks = 4;
314
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
315
+ return;
316
+ }
317
+
318
+ constexpr int cols_per_block = 8;
319
+ constexpr int parallel_blocks = 1;
320
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
321
+ }
322
+
323
+ #define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
324
+ template void ggml_cuda_flash_attn_ext_vec_f32_case \
325
+ <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
326
+
327
+ extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
328
+ extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
329
+ extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
330
+ extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
331
+ extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
332
+ extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
333
+
334
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
335
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
336
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
337
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
338
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
339
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
340
+
341
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
342
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
343
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
344
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
345
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
346
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
347
+
348
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
349
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
350
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
351
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
352
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
353
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
354
+
355
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
356
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
357
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
358
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
359
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
360
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
361
+
362
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
363
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
364
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
365
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
366
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
367
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
368
+
369
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
370
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
371
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
372
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
373
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
374
+ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
375
+
376
+ extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
ggml-cuda/fattn-wmma-f16.cuh ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "fattn-common.cuh"
3
+
4
+ #if FP16_MMA_AVAILABLE
5
+ #include <mma.h>
6
+ #endif
7
+
8
+ // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
9
+ template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
10
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
11
+ __launch_bounds__(nwarps*WARP_SIZE, 1)
12
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
13
+ static __global__ void flash_attn_ext_f16(
14
+ const char * __restrict__ Q,
15
+ const char * __restrict__ K,
16
+ const char * __restrict__ V,
17
+ const char * __restrict__ mask,
18
+ float * __restrict__ dst,
19
+ float2 * __restrict__ dst_meta,
20
+ const float scale,
21
+ const float max_bias,
22
+ const float m0,
23
+ const float m1,
24
+ const uint32_t n_head_log2,
25
+ const int ne00,
26
+ const int ne01,
27
+ const int ne02,
28
+ const int ne03,
29
+ const int ne10,
30
+ const int ne11,
31
+ const int ne12,
32
+ const int ne13,
33
+ const int ne31,
34
+ const int nb31,
35
+ const int nb01,
36
+ const int nb02,
37
+ const int nb03,
38
+ const int nb11,
39
+ const int nb12,
40
+ const int nb13,
41
+ const int nb21,
42
+ const int nb22,
43
+ const int nb23,
44
+ const int ne0,
45
+ const int ne1,
46
+ const int ne2,
47
+ const int ne3) {
48
+ #if FP16_MMA_AVAILABLE
49
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
50
+
51
+ const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
52
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
53
+
54
+ static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
55
+ static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
56
+ constexpr int frag_m = ncols == 8 ? 32 : 16;
57
+ constexpr int frag_n = ncols == 8 ? 8 : 16;
58
+ static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
59
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
60
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
61
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
62
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
63
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
64
+
65
+ constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
66
+ constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
67
+ static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
68
+
69
+ // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
70
+ constexpr int D_padded = D + 8;
71
+ constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
72
+ constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
73
+
74
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
75
+ const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
76
+ const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
77
+ const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
78
+ const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
79
+ const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
80
+
81
+ const int stride_Q = nb01 / sizeof(float);
82
+ const int stride_KV = nb11 / sizeof(half);
83
+
84
+ const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
85
+ const half slopeh = __float2half(slopef);
86
+ const half2 slope2 = make_half2(slopef, slopef);
87
+
88
+ frag_b Q_b[D/16][ncols/frag_n];
89
+
90
+ // A single buffer for temporarily holding tiles of KQ and VKQ parts:
91
+ constexpr int mem_KQ = ncols*kqs_padded*kqar;
92
+ constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
93
+ __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
94
+ float * KQ_f = (float *) KQ;
95
+ half2 * KQ2 = (half2 *) KQ;
96
+
97
+ float KQ_rowsum_f[ncols/nwarps] = {0.0f};
98
+ float KQ_max_f[ncols/nwarps];
99
+ float KQ_max_scale_f[ncols/nwarps] = {0.0f};
100
+
101
+ #pragma unroll
102
+ for (int j = 0; j < ncols/nwarps; ++j) {
103
+ KQ_max_f[j] = -FLT_MAX/2.0f;
104
+ }
105
+
106
+ half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
107
+ half2 KQ_max_h2[ncols/nwarps];
108
+ half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
109
+
110
+ #pragma unroll
111
+ for (int j = 0; j < ncols/nwarps; ++j) {
112
+ KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
113
+ }
114
+
115
+ __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
116
+ half2 * VKQ2 = (half2 *) VKQ;
117
+ #pragma unroll
118
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
119
+ const int j = j0 + threadIdx.y;
120
+ #pragma unroll
121
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
122
+ const int i = i0 + threadIdx.x;
123
+ if (i0 + WARP_SIZE > D/2 && i >= D/2) {
124
+ break;
125
+ }
126
+ VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
127
+ }
128
+ }
129
+
130
+ // Convert Q to half and apply scale, temporarily store in KQ:
131
+ #pragma unroll
132
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
133
+ const int j = j0 + threadIdx.y;
134
+ #pragma unroll
135
+ for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
136
+ const int i = i0 + threadIdx.x;
137
+ if (i0 + WARP_SIZE > D && i >= D) {
138
+ break;
139
+ }
140
+ KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
141
+ }
142
+ }
143
+
144
+ __syncthreads();
145
+
146
+ // Load Q into tensor core fragments/registers since it will be used frequently:
147
+ #pragma unroll
148
+ for (int i0 = 0; i0 < D; i0 += 16) {
149
+ #pragma unroll
150
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
151
+ nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
152
+ }
153
+ }
154
+
155
+ __syncthreads();
156
+
157
+ // Iterate over ne11 == previous tokens:
158
+ for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
159
+ // Calculate tile of KQ:
160
+ #pragma unroll
161
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
162
+ frag_c_KQ KQ_c[ncols/frag_n];
163
+ #pragma unroll
164
+ for (int j = 0; j < ncols/frag_n; ++j) {
165
+ nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
166
+ }
167
+ #pragma unroll
168
+ for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
169
+ frag_a_K K_a;
170
+ nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
171
+ #pragma unroll
172
+ for (int j = 0; j < ncols/frag_n; ++j) {
173
+ nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
174
+ }
175
+ }
176
+ #pragma unroll
177
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
178
+ nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
179
+ }
180
+ }
181
+
182
+ __syncthreads();
183
+
184
+ // Calculate softmax for each KQ column using the current max. value.
185
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
186
+ #pragma unroll
187
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
188
+ const int j = j0 + threadIdx.y;
189
+
190
+ if (std::is_same<KQ_acc_t, float>::value) {
191
+ float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
192
+ #pragma unroll
193
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
194
+ const int k = k0 + threadIdx.x;
195
+
196
+ KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
197
+ }
198
+
199
+ float KQ_max_new = KQ_max_f[j0/nwarps];
200
+ #pragma unroll
201
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
202
+ const int k = k0 + threadIdx.x;
203
+
204
+ KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
205
+ KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
206
+ }
207
+ KQ_max_new = warp_reduce_max(KQ_max_new);
208
+
209
+ const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
210
+ KQ_max_scale_f[j0/nwarps] = expf(diff);
211
+ if (diff <= SOFTMAX_FTZ_THRESHOLD) {
212
+ KQ_max_scale_f[j0/nwarps] = 0.0f;
213
+ }
214
+ KQ_max_f[j0/nwarps] = KQ_max_new;
215
+
216
+ float KQ_rowsum_add = 0.0f;
217
+ #pragma unroll
218
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
219
+ const int k = k0 + threadIdx.x;
220
+
221
+ const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
222
+ KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
223
+ if (diff <= SOFTMAX_FTZ_THRESHOLD) {
224
+ KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
225
+ }
226
+ KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
227
+ KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
228
+ }
229
+ KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
230
+
231
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
232
+ KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
233
+ } else {
234
+ half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
235
+ #pragma unroll
236
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
237
+ const int k = k0 + threadIdx.x;
238
+
239
+ KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
240
+ }
241
+
242
+ half2 KQ_max_new = KQ_max_h2[j0/nwarps];
243
+ #pragma unroll
244
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
245
+ const int k = k0 + threadIdx.x;
246
+
247
+ KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
248
+ KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
249
+ }
250
+ KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
251
+ const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
252
+ KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
253
+ const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
254
+ *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
255
+ KQ_max_h2[j0/nwarps] = KQ_max_new;
256
+
257
+ half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
258
+ #pragma unroll
259
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
260
+ const int k = k0 + threadIdx.x;
261
+
262
+ const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
263
+ KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
264
+ const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
265
+ *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
266
+ KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
267
+ KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
268
+ }
269
+ KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
270
+
271
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
272
+ KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
273
+ }
274
+ }
275
+
276
+ __syncthreads();
277
+
278
+ frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
279
+ #pragma unroll
280
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
281
+ #pragma unroll
282
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
283
+ const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
284
+ nvcuda::wmma::load_matrix_sync(
285
+ KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
286
+ KQ + j0*(kqar*kqs_padded) + k,
287
+ kqar*kqs_padded);
288
+ }
289
+ }
290
+
291
+ frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
292
+ #pragma unroll
293
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
294
+ #pragma unroll
295
+ for (int j = 0; j < ncols/frag_n; ++j) {
296
+ nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
297
+ }
298
+
299
+ #pragma unroll
300
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
301
+ const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
302
+
303
+ frag_a_V v_a;
304
+ nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
305
+ #pragma unroll
306
+ for (int j = 0; j < ncols/frag_n; ++j) {
307
+ nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
308
+ }
309
+ }
310
+ }
311
+
312
+ __syncthreads();
313
+
314
+ const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
315
+ #pragma unroll
316
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
317
+ #pragma unroll
318
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
319
+ nvcuda::wmma::store_matrix_sync(
320
+ KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
321
+ VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
322
+ D_padded, nvcuda::wmma::mem_col_major);
323
+ }
324
+ }
325
+
326
+ __syncthreads();
327
+
328
+ #pragma unroll
329
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
330
+ const int j = j0 + threadIdx.y;
331
+
332
+ half2 VKQ_scale;
333
+ if (std::is_same<KQ_acc_t, float>::value) {
334
+ VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
335
+ } else {
336
+ VKQ_scale = KQ_max_scale_h2[j0/nwarps];
337
+ }
338
+
339
+ #pragma unroll
340
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
341
+ const int i = i0 + threadIdx.x;
342
+ if (i0 + WARP_SIZE > D/2 && i >= D/2) {
343
+ break;
344
+ }
345
+
346
+ half2 VKQ_add = make_half2(0.0f, 0.0f);
347
+ #pragma unroll
348
+ for (int l = 0; l < VKQ_ratio; ++l) {
349
+ VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
350
+ }
351
+ VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
352
+ }
353
+ }
354
+
355
+ __syncthreads();
356
+ }
357
+
358
+ #pragma unroll
359
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
360
+ const int j_VKQ = j0 + threadIdx.y;
361
+ if (ic0 + j_VKQ >= ne01) {
362
+ return;
363
+ }
364
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
365
+
366
+ float KQ_rowsum_j;
367
+ if (std::is_same<KQ_acc_t, float>::value) {
368
+ KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
369
+ } else {
370
+ KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
371
+ }
372
+
373
+ #pragma unroll
374
+ for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
375
+ const int i = i0 + threadIdx.x;
376
+ if (i0 + WARP_SIZE > D && i >= D) {
377
+ break;
378
+ }
379
+ float dst_val = VKQ[j_VKQ*D_padded + i];
380
+ if (parallel_blocks == 1) {
381
+ dst_val /= KQ_rowsum_j;
382
+ }
383
+ dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
384
+ }
385
+
386
+ if (parallel_blocks == 1 || threadIdx.x != 0) {
387
+ continue;
388
+ }
389
+
390
+ float2 dst_meta_val;
391
+ if (std::is_same<KQ_acc_t, float>::value) {
392
+ dst_meta_val.x = KQ_max_f[j0/nwarps];
393
+ } else {
394
+ dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
395
+ }
396
+ dst_meta_val.y = KQ_rowsum_j;
397
+ dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
398
+ }
399
+ #else
400
+ NO_DEVICE_CODE;
401
+ #endif // FP16_MMA_AVAILABLE
402
+ }
403
+
404
+ constexpr int get_max_power_of_2(int x) {
405
+ return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
406
+ }
407
+
408
+ static_assert(get_max_power_of_2(1) == 1, "Test failed.");
409
+ static_assert(get_max_power_of_2(2) == 2, "Test failed.");
410
+ static_assert(get_max_power_of_2(4) == 4, "Test failed.");
411
+ static_assert(get_max_power_of_2(6) == 2, "Test failed.");
412
+
413
+ // Number of VKQ rows calculated in parallel:
414
+ constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
415
+ return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
416
+ }
417
+
418
+ static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
419
+ static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
420
+ static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
421
+ static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
422
+ static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
423
+ static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
424
+ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
425
+ static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
426
+ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
427
+
428
+ template <int D, int cols_per_block, typename KQ_acc_t>
429
+ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
430
+ const ggml_tensor * Q = dst->src[0];
431
+
432
+ constexpr int nwarps = 4;
433
+
434
+ constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
435
+ const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
436
+ const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
437
+
438
+ if (4*blocks_num_pb1 < 2*nsm) {
439
+ constexpr int parallel_blocks = 4;
440
+ fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
441
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
442
+ return;
443
+ }
444
+ if (2*blocks_num_pb1 < 2*nsm) {
445
+ constexpr int parallel_blocks = 2;
446
+ fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
447
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
448
+ return;
449
+ }
450
+ constexpr int parallel_blocks = 1;
451
+ fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
452
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
453
+ }
454
+
455
+ #define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
456
+ template void ggml_cuda_flash_attn_ext_wmma_f16_case \
457
+ <D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
458
+
459
+ extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
460
+ extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
461
+ extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float);
462
+ extern DECL_FATTN_WMMA_F16_CASE(112, 16, float);
463
+ extern DECL_FATTN_WMMA_F16_CASE(128, 16, float);
464
+ extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
465
+
466
+ extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float);
467
+ extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float);
468
+ extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float);
469
+ extern DECL_FATTN_WMMA_F16_CASE(112, 32, float);
470
+ extern DECL_FATTN_WMMA_F16_CASE(128, 32, float);
471
+ // extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
472
+
473
+ extern DECL_FATTN_WMMA_F16_CASE( 64, 8, half);
474
+ extern DECL_FATTN_WMMA_F16_CASE( 96, 8, half);
475
+ extern DECL_FATTN_WMMA_F16_CASE(128, 8, half);
476
+ extern DECL_FATTN_WMMA_F16_CASE(256, 8, half);
477
+
478
+ extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half);
479
+ extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half);
480
+ extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half);
481
+ extern DECL_FATTN_WMMA_F16_CASE(112, 16, half);
482
+ extern DECL_FATTN_WMMA_F16_CASE(128, 16, half);
483
+ extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
484
+
485
+ extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half);
486
+ extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half);
487
+ extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
488
+ extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
489
+ extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
490
+ extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
ggml-cuda/fattn.cu CHANGED
@@ -4,519 +4,38 @@
4
  #include "fattn-tile-f32.cuh"
5
  #include "fattn-vec-f16.cuh"
6
  #include "fattn-vec-f32.cuh"
 
7
  #include "fattn.cuh"
8
 
9
  #include <cstdint>
10
 
11
- #if FP16_MMA_AVAILABLE
12
- #include <mma.h>
13
- #endif
14
-
15
- // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
16
- template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
17
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
18
- __launch_bounds__(nwarps*WARP_SIZE, 1)
19
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
20
- static __global__ void flash_attn_ext_f16(
21
- const char * __restrict__ Q,
22
- const char * __restrict__ K,
23
- const char * __restrict__ V,
24
- const char * __restrict__ mask,
25
- float * __restrict__ dst,
26
- float2 * __restrict__ dst_meta,
27
- const float scale,
28
- const float max_bias,
29
- const float m0,
30
- const float m1,
31
- const uint32_t n_head_log2,
32
- const int ne00,
33
- const int ne01,
34
- const int ne02,
35
- const int ne03,
36
- const int ne10,
37
- const int ne11,
38
- const int ne12,
39
- const int ne13,
40
- const int ne31,
41
- const int nb31,
42
- const int nb01,
43
- const int nb02,
44
- const int nb03,
45
- const int nb11,
46
- const int nb12,
47
- const int nb13,
48
- const int ne0,
49
- const int ne1,
50
- const int ne2,
51
- const int ne3) {
52
- #if FP16_MMA_AVAILABLE
53
- //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
54
-
55
- const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
56
- const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
57
-
58
- static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
59
- static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
60
- constexpr int frag_m = ncols == 8 ? 32 : 16;
61
- constexpr int frag_n = ncols == 8 ? 8 : 16;
62
- static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
63
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
64
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
65
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
66
- typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
67
- typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
68
-
69
- constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
70
- constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
71
- static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
72
-
73
- // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
74
- constexpr int D_padded = D + 8;
75
- constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
76
- constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
77
-
78
- const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
79
- const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
80
- const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
81
- const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
82
- const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
83
- const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
84
-
85
- const int stride_Q = nb01 / sizeof(float);
86
- const int stride_KV = nb11 / sizeof(half);
87
-
88
- const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
89
- const half slopeh = __float2half(slopef);
90
- const half2 slope2 = make_half2(slopef, slopef);
91
-
92
- frag_b Q_b[D/16][ncols/frag_n];
93
-
94
- // A single buffer for temporarily holding tiles of KQ and VKQ parts:
95
- constexpr int mem_KQ = ncols*kqs_padded*kqar;
96
- constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
97
- __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
98
- float * KQ_f = (float *) KQ;
99
- half2 * KQ2 = (half2 *) KQ;
100
-
101
- float KQ_rowsum_f[ncols/nwarps] = {0.0f};
102
- float KQ_max_f[ncols/nwarps];
103
- float KQ_max_scale_f[ncols/nwarps] = {0.0f};
104
-
105
- #pragma unroll
106
- for (int j = 0; j < ncols/nwarps; ++j) {
107
- KQ_max_f[j] = -FLT_MAX/2.0f;
108
- }
109
-
110
- half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
111
- half2 KQ_max_h2[ncols/nwarps];
112
- half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
113
-
114
- #pragma unroll
115
- for (int j = 0; j < ncols/nwarps; ++j) {
116
- KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
117
- }
118
-
119
- __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
120
- half2 * VKQ2 = (half2 *) VKQ;
121
- #pragma unroll
122
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
123
- const int j = j0 + threadIdx.y;
124
- #pragma unroll
125
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
126
- const int i = i0 + threadIdx.x;
127
- if (i0 + WARP_SIZE > D/2 && i >= D/2) {
128
- break;
129
- }
130
- VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
131
- }
132
- }
133
-
134
- // Convert Q to half and apply scale, temporarily store in KQ:
135
- #pragma unroll
136
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
137
- const int j = j0 + threadIdx.y;
138
- #pragma unroll
139
- for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
140
- const int i = i0 + threadIdx.x;
141
- if (i0 + WARP_SIZE > D && i >= D) {
142
- break;
143
- }
144
- KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
145
- }
146
- }
147
-
148
- __syncthreads();
149
-
150
- // Load Q into tensor core fragments/registers since it will be used frequently:
151
- #pragma unroll
152
- for (int i0 = 0; i0 < D; i0 += 16) {
153
- #pragma unroll
154
- for (int j0 = 0; j0 < ncols; j0 += frag_n) {
155
- nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
156
- }
157
- }
158
-
159
- __syncthreads();
160
-
161
- // Iterate over ne11 == previous tokens:
162
- for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
163
- // Calculate tile of KQ:
164
- #pragma unroll
165
- for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
166
- frag_c_KQ KQ_c[ncols/frag_n];
167
- #pragma unroll
168
- for (int j = 0; j < ncols/frag_n; ++j) {
169
- nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
170
- }
171
- #pragma unroll
172
- for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
173
- frag_a_K K_a;
174
- nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
175
- #pragma unroll
176
- for (int j = 0; j < ncols/frag_n; ++j) {
177
- nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
178
- }
179
- }
180
- #pragma unroll
181
- for (int j0 = 0; j0 < ncols; j0 += frag_n) {
182
- nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
183
- }
184
- }
185
-
186
- __syncthreads();
187
-
188
- // Calculate softmax for each KQ column using the current max. value.
189
- // The divisor is stored in KQ_rowsum and will be applied at the end.
190
- #pragma unroll
191
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
192
- const int j = j0 + threadIdx.y;
193
-
194
- if (std::is_same<KQ_acc_t, float>::value) {
195
- float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
196
- #pragma unroll
197
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
198
- const int k = k0 + threadIdx.x;
199
-
200
- KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
201
- }
202
-
203
- float KQ_max_new = KQ_max_f[j0/nwarps];
204
- #pragma unroll
205
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
206
- const int k = k0 + threadIdx.x;
207
-
208
- KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
209
- KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
210
- }
211
- KQ_max_new = warp_reduce_max(KQ_max_new);
212
-
213
- const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
214
- KQ_max_scale_f[j0/nwarps] = expf(diff);
215
- if (diff <= SOFTMAX_FTZ_THRESHOLD) {
216
- KQ_max_scale_f[j0/nwarps] = 0.0f;
217
- }
218
- KQ_max_f[j0/nwarps] = KQ_max_new;
219
-
220
- float KQ_rowsum_add = 0.0f;
221
- #pragma unroll
222
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
223
- const int k = k0 + threadIdx.x;
224
-
225
- const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
226
- KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
227
- if (diff <= SOFTMAX_FTZ_THRESHOLD) {
228
- KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
229
- }
230
- KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
231
- KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
232
- }
233
- KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
234
-
235
- // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
236
- KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
237
- } else {
238
- half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
239
- #pragma unroll
240
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
241
- const int k = k0 + threadIdx.x;
242
-
243
- KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
244
- }
245
-
246
- half2 KQ_max_new = KQ_max_h2[j0/nwarps];
247
- #pragma unroll
248
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
249
- const int k = k0 + threadIdx.x;
250
-
251
- KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
252
- KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
253
- }
254
- KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
255
- const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
256
- KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
257
- const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
258
- *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
259
- KQ_max_h2[j0/nwarps] = KQ_max_new;
260
-
261
- half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
262
- #pragma unroll
263
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
264
- const int k = k0 + threadIdx.x;
265
-
266
- const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
267
- KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
268
- const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
269
- *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
270
- KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
271
- KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
272
- }
273
- KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
274
-
275
- // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
276
- KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
277
- }
278
- }
279
-
280
- __syncthreads();
281
-
282
- frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
283
- #pragma unroll
284
- for (int j0 = 0; j0 < ncols; j0 += frag_n) {
285
- #pragma unroll
286
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
287
- const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
288
- nvcuda::wmma::load_matrix_sync(
289
- KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
290
- KQ + j0*(kqar*kqs_padded) + k,
291
- kqar*kqs_padded);
292
- }
293
- }
294
-
295
- frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
296
- #pragma unroll
297
- for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
298
- #pragma unroll
299
- for (int j = 0; j < ncols/frag_n; ++j) {
300
- nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
301
- }
302
-
303
- #pragma unroll
304
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
305
- const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
306
-
307
- frag_a_V v_a;
308
- nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
309
- #pragma unroll
310
- for (int j = 0; j < ncols/frag_n; ++j) {
311
- nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
312
- }
313
- }
314
- }
315
-
316
- __syncthreads();
317
-
318
- const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
319
- #pragma unroll
320
- for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
321
- #pragma unroll
322
- for (int j0 = 0; j0 < ncols; j0 += frag_n) {
323
- nvcuda::wmma::store_matrix_sync(
324
- KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
325
- VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
326
- D_padded, nvcuda::wmma::mem_col_major);
327
- }
328
- }
329
-
330
- __syncthreads();
331
-
332
- #pragma unroll
333
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
334
- const int j = j0 + threadIdx.y;
335
-
336
- half2 VKQ_scale;
337
- if (std::is_same<KQ_acc_t, float>::value) {
338
- VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
339
- } else {
340
- VKQ_scale = KQ_max_scale_h2[j0/nwarps];
341
- }
342
-
343
- #pragma unroll
344
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
345
- const int i = i0 + threadIdx.x;
346
- if (i0 + WARP_SIZE > D/2 && i >= D/2) {
347
- break;
348
- }
349
-
350
- half2 VKQ_add = make_half2(0.0f, 0.0f);
351
- #pragma unroll
352
- for (int l = 0; l < VKQ_ratio; ++l) {
353
- VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
354
- }
355
- VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
356
- }
357
- }
358
-
359
- __syncthreads();
360
- }
361
-
362
- #pragma unroll
363
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
364
- const int j_VKQ = j0 + threadIdx.y;
365
- if (ic0 + j_VKQ >= ne01) {
366
- return;
367
- }
368
- const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
369
-
370
- float KQ_rowsum_j;
371
- if (std::is_same<KQ_acc_t, float>::value) {
372
- KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
373
- } else {
374
- KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
375
- }
376
-
377
- #pragma unroll
378
- for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
379
- const int i = i0 + threadIdx.x;
380
- if (i0 + WARP_SIZE > D && i >= D) {
381
- break;
382
- }
383
- float dst_val = VKQ[j_VKQ*D_padded + i];
384
- if (parallel_blocks == 1) {
385
- dst_val /= KQ_rowsum_j;
386
- }
387
- dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
388
- }
389
-
390
- if (parallel_blocks == 1 || threadIdx.x != 0) {
391
- continue;
392
- }
393
-
394
- float2 dst_meta_val;
395
- if (std::is_same<KQ_acc_t, float>::value) {
396
- dst_meta_val.x = KQ_max_f[j0/nwarps];
397
- } else {
398
- dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
399
- }
400
- dst_meta_val.y = KQ_rowsum_j;
401
- dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
402
- }
403
- #else
404
- NO_DEVICE_CODE;
405
- #endif // FP16_MMA_AVAILABLE
406
- }
407
-
408
- constexpr int get_max_power_of_2(int x) {
409
- return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
410
- }
411
-
412
- static_assert(get_max_power_of_2(1) == 1, "Test failed.");
413
- static_assert(get_max_power_of_2(2) == 2, "Test failed.");
414
- static_assert(get_max_power_of_2(4) == 4, "Test failed.");
415
- static_assert(get_max_power_of_2(6) == 2, "Test failed.");
416
-
417
- // Number of VKQ rows calculated in parallel:
418
- constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
419
- return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
420
- }
421
-
422
- static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
423
- static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
424
- static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
425
- static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
426
- static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
427
- static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
428
- static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
429
- static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
430
- static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
431
-
432
- template <int D, int cols_per_block, int nwarps, typename KQ_acc_t>
433
- void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
434
- const ggml_tensor * Q = dst->src[0];
435
-
436
- constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
437
- const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
438
- const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
439
-
440
- if (4*blocks_num_pb1 < 2*nsm) {
441
- constexpr int parallel_blocks = 4;
442
- fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
443
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
444
- return;
445
- }
446
- if (2*blocks_num_pb1 < 2*nsm) {
447
- constexpr int parallel_blocks = 2;
448
- fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
449
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
450
- return;
451
- }
452
- constexpr int parallel_blocks = 1;
453
- fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
454
- launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
455
- }
456
-
457
- void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
458
  const ggml_tensor * KQV = dst;
459
  const ggml_tensor * Q = dst->src[0];
460
 
461
- ggml_cuda_set_device(ctx.device);
462
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
463
  const int32_t precision = KQV->op_params[2];
464
 
465
- // On AMD the tile kernels perform poorly, use the vec kernel instead:
466
- if (cc >= CC_OFFSET_AMD) {
467
- if (precision == GGML_PREC_DEFAULT) {
468
- ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
469
- } else {
470
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
471
- }
472
- return;
473
- }
474
-
475
- if (!fast_fp16_available(cc)) {
476
- if (Q->ne[1] <= 8) {
477
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
478
- } else {
479
- ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
480
- }
481
- return;
482
- }
483
-
484
- if (!fp16_mma_available(cc)) {
485
- if (Q->ne[1] <= 8) {
486
- ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
487
- } else {
488
- ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
489
- }
490
- return;
491
- }
492
-
493
  if (precision != GGML_PREC_DEFAULT) {
494
- if (Q->ne[1] == 1 && (Q->ne[0] == 64 || Q->ne[0] == 128)) {
495
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
496
- return;
497
- }
498
-
499
  if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
500
  constexpr int cols_per_block = 16;
501
- constexpr int nwarps = 4;
502
  switch (Q->ne[0]) {
503
  case 64:
504
- launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
505
  break;
506
  case 80:
507
- launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
508
  break;
509
  case 96:
510
- launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
511
  break;
512
  case 112:
513
- launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
514
  break;
515
  case 128:
516
- launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
517
  break;
518
  case 256:
519
- launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
520
  break;
521
  default:
522
  GGML_ASSERT(false);
@@ -524,25 +43,24 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
524
  }
525
  } else {
526
  constexpr int cols_per_block = 32;
527
- constexpr int nwarps = 4;
528
  switch (Q->ne[0]) {
529
  case 64:
530
- launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
531
  break;
532
  case 80:
533
- launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
534
  break;
535
  case 96:
536
- launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
537
  break;
538
  case 112:
539
- launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
540
  break;
541
  case 128:
542
- launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
543
  break;
544
  // case 256:
545
- // launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
546
  // break;
547
  default:
548
  GGML_ASSERT(false);
@@ -552,26 +70,20 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
552
  return;
553
  }
554
 
555
- if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
556
- ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
557
- return;
558
- }
559
-
560
  if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
561
  constexpr int cols_per_block = 8;
562
- constexpr int nwarps = 4;
563
  switch (Q->ne[0]) {
564
  case 64:
565
- launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
566
  break;
567
  case 96:
568
- launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
569
  break;
570
  case 128:
571
- launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
572
  break;
573
  case 256:
574
- launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
575
  break;
576
  default:
577
  GGML_ASSERT(false);
@@ -582,25 +94,24 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
582
 
583
  if (Q->ne[1] <= 32) {
584
  constexpr int cols_per_block = 16;
585
- constexpr int nwarps = 4;
586
  switch (Q->ne[0]) {
587
  case 64:
588
- launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
589
  break;
590
  case 80:
591
- launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
592
  break;
593
  case 96:
594
- launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
595
  break;
596
  case 112:
597
- launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
598
  break;
599
  case 128:
600
- launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
601
  break;
602
  case 256:
603
- launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
604
  break;
605
  default:
606
  GGML_ASSERT(false);
@@ -610,29 +121,229 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
610
  }
611
 
612
  constexpr int cols_per_block = 32;
613
- constexpr int nwarps = 4;
614
  switch (Q->ne[0]) {
615
  case 64:
616
- launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
617
  break;
618
  case 80:
619
- launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
620
  break;
621
  case 96:
622
- launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
623
  break;
624
  case 112:
625
- launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
626
  break;
627
  case 128:
628
- launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
629
  break;
630
  case 256:
631
- launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
632
  break;
633
  default:
634
  GGML_ASSERT(false);
635
  break;
636
  }
637
- return;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  }
 
4
  #include "fattn-tile-f32.cuh"
5
  #include "fattn-vec-f16.cuh"
6
  #include "fattn-vec-f32.cuh"
7
+ #include "fattn-wmma-f16.cuh"
8
  #include "fattn.cuh"
9
 
10
  #include <cstdint>
11
 
12
+ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  const ggml_tensor * KQV = dst;
14
  const ggml_tensor * Q = dst->src[0];
15
 
 
 
16
  const int32_t precision = KQV->op_params[2];
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  if (precision != GGML_PREC_DEFAULT) {
 
 
 
 
 
19
  if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
20
  constexpr int cols_per_block = 16;
 
21
  switch (Q->ne[0]) {
22
  case 64:
23
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
24
  break;
25
  case 80:
26
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
27
  break;
28
  case 96:
29
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
30
  break;
31
  case 112:
32
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
33
  break;
34
  case 128:
35
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
36
  break;
37
  case 256:
38
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
39
  break;
40
  default:
41
  GGML_ASSERT(false);
 
43
  }
44
  } else {
45
  constexpr int cols_per_block = 32;
 
46
  switch (Q->ne[0]) {
47
  case 64:
48
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
49
  break;
50
  case 80:
51
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
52
  break;
53
  case 96:
54
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
55
  break;
56
  case 112:
57
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
58
  break;
59
  case 128:
60
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
61
  break;
62
  // case 256:
63
+ // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
64
  // break;
65
  default:
66
  GGML_ASSERT(false);
 
70
  return;
71
  }
72
 
 
 
 
 
 
73
  if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
74
  constexpr int cols_per_block = 8;
 
75
  switch (Q->ne[0]) {
76
  case 64:
77
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
78
  break;
79
  case 96:
80
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
81
  break;
82
  case 128:
83
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
84
  break;
85
  case 256:
86
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
87
  break;
88
  default:
89
  GGML_ASSERT(false);
 
94
 
95
  if (Q->ne[1] <= 32) {
96
  constexpr int cols_per_block = 16;
 
97
  switch (Q->ne[0]) {
98
  case 64:
99
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
100
  break;
101
  case 80:
102
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
103
  break;
104
  case 96:
105
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
106
  break;
107
  case 112:
108
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
109
  break;
110
  case 128:
111
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
112
  break;
113
  case 256:
114
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
115
  break;
116
  default:
117
  GGML_ASSERT(false);
 
121
  }
122
 
123
  constexpr int cols_per_block = 32;
 
124
  switch (Q->ne[0]) {
125
  case 64:
126
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
127
  break;
128
  case 80:
129
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
130
  break;
131
  case 96:
132
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
133
  break;
134
  case 112:
135
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
136
  break;
137
  case 128:
138
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
139
  break;
140
  case 256:
141
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
142
  break;
143
  default:
144
  GGML_ASSERT(false);
145
  break;
146
  }
147
+ }
148
+ #define FATTN_VEC_F16_CASE(D, type_K, type_V) \
149
+ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
150
+ ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
151
+ return; \
152
+ } \
153
+
154
+ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
155
+ ggml_tensor * Q = dst->src[1];
156
+ ggml_tensor * K = dst->src[1];
157
+ ggml_tensor * V = dst->src[2];
158
+
159
+ #ifdef GGML_CUDA_FA_ALL_QUANTS
160
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
161
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
162
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
163
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
164
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
165
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
166
+
167
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
168
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
169
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
170
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
171
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
172
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
173
+
174
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
175
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
176
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
177
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
178
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
179
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
180
+
181
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
182
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
183
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
184
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
185
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
186
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
187
+
188
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
189
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
190
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
191
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
192
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
193
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
194
+
195
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
196
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
197
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
198
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
199
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
200
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
201
+
202
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
203
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
204
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
205
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
206
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
207
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
208
+
209
+ FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
210
+ #else
211
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
212
+
213
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
214
+
215
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
216
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
217
+ FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
218
+ #endif // GGML_CUDA_FA_ALL_QUANTS
219
+
220
+ on_no_fattn_vec_case(Q->ne[0]);
221
+ }
222
+
223
+ #define FATTN_VEC_F32_CASE(D, type_K, type_V) \
224
+ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
225
+ ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \
226
+ return; \
227
+ } \
228
+
229
+ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
230
+ ggml_tensor * Q = dst->src[1];
231
+ ggml_tensor * K = dst->src[1];
232
+ ggml_tensor * V = dst->src[2];
233
+
234
+ #ifdef GGML_CUDA_FA_ALL_QUANTS
235
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
236
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
237
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
238
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
239
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
240
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
241
+
242
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
243
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
244
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
245
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
246
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
247
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
248
+
249
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
250
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
251
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
252
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
253
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
254
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
255
+
256
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
257
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
258
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
259
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
260
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
261
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
262
+
263
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
264
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
265
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
266
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
267
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
268
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
269
+
270
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
271
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
272
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
273
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
274
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
275
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
276
+
277
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
278
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
279
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
280
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
281
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
282
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
283
+
284
+ FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
285
+ #else
286
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
287
+
288
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
289
+
290
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
291
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
292
+ FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
293
+ #endif // GGML_CUDA_FA_ALL_QUANTS
294
+
295
+ on_no_fattn_vec_case(Q->ne[0]);
296
+ }
297
+
298
+ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
299
+ const ggml_tensor * KQV = dst;
300
+ const ggml_tensor * Q = dst->src[0];
301
+ const ggml_tensor * K = dst->src[1];
302
+ const ggml_tensor * V = dst->src[2];
303
+
304
+ ggml_cuda_set_device(ctx.device);
305
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
306
+ const int32_t precision = KQV->op_params[2];
307
+
308
+ const bool quantized_KV = ggml_is_quantized(K->type) || ggml_is_quantized(V->type);
309
+
310
+ // On AMD the tile kernels perform poorly, use the vec kernel instead:
311
+ if (cc >= CC_OFFSET_AMD || quantized_KV) {
312
+ if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
313
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
314
+ } else {
315
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
316
+ }
317
+ return;
318
+ }
319
+
320
+ if (!fast_fp16_available(cc)) {
321
+ if (Q->ne[1] <= 8) {
322
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
323
+ } else {
324
+ ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
325
+ }
326
+ return;
327
+ }
328
+
329
+ if (!fp16_mma_available(cc)) {
330
+ if (Q->ne[1] <= 8) {
331
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
332
+ } else {
333
+ ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
334
+ }
335
+ return;
336
+ }
337
+
338
+ if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
339
+ if (precision == GGML_PREC_DEFAULT) {
340
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
341
+ return;
342
+ } else if(Q->ne[0] <= 128) {
343
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
344
+ return;
345
+ }
346
+ }
347
+
348
+ ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
349
  }
ggml-cuda/mmq.cu CHANGED
@@ -386,7 +386,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
386
  u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
387
  }
388
 
389
- return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
390
  (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
391
  }
392
 
@@ -547,7 +547,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
547
  const float * x_dmf = (const float *) x_dm;
548
  const float * y_df = (const float *) y_ds;
549
 
550
- return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
551
  (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
552
  y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
553
  }
 
386
  u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
387
  }
388
 
389
+ return vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
390
  (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
391
  }
392
 
 
547
  const float * x_dmf = (const float *) x_dm;
548
  const float * y_df = (const float *) y_ds;
549
 
550
+ return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
551
  (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
552
  y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
553
  }
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_F16);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate-variants.py, do not edit manually.
2
+
3
+ #include "../fattn-vec-f16.cuh"
4
+
5
+ DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_0);