Spaces:
Running
Running
Commit
·
a76ef69
1
Parent(s):
b3bf710
vulkan: Use unclamped loads for flash attention mask (llama/12720)
Browse filesnem1 must be a multiple of GGML_KQ_MASK_PAD, and GGML_KQ_MASK_PAD is a multiple
of the number of rows in the matrix. The KV dim is a multiple of the number of
columns for the aligned shader.
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -1833,6 +1833,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 1833 |
// can't use 256 for D==80.
|
| 1834 |
uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
|
| 1835 |
auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
|
|
|
|
|
|
|
| 1836 |
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
|
| 1837 |
};
|
| 1838 |
|
|
@@ -5511,6 +5513,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 5511 |
// the "aligned" shader variant will forcibly align strides, for performance
|
| 5512 |
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
| 5513 |
|
|
|
|
|
|
|
|
|
|
| 5514 |
vk_pipeline pipeline = pipelines[aligned];
|
| 5515 |
assert(pipeline);
|
| 5516 |
|
|
|
|
| 1833 |
// can't use 256 for D==80.
|
| 1834 |
uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
|
| 1835 |
auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
|
| 1836 |
+
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
| 1837 |
+
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
| 1838 |
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
|
| 1839 |
};
|
| 1840 |
|
|
|
|
| 5513 |
// the "aligned" shader variant will forcibly align strides, for performance
|
| 5514 |
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
| 5515 |
|
| 5516 |
+
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
| 5517 |
+
GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
|
| 5518 |
+
|
| 5519 |
vk_pipeline pipeline = pipelines[aligned];
|
| 5520 |
assert(pipeline);
|
| 5521 |
|
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
CHANGED
|
@@ -256,7 +256,7 @@ void main() {
|
|
| 256 |
}
|
| 257 |
|
| 258 |
if (p.mask != 0) {
|
| 259 |
-
tensorLayoutNV<2,
|
| 260 |
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
| 261 |
// When using grouped query attention, all rows use the same mask.
|
| 262 |
if (p.gqa_ratio > 1) {
|
|
|
|
| 256 |
}
|
| 257 |
|
| 258 |
if (p.mask != 0) {
|
| 259 |
+
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
| 260 |
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
| 261 |
// When using grouped query attention, all rows use the same mask.
|
| 262 |
if (p.gqa_ratio > 1) {
|