uvos commited on
Commit
2adc060
·
1 Parent(s): 1e50161

CUDA/HIP: Fix fattn-vec-* when device warp size is not 32 (llama/12315)

Browse files

When fattn-wmma was ported over to warp64 various bits that also touch fattn-vec where converted to
selectable warp size, however the fattn-vec kernels dont work with 64 wide warps for now, so we need
to avoid launching them with parameters for warp64

ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -52,12 +52,11 @@ typedef half (*vec_dot_KQ_f16_t)(
52
  typedef float (*vec_dot_KQ_f32_t)(
53
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
54
 
55
- template<typename T, int D>
56
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
57
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
58
 
59
  const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
60
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
61
  GGML_UNUSED(Q_v);
62
 
63
  T sum = 0.0f;
@@ -93,12 +92,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
93
  return sum;
94
  }
95
 
96
- template<typename T, int D>
97
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
98
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
99
 
100
  const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
101
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
102
  GGML_UNUSED(Q_v);
103
 
104
  T sum = 0.0f;
@@ -138,12 +136,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
138
  return sum;
139
  }
140
 
141
- template<typename T, int D>
142
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
143
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
144
 
145
  const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
146
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
147
  GGML_UNUSED(Q_v);
148
 
149
  T sum = 0.0f;
@@ -186,12 +183,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
186
  return sum;
187
  }
188
 
189
- template<typename T, int D>
190
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
191
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
192
 
193
  const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
194
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
195
  GGML_UNUSED(Q_v);
196
 
197
  T sum = 0.0f;
@@ -238,12 +234,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
238
  return sum;
239
  }
240
 
241
- template <typename T, int D>
242
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
243
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
244
 
245
  const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
246
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
247
  GGML_UNUSED(Q_v);
248
 
249
  T sum = 0.0f;
@@ -272,12 +267,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
272
  return sum;
273
  }
274
 
275
- template <typename T, int D>
276
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
277
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
278
 
279
  const half2 * K_h2 = (const half2 *) K_c;
280
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
281
  GGML_UNUSED(Q_q8);
282
  GGML_UNUSED(Q_ds_v);
283
 
@@ -480,25 +474,25 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
480
  return x[i];
481
  }
482
 
483
- template <int D>
484
  constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
485
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
486
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
487
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
488
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
489
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
490
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
491
  nullptr;
492
  }
493
 
494
- template <int D>
495
  constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
496
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
497
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
498
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
499
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
500
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
501
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
502
  nullptr;
503
  }
504
 
@@ -681,7 +675,8 @@ static void on_no_fattn_vec_case(const int D) {
681
  template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
682
  void launch_fattn(
683
  ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
684
- const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
 
685
  ) {
686
  constexpr int ncols = ncols1 * ncols2;
687
 
@@ -704,8 +699,6 @@ void launch_fattn(
704
 
705
  GGML_ASSERT(Q->ne[3] == 1);
706
 
707
- const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
708
-
709
  ggml_cuda_pool & pool = ctx.pool();
710
  cudaStream_t main_stream = ctx.stream();
711
  const int id = ggml_cuda_get_device();
@@ -805,7 +798,6 @@ void launch_fattn(
805
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
806
 
807
  GGML_ASSERT(block_dim.x % warp_size == 0);
808
- GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
809
  fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
810
  (const char *) Q->data,
811
  K_data,
 
52
  typedef float (*vec_dot_KQ_f32_t)(
53
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
54
 
55
+ template<typename T, int D, int warp_size>
56
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
57
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
58
 
59
  const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
 
60
  GGML_UNUSED(Q_v);
61
 
62
  T sum = 0.0f;
 
92
  return sum;
93
  }
94
 
95
+ template<typename T, int D, int warp_size>
96
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
97
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
98
 
99
  const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
 
100
  GGML_UNUSED(Q_v);
101
 
102
  T sum = 0.0f;
 
136
  return sum;
137
  }
138
 
139
+ template<typename T, int D, int warp_size>
140
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
141
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
142
 
143
  const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
 
144
  GGML_UNUSED(Q_v);
145
 
146
  T sum = 0.0f;
 
183
  return sum;
184
  }
185
 
186
+ template<typename T, int D, int warp_size>
187
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
188
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
189
 
190
  const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
 
191
  GGML_UNUSED(Q_v);
192
 
193
  T sum = 0.0f;
 
234
  return sum;
235
  }
236
 
237
+ template <typename T, int D, int warp_size>
238
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
239
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
240
 
241
  const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
 
242
  GGML_UNUSED(Q_v);
243
 
244
  T sum = 0.0f;
 
267
  return sum;
268
  }
269
 
270
+ template <typename T, int D, int warp_size>
271
  static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
272
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
273
 
274
  const half2 * K_h2 = (const half2 *) K_c;
 
275
  GGML_UNUSED(Q_q8);
276
  GGML_UNUSED(Q_ds_v);
277
 
 
474
  return x[i];
475
  }
476
 
477
+ template <int D, int warp_size = WARP_SIZE>
478
  constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
479
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D, warp_size> :
480
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D, warp_size> :
481
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D, warp_size> :
482
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size> :
483
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size> :
484
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size> :
485
  nullptr;
486
  }
487
 
488
+ template <int D, int warp_size = WARP_SIZE>
489
  constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
490
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D, warp_size> :
491
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D, warp_size> :
492
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D, warp_size> :
493
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D, warp_size> :
494
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D, warp_size> :
495
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D, warp_size> :
496
  nullptr;
497
  }
498
 
 
675
  template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
676
  void launch_fattn(
677
  ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
678
+ const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
679
+ const int warp_size = WARP_SIZE
680
  ) {
681
  constexpr int ncols = ncols1 * ncols2;
682
 
 
699
 
700
  GGML_ASSERT(Q->ne[3] == 1);
701
 
 
 
702
  ggml_cuda_pool & pool = ctx.pool();
703
  cudaStream_t main_stream = ctx.stream();
704
  const int id = ggml_cuda_get_device();
 
798
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
799
 
800
  GGML_ASSERT(block_dim.x % warp_size == 0);
 
801
  fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
802
  (const char *) Q->data,
803
  K_data,
ggml/src/ggml-cuda/fattn-wmma-f16.cu CHANGED
@@ -469,6 +469,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
469
  constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
470
  const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
471
  const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
 
472
 
473
  float logit_softcap;
474
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
@@ -485,7 +486,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
485
  fattn_kernel = flash_attn_ext_f16<
486
  D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
487
  }
488
- launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
489
  return;
490
  }
491
  if (2*blocks_num_pb1 < 2*nsm) {
@@ -500,7 +501,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
500
  fattn_kernel = flash_attn_ext_f16<
501
  D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
502
  }
503
- launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
504
  return;
505
  }
506
  constexpr int parallel_blocks = 1;
@@ -514,7 +515,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
514
  fattn_kernel = flash_attn_ext_f16<
515
  D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
516
  }
517
- launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
518
  }
519
 
520
  void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
469
  constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
470
  const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
471
  const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
472
+ const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
473
 
474
  float logit_softcap;
475
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
486
  fattn_kernel = flash_attn_ext_f16<
487
  D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
488
  }
489
+ launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
490
  return;
491
  }
492
  if (2*blocks_num_pb1 < 2*nsm) {
 
501
  fattn_kernel = flash_attn_ext_f16<
502
  D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
503
  }
504
+ launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
505
  return;
506
  }
507
  constexpr int parallel_blocks = 1;
 
515
  fattn_kernel = flash_attn_ext_f16<
516
  D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
517
  }
518
+ launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
519
  }
520
 
521
  void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {