jeffbolznv commited on
Commit
f6b0b76
·
1 Parent(s): ebacb3e

vulkan: support softmax/FA batch and broadcast (llama/14449)

Browse files
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -633,6 +633,7 @@ struct vk_flash_attn_push_constants {
633
  uint32_t nev2;
634
  uint32_t nev3;
635
  uint32_t nem1;
 
636
 
637
  uint32_t nb01;
638
  uint32_t nb02;
@@ -643,7 +644,6 @@ struct vk_flash_attn_push_constants {
643
  uint32_t nb21;
644
  uint32_t nb22;
645
  uint32_t nb23;
646
- uint32_t nb31;
647
 
648
  float scale;
649
  float max_bias;
@@ -658,6 +658,7 @@ struct vk_flash_attn_push_constants {
658
  uint32_t split_kv;
659
  uint32_t k_num;
660
  };
 
661
 
662
  struct vk_op_push_constants {
663
  uint32_t KX;
@@ -756,6 +757,14 @@ struct vk_op_rope_push_constants {
756
  struct vk_op_soft_max_push_constants {
757
  uint32_t KX;
758
  uint32_t KY;
 
 
 
 
 
 
 
 
759
  float scale;
760
  float max_bias;
761
  float m0;
@@ -6040,7 +6049,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6040
  GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
6041
 
6042
  const uint32_t nem1 = mask ? mask->ne[1] : 0;
6043
- const uint32_t nbm1 = mask ? mask->nb[1] : 0;
6044
 
6045
  const uint32_t D = neq0;
6046
  uint32_t N = neq1;
@@ -6203,7 +6212,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6203
  // Try to use split_k when KV is large enough to be worth the overhead
6204
  if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
6205
  // Try to run two workgroups per SM.
6206
- split_k = ctx->device->shader_core_count * 2 / workgroups_y;
6207
  if (split_k > 1) {
6208
  // Try to evenly split KV into split_k chunks, but it needs to be a multiple
6209
  // of "align", so recompute split_k based on that.
@@ -6213,9 +6222,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6213
  }
6214
  }
6215
 
6216
- // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
6217
- // and the per-row m and L values (ne1 rows).
6218
- const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
6219
  if (split_k_size > ctx->device->max_memory_allocation_size) {
6220
  GGML_ABORT("Requested preallocation size is too large");
6221
  }
@@ -6307,11 +6316,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6307
  (uint32_t)neq2, (uint32_t)neq3,
6308
  (uint32_t)nek2, (uint32_t)nek3,
6309
  (uint32_t)nev2, (uint32_t)nev3,
6310
- nem1,
6311
  q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
6312
  k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
6313
  v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
6314
- nbm1,
6315
  scale, max_bias, logit_softcap,
6316
  mask != nullptr, n_head_log2, m0, m1,
6317
  gqa_ratio, split_kv, split_k };
@@ -6334,13 +6342,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
6334
  pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
6335
 
6336
  ggml_vk_sync_buffers(subctx);
6337
- const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
6338
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
6339
  {
6340
  vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6341
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6342
  },
6343
- pc2, { (uint32_t)ne1, 1, 1 });
6344
  } else {
6345
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
6346
  {
@@ -7666,7 +7674,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
7666
  const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
7667
  const uint32_t nrows_y = (uint32_t)src0->ne[1];
7668
 
7669
- const uint32_t n_head_kv = nrows_x/nrows_y;
 
 
 
 
 
 
7670
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
7671
 
7672
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -7675,6 +7689,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
7675
  ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
7676
  ncols,
7677
  src1 != nullptr ? nrows_y : (uint32_t)0,
 
 
 
7678
  scale, max_bias,
7679
  m0, m1,
7680
  n_head_log2,
@@ -10248,11 +10265,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10248
  if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
10249
  return false;
10250
  }
10251
- // TODO: support broadcast
10252
- // ref: https://github.com/ggml-org/llama.cpp/pull/14435
10253
- if (op->src[0]->ne[3] != 1) {
10254
- return false;
10255
- }
10256
  // It's straightforward to support different K/V dequant, but would
10257
  // significantly increase the number of pipelines
10258
  if (op->src[1]->type != op->src[2]->type) {
@@ -10413,13 +10425,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10413
  case GGML_OP_DIAG_MASK_INF:
10414
  return true;
10415
  case GGML_OP_SOFT_MAX:
10416
- // TODO: support batching
10417
- if (op->src[0]->ne[3] != 1) {
10418
- return false;
10419
- }
10420
- // TODO: support broadcast
10421
- // ref: https://github.com/ggml-org/llama.cpp/pull/14435
10422
- return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
10423
  case GGML_OP_SOFT_MAX_BACK:
10424
  case GGML_OP_ARGSORT:
10425
  case GGML_OP_SUM:
 
633
  uint32_t nev2;
634
  uint32_t nev3;
635
  uint32_t nem1;
636
+ uint32_t nem2;
637
 
638
  uint32_t nb01;
639
  uint32_t nb02;
 
644
  uint32_t nb21;
645
  uint32_t nb22;
646
  uint32_t nb23;
 
647
 
648
  float scale;
649
  float max_bias;
 
658
  uint32_t split_kv;
659
  uint32_t k_num;
660
  };
661
+ static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
662
 
663
  struct vk_op_push_constants {
664
  uint32_t KX;
 
757
  struct vk_op_soft_max_push_constants {
758
  uint32_t KX;
759
  uint32_t KY;
760
+ uint32_t ne00;
761
+ uint32_t ne01;
762
+ uint32_t ne02;
763
+ uint32_t ne12;
764
+ uint32_t ne13;
765
+ uint32_t nb11;
766
+ uint32_t nb12;
767
+ uint32_t nb13;
768
  float scale;
769
  float max_bias;
770
  float m0;
 
6049
  GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
6050
 
6051
  const uint32_t nem1 = mask ? mask->ne[1] : 0;
6052
+ const uint32_t nem2 = mask ? mask->ne[2] : 0;
6053
 
6054
  const uint32_t D = neq0;
6055
  uint32_t N = neq1;
 
6212
  // Try to use split_k when KV is large enough to be worth the overhead
6213
  if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
6214
  // Try to run two workgroups per SM.
6215
+ split_k = ctx->device->shader_core_count * 2 / (workgroups_y * workgroups_z);
6216
  if (split_k > 1) {
6217
  // Try to evenly split KV into split_k chunks, but it needs to be a multiple
6218
  // of "align", so recompute split_k based on that.
 
6222
  }
6223
  }
6224
 
6225
+ // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
6226
+ // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
6227
+ const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
6228
  if (split_k_size > ctx->device->max_memory_allocation_size) {
6229
  GGML_ABORT("Requested preallocation size is too large");
6230
  }
 
6316
  (uint32_t)neq2, (uint32_t)neq3,
6317
  (uint32_t)nek2, (uint32_t)nek3,
6318
  (uint32_t)nev2, (uint32_t)nev3,
6319
+ nem1, nem2,
6320
  q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
6321
  k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
6322
  v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
 
6323
  scale, max_bias, logit_softcap,
6324
  mask != nullptr, n_head_log2, m0, m1,
6325
  gqa_ratio, split_kv, split_k };
 
6342
  pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
6343
 
6344
  ggml_vk_sync_buffers(subctx);
6345
+ const std::array<uint32_t, 4> pc2 = { D, (uint32_t)ne1, (uint32_t)ne3, split_k };
6346
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
6347
  {
6348
  vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6349
  vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6350
  },
6351
+ pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
6352
  } else {
6353
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
6354
  {
 
7674
  const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
7675
  const uint32_t nrows_y = (uint32_t)src0->ne[1];
7676
 
7677
+ const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
7678
+ const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
7679
+ const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
7680
+ const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
7681
+ const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
7682
+
7683
+ const uint32_t n_head_kv = src0->ne[2];
7684
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
7685
 
7686
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
 
7689
  ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
7690
  ncols,
7691
  src1 != nullptr ? nrows_y : (uint32_t)0,
7692
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
7693
+ ne12, ne13,
7694
+ nb11, nb12, nb13,
7695
  scale, max_bias,
7696
  m0, m1,
7697
  n_head_log2,
 
10265
  if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
10266
  return false;
10267
  }
 
 
 
 
 
10268
  // It's straightforward to support different K/V dequant, but would
10269
  // significantly increase the number of pipelines
10270
  if (op->src[1]->type != op->src[2]->type) {
 
10425
  case GGML_OP_DIAG_MASK_INF:
10426
  return true;
10427
  case GGML_OP_SOFT_MAX:
 
 
 
 
 
 
 
10428
  case GGML_OP_SOFT_MAX_BACK:
10429
  case GGML_OP_ARGSORT:
10430
  case GGML_OP_SUM:
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp CHANGED
@@ -99,6 +99,10 @@ void main() {
99
  uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
100
  uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
101
  #endif
 
 
 
 
102
 
103
  [[dont_unroll]]
104
  for (uint32_t j = start_j; j < end_j; ++j) {
@@ -150,7 +154,7 @@ void main() {
150
  uint32_t c = (idx + tid) % Bc;
151
  uint32_t r = (idx + tid) / Bc;
152
  if (idx + tid < Bc * Br) {
153
- masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
154
  }
155
  }
156
  barrier();
@@ -277,7 +281,7 @@ void main() {
277
  // If there is split_k, then the split_k resolve shader does the final
278
  // division by L. Store the intermediate O value and per-row m and L values.
279
  if (p.k_num > 1) {
280
- uint32_t o_offset = D * p.ne1 * split_k_index;
281
 
282
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
283
  if (r < N) {
@@ -289,7 +293,7 @@ void main() {
289
  }
290
  }
291
 
292
- o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
293
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
294
  if (r < N) {
295
  perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -311,7 +315,7 @@ void main() {
311
  }
312
  }
313
 
314
- uint32_t o_offset = iq3*p.ne2*p.ne1;
315
 
316
  if (p.gqa_ratio > 1) {
317
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
 
99
  uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
100
  uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
101
  #endif
102
+ uint32_t m_offset = 0;
103
+ if (p.nem2 != 1) {
104
+ m_offset = (iq3 % p.nem2) * p.nem1 * KV;
105
+ }
106
 
107
  [[dont_unroll]]
108
  for (uint32_t j = start_j; j < end_j; ++j) {
 
154
  uint32_t c = (idx + tid) % Bc;
155
  uint32_t r = (idx + tid) / Bc;
156
  if (idx + tid < Bc * Br) {
157
+ masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
158
  }
159
  }
160
  barrier();
 
281
  // If there is split_k, then the split_k resolve shader does the final
282
  // division by L. Store the intermediate O value and per-row m and L values.
283
  if (p.k_num > 1) {
284
+ uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
285
 
286
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
287
  if (r < N) {
 
293
  }
294
  }
295
 
296
+ o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
297
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
298
  if (r < N) {
299
  perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
 
315
  }
316
  }
317
 
318
+ uint32_t o_offset = iq3*p.ne2*p.ne1*D;
319
 
320
  if (p.gqa_ratio > 1) {
321
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp CHANGED
@@ -24,6 +24,7 @@ layout (push_constant) uniform parameter {
24
  uint32_t nev2;
25
  uint32_t nev3;
26
  uint32_t nem1;
 
27
 
28
  uint32_t nb01;
29
  uint32_t nb02;
@@ -34,7 +35,6 @@ layout (push_constant) uniform parameter {
34
  uint32_t nb21;
35
  uint32_t nb22;
36
  uint32_t nb23;
37
- uint32_t nb31;
38
 
39
  float scale;
40
  float max_bias;
 
24
  uint32_t nev2;
25
  uint32_t nev3;
26
  uint32_t nem1;
27
+ uint32_t nem2;
28
 
29
  uint32_t nb01;
30
  uint32_t nb02;
 
35
  uint32_t nb21;
36
  uint32_t nb22;
37
  uint32_t nb23;
 
38
 
39
  float scale;
40
  float max_bias;
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp CHANGED
@@ -123,6 +123,10 @@ void main() {
123
  uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
124
  uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
125
  #endif
 
 
 
 
126
 
127
  [[dont_unroll]]
128
  for (uint32_t j = start_j; j < end_j; ++j) {
@@ -181,7 +185,7 @@ void main() {
181
  uint32_t c = (idx + tid) % Bc;
182
  uint32_t r = (idx + tid) / Bc;
183
  if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
184
- sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
185
  }
186
  }
187
  barrier();
@@ -300,7 +304,7 @@ void main() {
300
  // If there is split_k, then the split_k resolve shader does the final
301
  // division by L. Store the intermediate O value and per-row m and L values.
302
  if (p.k_num > 1) {
303
- uint32_t o_offset = D * p.ne1 * split_k_index;
304
 
305
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
306
  if (tile_row(r) < N) {
@@ -312,7 +316,7 @@ void main() {
312
  }
313
  }
314
 
315
- o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
316
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
317
  if (tile_row(r) < N) {
318
  perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -334,7 +338,7 @@ void main() {
334
  }
335
  }
336
 
337
- uint32_t o_offset = iq3*p.ne2*p.ne1;
338
 
339
  if (p.gqa_ratio > 1) {
340
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
 
123
  uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
124
  uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
125
  #endif
126
+ uint32_t m_offset = 0;
127
+ if (p.nem2 != 1) {
128
+ m_offset = (iq3 % p.nem2) * p.nem1 * KV;
129
+ }
130
 
131
  [[dont_unroll]]
132
  for (uint32_t j = start_j; j < end_j; ++j) {
 
185
  uint32_t c = (idx + tid) % Bc;
186
  uint32_t r = (idx + tid) / Bc;
187
  if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
188
+ sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
189
  }
190
  }
191
  barrier();
 
304
  // If there is split_k, then the split_k resolve shader does the final
305
  // division by L. Store the intermediate O value and per-row m and L values.
306
  if (p.k_num > 1) {
307
+ uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
308
 
309
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
310
  if (tile_row(r) < N) {
 
316
  }
317
  }
318
 
319
+ o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
320
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
321
  if (tile_row(r) < N) {
322
  perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
 
338
  }
339
  }
340
 
341
+ uint32_t o_offset = iq3*p.ne2*p.ne1*D;
342
 
343
  if (p.gqa_ratio > 1) {
344
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp CHANGED
@@ -130,6 +130,11 @@ void main() {
130
  coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
131
  }
132
 
 
 
 
 
 
133
  [[dont_unroll]]
134
  for (uint32_t j = start_j; j < end_j; ++j) {
135
 
@@ -155,7 +160,7 @@ void main() {
155
 
156
  coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
157
 
158
- coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
159
 
160
  S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
161
  }
@@ -229,10 +234,10 @@ void main() {
229
  if (p.k_num > 1) {
230
  coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
231
 
232
- uint32_t o_offset = D * p.ne1 * split_k_index;
233
  coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
234
 
235
- o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
236
  coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
237
  coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
238
  return;
@@ -250,7 +255,7 @@ void main() {
250
 
251
  O = Ldiag*O;
252
 
253
- uint32_t o_offset = iq3*p.ne2*p.ne1;
254
 
255
  coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
256
  if (p.gqa_ratio > 1) {
 
130
  coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
131
  }
132
 
133
+ uint32_t m_offset = 0;
134
+ if (p.nem2 != 1) {
135
+ m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
136
+ }
137
+
138
  [[dont_unroll]]
139
  for (uint32_t j = start_j; j < end_j; ++j) {
140
 
 
160
 
161
  coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
162
 
163
+ coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
164
 
165
  S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
166
  }
 
234
  if (p.k_num > 1) {
235
  coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
236
 
237
+ uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
238
  coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
239
 
240
+ o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
241
  coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
242
  coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
243
  return;
 
255
 
256
  O = Ldiag*O;
257
 
258
+ uint32_t o_offset = iq3*p.ne2*p.ne1*D;
259
 
260
  coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
261
  if (p.gqa_ratio > 1) {
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp CHANGED
@@ -12,6 +12,7 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
12
  layout (push_constant) uniform parameter {
13
  uint D;
14
  uint N;
 
15
  uint k_num;
16
  } p;
17
 
@@ -19,13 +20,14 @@ void main() {
19
  // Each workgroup handles a row
20
  const uint n = gl_WorkGroupID.x;
21
  const uint tid = gl_LocalInvocationID.x;
 
22
 
23
  uint D = p.D;
24
  uint N = p.N;
25
  uint k_num = p.k_num;
26
 
27
- uint l_offset = D * N * k_num + n;
28
- uint m_offset = D * N * k_num + N + n;
29
  uint lm_stride = N * 2;
30
 
31
  // Compute the max m value for the row
@@ -49,11 +51,11 @@ void main() {
49
  for (uint d = tid; d < D; d += BLOCK_SIZE) {
50
  float O = 0.0;
51
  [[unroll]] for (uint k = 0; k < k_num; ++k) {
52
- uint o_offset = D * N * k + D * n + d;
53
  float m = data_a[m_offset + k * lm_stride];
54
  O += exp(m - m_max) * data_a[o_offset];
55
  }
56
  O *= L;
57
- data_d[D * n + d] = O;
58
  }
59
  }
 
12
  layout (push_constant) uniform parameter {
13
  uint D;
14
  uint N;
15
+ uint ne3;
16
  uint k_num;
17
  } p;
18
 
 
20
  // Each workgroup handles a row
21
  const uint n = gl_WorkGroupID.x;
22
  const uint tid = gl_LocalInvocationID.x;
23
+ const uint iq3 = gl_WorkGroupID.z;
24
 
25
  uint D = p.D;
26
  uint N = p.N;
27
  uint k_num = p.k_num;
28
 
29
+ uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
30
+ uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
31
  uint lm_stride = N * 2;
32
 
33
  // Compute the max m value for the row
 
51
  for (uint d = tid; d < D; d += BLOCK_SIZE) {
52
  float O = 0.0;
53
  [[unroll]] for (uint k = 0; k < k_num; ++k) {
54
+ uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
55
  float m = data_a[m_offset + k * lm_stride];
56
  O += exp(m - m_max) * data_a[o_offset];
57
  }
58
  O *= L;
59
+ data_d[iq3 * D * N + D * n + d] = O;
60
  }
61
  }
ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp CHANGED
@@ -6,6 +6,14 @@ layout (push_constant) uniform parameter
6
  {
7
  uint KX;
8
  uint KY;
 
 
 
 
 
 
 
 
9
  float scale;
10
  float max_bias;
11
  float m0;
@@ -31,7 +39,15 @@ shared FLOAT_TYPE vals[BLOCK_SIZE];
31
  void soft_max(uint num_iters) {
32
  const uint tid = gl_LocalInvocationID.x;
33
  const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
34
- const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0;
 
 
 
 
 
 
 
 
35
 
36
  if (rowx >= p.nrows_x) {
37
  return;
@@ -41,7 +57,7 @@ void soft_max(uint num_iters) {
41
 
42
  // ALiBi
43
  if (p.max_bias > 0.0f) {
44
- const uint h = rowx/p.KY; // head index
45
 
46
  const float base = h < p.n_head_log2 ? p.m0 : p.m1;
47
  const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
@@ -67,7 +83,7 @@ void soft_max(uint num_iters) {
67
 
68
  FLOAT_TYPE b = FLOAT_TYPE(0);
69
  if (p.KY > 0 && col < p.KX) {
70
- b = data_b[rowy * p.KX + col];
71
  }
72
 
73
  FLOAT_TYPE v = a * p.scale + slope * b;
@@ -111,7 +127,7 @@ void soft_max(uint num_iters) {
111
  if (idx < DATA_CACHE_SIZE) {
112
  val = exp(data_cache[idx] - max_val);
113
  } else {
114
- val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
115
  }
116
  sum += val;
117
  if (idx < DATA_CACHE_SIZE) {
 
6
  {
7
  uint KX;
8
  uint KY;
9
+ uint ne00;
10
+ uint ne01;
11
+ uint ne02;
12
+ uint ne12;
13
+ uint ne13;
14
+ uint nb11;
15
+ uint nb12;
16
+ uint nb13;
17
  float scale;
18
  float max_bias;
19
  float m0;
 
39
  void soft_max(uint num_iters) {
40
  const uint tid = gl_LocalInvocationID.x;
41
  const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
42
+
43
+ const uint32_t i03 = rowx / (p.ne01 * p.ne02);
44
+ const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
45
+ const uint32_t i01 = rowx % p.ne01;
46
+
47
+ uint rowy_start = 0;
48
+ if (p.KY > 0) {
49
+ rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
50
+ }
51
 
52
  if (rowx >= p.nrows_x) {
53
  return;
 
57
 
58
  // ALiBi
59
  if (p.max_bias > 0.0f) {
60
+ const uint h = (rowx / p.ne01) % p.ne02; // head index
61
 
62
  const float base = h < p.n_head_log2 ? p.m0 : p.m1;
63
  const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
 
83
 
84
  FLOAT_TYPE b = FLOAT_TYPE(0);
85
  if (p.KY > 0 && col < p.KX) {
86
+ b = data_b[rowy_start + col];
87
  }
88
 
89
  FLOAT_TYPE v = a * p.scale + slope * b;
 
127
  if (idx < DATA_CACHE_SIZE) {
128
  val = exp(data_cache[idx] - max_val);
129
  } else {
130
+ val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
131
  }
132
  sum += val;
133
  if (idx < DATA_CACHE_SIZE) {