Spaces:
Running
Running
Commit
·
f6b0b76
1
Parent(s):
ebacb3e
vulkan: support softmax/FA batch and broadcast (llama/14449)
Browse files- ggml/src/ggml-vulkan/ggml-vulkan.cpp +28 -23
- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +8 -4
- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +1 -1
- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +8 -4
- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +9 -4
- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +6 -4
- ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -633,6 +633,7 @@ struct vk_flash_attn_push_constants {
|
|
| 633 |
uint32_t nev2;
|
| 634 |
uint32_t nev3;
|
| 635 |
uint32_t nem1;
|
|
|
|
| 636 |
|
| 637 |
uint32_t nb01;
|
| 638 |
uint32_t nb02;
|
|
@@ -643,7 +644,6 @@ struct vk_flash_attn_push_constants {
|
|
| 643 |
uint32_t nb21;
|
| 644 |
uint32_t nb22;
|
| 645 |
uint32_t nb23;
|
| 646 |
-
uint32_t nb31;
|
| 647 |
|
| 648 |
float scale;
|
| 649 |
float max_bias;
|
|
@@ -658,6 +658,7 @@ struct vk_flash_attn_push_constants {
|
|
| 658 |
uint32_t split_kv;
|
| 659 |
uint32_t k_num;
|
| 660 |
};
|
|
|
|
| 661 |
|
| 662 |
struct vk_op_push_constants {
|
| 663 |
uint32_t KX;
|
|
@@ -756,6 +757,14 @@ struct vk_op_rope_push_constants {
|
|
| 756 |
struct vk_op_soft_max_push_constants {
|
| 757 |
uint32_t KX;
|
| 758 |
uint32_t KY;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 759 |
float scale;
|
| 760 |
float max_bias;
|
| 761 |
float m0;
|
|
@@ -6040,7 +6049,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 6040 |
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
| 6041 |
|
| 6042 |
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
| 6043 |
-
const uint32_t
|
| 6044 |
|
| 6045 |
const uint32_t D = neq0;
|
| 6046 |
uint32_t N = neq1;
|
|
@@ -6203,7 +6212,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 6203 |
// Try to use split_k when KV is large enough to be worth the overhead
|
| 6204 |
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
| 6205 |
// Try to run two workgroups per SM.
|
| 6206 |
-
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
|
| 6207 |
if (split_k > 1) {
|
| 6208 |
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
| 6209 |
// of "align", so recompute split_k based on that.
|
|
@@ -6213,9 +6222,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 6213 |
}
|
| 6214 |
}
|
| 6215 |
|
| 6216 |
-
// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
|
| 6217 |
-
// and the per-row m and L values (ne1 rows).
|
| 6218 |
-
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
|
| 6219 |
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
| 6220 |
GGML_ABORT("Requested preallocation size is too large");
|
| 6221 |
}
|
|
@@ -6307,11 +6316,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 6307 |
(uint32_t)neq2, (uint32_t)neq3,
|
| 6308 |
(uint32_t)nek2, (uint32_t)nek3,
|
| 6309 |
(uint32_t)nev2, (uint32_t)nev3,
|
| 6310 |
-
nem1,
|
| 6311 |
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
|
| 6312 |
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
|
| 6313 |
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
| 6314 |
-
nbm1,
|
| 6315 |
scale, max_bias, logit_softcap,
|
| 6316 |
mask != nullptr, n_head_log2, m0, m1,
|
| 6317 |
gqa_ratio, split_kv, split_k };
|
|
@@ -6334,13 +6342,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 6334 |
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
| 6335 |
|
| 6336 |
ggml_vk_sync_buffers(subctx);
|
| 6337 |
-
const std::array<uint32_t,
|
| 6338 |
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
| 6339 |
{
|
| 6340 |
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
| 6341 |
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
| 6342 |
},
|
| 6343 |
-
pc2, { (uint32_t)ne1, 1,
|
| 6344 |
} else {
|
| 6345 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
| 6346 |
{
|
|
@@ -7666,7 +7674,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
| 7666 |
const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
|
| 7667 |
const uint32_t nrows_y = (uint32_t)src0->ne[1];
|
| 7668 |
|
| 7669 |
-
const uint32_t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7670 |
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
| 7671 |
|
| 7672 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
@@ -7675,6 +7689,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
| 7675 |
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
|
| 7676 |
ncols,
|
| 7677 |
src1 != nullptr ? nrows_y : (uint32_t)0,
|
|
|
|
|
|
|
|
|
|
| 7678 |
scale, max_bias,
|
| 7679 |
m0, m1,
|
| 7680 |
n_head_log2,
|
|
@@ -10248,11 +10265,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 10248 |
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
| 10249 |
return false;
|
| 10250 |
}
|
| 10251 |
-
// TODO: support broadcast
|
| 10252 |
-
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
| 10253 |
-
if (op->src[0]->ne[3] != 1) {
|
| 10254 |
-
return false;
|
| 10255 |
-
}
|
| 10256 |
// It's straightforward to support different K/V dequant, but would
|
| 10257 |
// significantly increase the number of pipelines
|
| 10258 |
if (op->src[1]->type != op->src[2]->type) {
|
|
@@ -10413,13 +10425,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 10413 |
case GGML_OP_DIAG_MASK_INF:
|
| 10414 |
return true;
|
| 10415 |
case GGML_OP_SOFT_MAX:
|
| 10416 |
-
// TODO: support batching
|
| 10417 |
-
if (op->src[0]->ne[3] != 1) {
|
| 10418 |
-
return false;
|
| 10419 |
-
}
|
| 10420 |
-
// TODO: support broadcast
|
| 10421 |
-
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
| 10422 |
-
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
|
| 10423 |
case GGML_OP_SOFT_MAX_BACK:
|
| 10424 |
case GGML_OP_ARGSORT:
|
| 10425 |
case GGML_OP_SUM:
|
|
|
|
| 633 |
uint32_t nev2;
|
| 634 |
uint32_t nev3;
|
| 635 |
uint32_t nem1;
|
| 636 |
+
uint32_t nem2;
|
| 637 |
|
| 638 |
uint32_t nb01;
|
| 639 |
uint32_t nb02;
|
|
|
|
| 644 |
uint32_t nb21;
|
| 645 |
uint32_t nb22;
|
| 646 |
uint32_t nb23;
|
|
|
|
| 647 |
|
| 648 |
float scale;
|
| 649 |
float max_bias;
|
|
|
|
| 658 |
uint32_t split_kv;
|
| 659 |
uint32_t k_num;
|
| 660 |
};
|
| 661 |
+
static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
|
| 662 |
|
| 663 |
struct vk_op_push_constants {
|
| 664 |
uint32_t KX;
|
|
|
|
| 757 |
struct vk_op_soft_max_push_constants {
|
| 758 |
uint32_t KX;
|
| 759 |
uint32_t KY;
|
| 760 |
+
uint32_t ne00;
|
| 761 |
+
uint32_t ne01;
|
| 762 |
+
uint32_t ne02;
|
| 763 |
+
uint32_t ne12;
|
| 764 |
+
uint32_t ne13;
|
| 765 |
+
uint32_t nb11;
|
| 766 |
+
uint32_t nb12;
|
| 767 |
+
uint32_t nb13;
|
| 768 |
float scale;
|
| 769 |
float max_bias;
|
| 770 |
float m0;
|
|
|
|
| 6049 |
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
| 6050 |
|
| 6051 |
const uint32_t nem1 = mask ? mask->ne[1] : 0;
|
| 6052 |
+
const uint32_t nem2 = mask ? mask->ne[2] : 0;
|
| 6053 |
|
| 6054 |
const uint32_t D = neq0;
|
| 6055 |
uint32_t N = neq1;
|
|
|
|
| 6212 |
// Try to use split_k when KV is large enough to be worth the overhead
|
| 6213 |
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
| 6214 |
// Try to run two workgroups per SM.
|
| 6215 |
+
split_k = ctx->device->shader_core_count * 2 / (workgroups_y * workgroups_z);
|
| 6216 |
if (split_k > 1) {
|
| 6217 |
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
| 6218 |
// of "align", so recompute split_k based on that.
|
|
|
|
| 6222 |
}
|
| 6223 |
}
|
| 6224 |
|
| 6225 |
+
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
| 6226 |
+
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
| 6227 |
+
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
|
| 6228 |
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
| 6229 |
GGML_ABORT("Requested preallocation size is too large");
|
| 6230 |
}
|
|
|
|
| 6316 |
(uint32_t)neq2, (uint32_t)neq3,
|
| 6317 |
(uint32_t)nek2, (uint32_t)nek3,
|
| 6318 |
(uint32_t)nev2, (uint32_t)nev3,
|
| 6319 |
+
nem1, nem2,
|
| 6320 |
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
|
| 6321 |
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
|
| 6322 |
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
|
|
|
| 6323 |
scale, max_bias, logit_softcap,
|
| 6324 |
mask != nullptr, n_head_log2, m0, m1,
|
| 6325 |
gqa_ratio, split_kv, split_k };
|
|
|
|
| 6342 |
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
| 6343 |
|
| 6344 |
ggml_vk_sync_buffers(subctx);
|
| 6345 |
+
const std::array<uint32_t, 4> pc2 = { D, (uint32_t)ne1, (uint32_t)ne3, split_k };
|
| 6346 |
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
| 6347 |
{
|
| 6348 |
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
| 6349 |
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
| 6350 |
},
|
| 6351 |
+
pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
|
| 6352 |
} else {
|
| 6353 |
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
| 6354 |
{
|
|
|
|
| 7674 |
const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
|
| 7675 |
const uint32_t nrows_y = (uint32_t)src0->ne[1];
|
| 7676 |
|
| 7677 |
+
const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
|
| 7678 |
+
const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
|
| 7679 |
+
const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
|
| 7680 |
+
const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
|
| 7681 |
+
const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
|
| 7682 |
+
|
| 7683 |
+
const uint32_t n_head_kv = src0->ne[2];
|
| 7684 |
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
| 7685 |
|
| 7686 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
|
|
| 7689 |
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
|
| 7690 |
ncols,
|
| 7691 |
src1 != nullptr ? nrows_y : (uint32_t)0,
|
| 7692 |
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
| 7693 |
+
ne12, ne13,
|
| 7694 |
+
nb11, nb12, nb13,
|
| 7695 |
scale, max_bias,
|
| 7696 |
m0, m1,
|
| 7697 |
n_head_log2,
|
|
|
|
| 10265 |
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
| 10266 |
return false;
|
| 10267 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10268 |
// It's straightforward to support different K/V dequant, but would
|
| 10269 |
// significantly increase the number of pipelines
|
| 10270 |
if (op->src[1]->type != op->src[2]->type) {
|
|
|
|
| 10425 |
case GGML_OP_DIAG_MASK_INF:
|
| 10426 |
return true;
|
| 10427 |
case GGML_OP_SOFT_MAX:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10428 |
case GGML_OP_SOFT_MAX_BACK:
|
| 10429 |
case GGML_OP_ARGSORT:
|
| 10430 |
case GGML_OP_SUM:
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
CHANGED
|
@@ -99,6 +99,10 @@ void main() {
|
|
| 99 |
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
| 100 |
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
| 101 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
[[dont_unroll]]
|
| 104 |
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
@@ -150,7 +154,7 @@ void main() {
|
|
| 150 |
uint32_t c = (idx + tid) % Bc;
|
| 151 |
uint32_t r = (idx + tid) / Bc;
|
| 152 |
if (idx + tid < Bc * Br) {
|
| 153 |
-
masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
|
| 154 |
}
|
| 155 |
}
|
| 156 |
barrier();
|
|
@@ -277,7 +281,7 @@ void main() {
|
|
| 277 |
// If there is split_k, then the split_k resolve shader does the final
|
| 278 |
// division by L. Store the intermediate O value and per-row m and L values.
|
| 279 |
if (p.k_num > 1) {
|
| 280 |
-
uint32_t o_offset = D * p.ne1 * split_k_index;
|
| 281 |
|
| 282 |
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
| 283 |
if (r < N) {
|
|
@@ -289,7 +293,7 @@ void main() {
|
|
| 289 |
}
|
| 290 |
}
|
| 291 |
|
| 292 |
-
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
| 293 |
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
| 294 |
if (r < N) {
|
| 295 |
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
|
@@ -311,7 +315,7 @@ void main() {
|
|
| 311 |
}
|
| 312 |
}
|
| 313 |
|
| 314 |
-
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
| 315 |
|
| 316 |
if (p.gqa_ratio > 1) {
|
| 317 |
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
|
|
| 99 |
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
| 100 |
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
| 101 |
#endif
|
| 102 |
+
uint32_t m_offset = 0;
|
| 103 |
+
if (p.nem2 != 1) {
|
| 104 |
+
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
|
| 105 |
+
}
|
| 106 |
|
| 107 |
[[dont_unroll]]
|
| 108 |
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
|
|
| 154 |
uint32_t c = (idx + tid) % Bc;
|
| 155 |
uint32_t r = (idx + tid) / Bc;
|
| 156 |
if (idx + tid < Bc * Br) {
|
| 157 |
+
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
| 158 |
}
|
| 159 |
}
|
| 160 |
barrier();
|
|
|
|
| 281 |
// If there is split_k, then the split_k resolve shader does the final
|
| 282 |
// division by L. Store the intermediate O value and per-row m and L values.
|
| 283 |
if (p.k_num > 1) {
|
| 284 |
+
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
|
| 285 |
|
| 286 |
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
| 287 |
if (r < N) {
|
|
|
|
| 293 |
}
|
| 294 |
}
|
| 295 |
|
| 296 |
+
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
| 297 |
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
| 298 |
if (r < N) {
|
| 299 |
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
|
|
|
| 315 |
}
|
| 316 |
}
|
| 317 |
|
| 318 |
+
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
|
| 319 |
|
| 320 |
if (p.gqa_ratio > 1) {
|
| 321 |
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
CHANGED
|
@@ -24,6 +24,7 @@ layout (push_constant) uniform parameter {
|
|
| 24 |
uint32_t nev2;
|
| 25 |
uint32_t nev3;
|
| 26 |
uint32_t nem1;
|
|
|
|
| 27 |
|
| 28 |
uint32_t nb01;
|
| 29 |
uint32_t nb02;
|
|
@@ -34,7 +35,6 @@ layout (push_constant) uniform parameter {
|
|
| 34 |
uint32_t nb21;
|
| 35 |
uint32_t nb22;
|
| 36 |
uint32_t nb23;
|
| 37 |
-
uint32_t nb31;
|
| 38 |
|
| 39 |
float scale;
|
| 40 |
float max_bias;
|
|
|
|
| 24 |
uint32_t nev2;
|
| 25 |
uint32_t nev3;
|
| 26 |
uint32_t nem1;
|
| 27 |
+
uint32_t nem2;
|
| 28 |
|
| 29 |
uint32_t nb01;
|
| 30 |
uint32_t nb02;
|
|
|
|
| 35 |
uint32_t nb21;
|
| 36 |
uint32_t nb22;
|
| 37 |
uint32_t nb23;
|
|
|
|
| 38 |
|
| 39 |
float scale;
|
| 40 |
float max_bias;
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
CHANGED
|
@@ -123,6 +123,10 @@ void main() {
|
|
| 123 |
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
| 124 |
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
| 125 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
[[dont_unroll]]
|
| 128 |
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
@@ -181,7 +185,7 @@ void main() {
|
|
| 181 |
uint32_t c = (idx + tid) % Bc;
|
| 182 |
uint32_t r = (idx + tid) / Bc;
|
| 183 |
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
| 184 |
-
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
|
| 185 |
}
|
| 186 |
}
|
| 187 |
barrier();
|
|
@@ -300,7 +304,7 @@ void main() {
|
|
| 300 |
// If there is split_k, then the split_k resolve shader does the final
|
| 301 |
// division by L. Store the intermediate O value and per-row m and L values.
|
| 302 |
if (p.k_num > 1) {
|
| 303 |
-
uint32_t o_offset = D * p.ne1 * split_k_index;
|
| 304 |
|
| 305 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
| 306 |
if (tile_row(r) < N) {
|
|
@@ -312,7 +316,7 @@ void main() {
|
|
| 312 |
}
|
| 313 |
}
|
| 314 |
|
| 315 |
-
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
| 316 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
| 317 |
if (tile_row(r) < N) {
|
| 318 |
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
|
@@ -334,7 +338,7 @@ void main() {
|
|
| 334 |
}
|
| 335 |
}
|
| 336 |
|
| 337 |
-
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
| 338 |
|
| 339 |
if (p.gqa_ratio > 1) {
|
| 340 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
|
|
| 123 |
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
| 124 |
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
| 125 |
#endif
|
| 126 |
+
uint32_t m_offset = 0;
|
| 127 |
+
if (p.nem2 != 1) {
|
| 128 |
+
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
|
| 129 |
+
}
|
| 130 |
|
| 131 |
[[dont_unroll]]
|
| 132 |
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
|
|
| 185 |
uint32_t c = (idx + tid) % Bc;
|
| 186 |
uint32_t r = (idx + tid) / Bc;
|
| 187 |
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
| 188 |
+
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
| 189 |
}
|
| 190 |
}
|
| 191 |
barrier();
|
|
|
|
| 304 |
// If there is split_k, then the split_k resolve shader does the final
|
| 305 |
// division by L. Store the intermediate O value and per-row m and L values.
|
| 306 |
if (p.k_num > 1) {
|
| 307 |
+
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
|
| 308 |
|
| 309 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
| 310 |
if (tile_row(r) < N) {
|
|
|
|
| 316 |
}
|
| 317 |
}
|
| 318 |
|
| 319 |
+
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
| 320 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
| 321 |
if (tile_row(r) < N) {
|
| 322 |
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
|
|
|
| 338 |
}
|
| 339 |
}
|
| 340 |
|
| 341 |
+
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
|
| 342 |
|
| 343 |
if (p.gqa_ratio > 1) {
|
| 344 |
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
CHANGED
|
@@ -130,6 +130,11 @@ void main() {
|
|
| 130 |
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
|
| 131 |
}
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
[[dont_unroll]]
|
| 134 |
for (uint32_t j = start_j; j < end_j; ++j) {
|
| 135 |
|
|
@@ -155,7 +160,7 @@ void main() {
|
|
| 155 |
|
| 156 |
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
| 157 |
|
| 158 |
-
coopMatLoadTensorNV(mv, data_m,
|
| 159 |
|
| 160 |
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
| 161 |
}
|
|
@@ -229,10 +234,10 @@ void main() {
|
|
| 229 |
if (p.k_num > 1) {
|
| 230 |
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
| 231 |
|
| 232 |
-
uint32_t o_offset = D * p.ne1 * split_k_index;
|
| 233 |
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
| 234 |
|
| 235 |
-
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
| 236 |
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
| 237 |
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
| 238 |
return;
|
|
@@ -250,7 +255,7 @@ void main() {
|
|
| 250 |
|
| 251 |
O = Ldiag*O;
|
| 252 |
|
| 253 |
-
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
| 254 |
|
| 255 |
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
| 256 |
if (p.gqa_ratio > 1) {
|
|
|
|
| 130 |
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
|
| 131 |
}
|
| 132 |
|
| 133 |
+
uint32_t m_offset = 0;
|
| 134 |
+
if (p.nem2 != 1) {
|
| 135 |
+
m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
[[dont_unroll]]
|
| 139 |
for (uint32_t j = start_j; j < end_j; ++j) {
|
| 140 |
|
|
|
|
| 160 |
|
| 161 |
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
| 162 |
|
| 163 |
+
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
| 164 |
|
| 165 |
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
| 166 |
}
|
|
|
|
| 234 |
if (p.k_num > 1) {
|
| 235 |
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
| 236 |
|
| 237 |
+
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
|
| 238 |
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
| 239 |
|
| 240 |
+
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
| 241 |
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
| 242 |
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
| 243 |
return;
|
|
|
|
| 255 |
|
| 256 |
O = Ldiag*O;
|
| 257 |
|
| 258 |
+
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
|
| 259 |
|
| 260 |
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
| 261 |
if (p.gqa_ratio > 1) {
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
CHANGED
|
@@ -12,6 +12,7 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
|
|
| 12 |
layout (push_constant) uniform parameter {
|
| 13 |
uint D;
|
| 14 |
uint N;
|
|
|
|
| 15 |
uint k_num;
|
| 16 |
} p;
|
| 17 |
|
|
@@ -19,13 +20,14 @@ void main() {
|
|
| 19 |
// Each workgroup handles a row
|
| 20 |
const uint n = gl_WorkGroupID.x;
|
| 21 |
const uint tid = gl_LocalInvocationID.x;
|
|
|
|
| 22 |
|
| 23 |
uint D = p.D;
|
| 24 |
uint N = p.N;
|
| 25 |
uint k_num = p.k_num;
|
| 26 |
|
| 27 |
-
uint l_offset = D * N * k_num + n;
|
| 28 |
-
uint m_offset = D * N * k_num + N + n;
|
| 29 |
uint lm_stride = N * 2;
|
| 30 |
|
| 31 |
// Compute the max m value for the row
|
|
@@ -49,11 +51,11 @@ void main() {
|
|
| 49 |
for (uint d = tid; d < D; d += BLOCK_SIZE) {
|
| 50 |
float O = 0.0;
|
| 51 |
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
| 52 |
-
uint o_offset = D * N * k + D * n + d;
|
| 53 |
float m = data_a[m_offset + k * lm_stride];
|
| 54 |
O += exp(m - m_max) * data_a[o_offset];
|
| 55 |
}
|
| 56 |
O *= L;
|
| 57 |
-
data_d[D * n + d] = O;
|
| 58 |
}
|
| 59 |
}
|
|
|
|
| 12 |
layout (push_constant) uniform parameter {
|
| 13 |
uint D;
|
| 14 |
uint N;
|
| 15 |
+
uint ne3;
|
| 16 |
uint k_num;
|
| 17 |
} p;
|
| 18 |
|
|
|
|
| 20 |
// Each workgroup handles a row
|
| 21 |
const uint n = gl_WorkGroupID.x;
|
| 22 |
const uint tid = gl_LocalInvocationID.x;
|
| 23 |
+
const uint iq3 = gl_WorkGroupID.z;
|
| 24 |
|
| 25 |
uint D = p.D;
|
| 26 |
uint N = p.N;
|
| 27 |
uint k_num = p.k_num;
|
| 28 |
|
| 29 |
+
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
|
| 30 |
+
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
|
| 31 |
uint lm_stride = N * 2;
|
| 32 |
|
| 33 |
// Compute the max m value for the row
|
|
|
|
| 51 |
for (uint d = tid; d < D; d += BLOCK_SIZE) {
|
| 52 |
float O = 0.0;
|
| 53 |
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
| 54 |
+
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
|
| 55 |
float m = data_a[m_offset + k * lm_stride];
|
| 56 |
O += exp(m - m_max) * data_a[o_offset];
|
| 57 |
}
|
| 58 |
O *= L;
|
| 59 |
+
data_d[iq3 * D * N + D * n + d] = O;
|
| 60 |
}
|
| 61 |
}
|
ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp
CHANGED
|
@@ -6,6 +6,14 @@ layout (push_constant) uniform parameter
|
|
| 6 |
{
|
| 7 |
uint KX;
|
| 8 |
uint KY;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
float scale;
|
| 10 |
float max_bias;
|
| 11 |
float m0;
|
|
@@ -31,7 +39,15 @@ shared FLOAT_TYPE vals[BLOCK_SIZE];
|
|
| 31 |
void soft_max(uint num_iters) {
|
| 32 |
const uint tid = gl_LocalInvocationID.x;
|
| 33 |
const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
if (rowx >= p.nrows_x) {
|
| 37 |
return;
|
|
@@ -41,7 +57,7 @@ void soft_max(uint num_iters) {
|
|
| 41 |
|
| 42 |
// ALiBi
|
| 43 |
if (p.max_bias > 0.0f) {
|
| 44 |
-
const uint h = rowx/p.
|
| 45 |
|
| 46 |
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
|
| 47 |
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
|
|
@@ -67,7 +83,7 @@ void soft_max(uint num_iters) {
|
|
| 67 |
|
| 68 |
FLOAT_TYPE b = FLOAT_TYPE(0);
|
| 69 |
if (p.KY > 0 && col < p.KX) {
|
| 70 |
-
b = data_b[
|
| 71 |
}
|
| 72 |
|
| 73 |
FLOAT_TYPE v = a * p.scale + slope * b;
|
|
@@ -111,7 +127,7 @@ void soft_max(uint num_iters) {
|
|
| 111 |
if (idx < DATA_CACHE_SIZE) {
|
| 112 |
val = exp(data_cache[idx] - max_val);
|
| 113 |
} else {
|
| 114 |
-
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[
|
| 115 |
}
|
| 116 |
sum += val;
|
| 117 |
if (idx < DATA_CACHE_SIZE) {
|
|
|
|
| 6 |
{
|
| 7 |
uint KX;
|
| 8 |
uint KY;
|
| 9 |
+
uint ne00;
|
| 10 |
+
uint ne01;
|
| 11 |
+
uint ne02;
|
| 12 |
+
uint ne12;
|
| 13 |
+
uint ne13;
|
| 14 |
+
uint nb11;
|
| 15 |
+
uint nb12;
|
| 16 |
+
uint nb13;
|
| 17 |
float scale;
|
| 18 |
float max_bias;
|
| 19 |
float m0;
|
|
|
|
| 39 |
void soft_max(uint num_iters) {
|
| 40 |
const uint tid = gl_LocalInvocationID.x;
|
| 41 |
const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
| 42 |
+
|
| 43 |
+
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
|
| 44 |
+
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
|
| 45 |
+
const uint32_t i01 = rowx % p.ne01;
|
| 46 |
+
|
| 47 |
+
uint rowy_start = 0;
|
| 48 |
+
if (p.KY > 0) {
|
| 49 |
+
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
|
| 50 |
+
}
|
| 51 |
|
| 52 |
if (rowx >= p.nrows_x) {
|
| 53 |
return;
|
|
|
|
| 57 |
|
| 58 |
// ALiBi
|
| 59 |
if (p.max_bias > 0.0f) {
|
| 60 |
+
const uint h = (rowx / p.ne01) % p.ne02; // head index
|
| 61 |
|
| 62 |
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
|
| 63 |
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
|
|
|
|
| 83 |
|
| 84 |
FLOAT_TYPE b = FLOAT_TYPE(0);
|
| 85 |
if (p.KY > 0 && col < p.KX) {
|
| 86 |
+
b = data_b[rowy_start + col];
|
| 87 |
}
|
| 88 |
|
| 89 |
FLOAT_TYPE v = a * p.scale + slope * b;
|
|
|
|
| 127 |
if (idx < DATA_CACHE_SIZE) {
|
| 128 |
val = exp(data_cache[idx] - max_val);
|
| 129 |
} else {
|
| 130 |
+
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
|
| 131 |
}
|
| 132 |
sum += val;
|
| 133 |
if (idx < DATA_CACHE_SIZE) {
|