jeffbolznv commited on
Commit
a76ef69
·
1 Parent(s): b3bf710

vulkan: Use unclamped loads for flash attention mask (llama/12720)

Browse files

nem1 must be a multiple of GGML_KQ_MASK_PAD, and GGML_KQ_MASK_PAD is a multiple
of the number of rows in the matrix. The KV dim is a multiple of the number of
columns for the aligned shader.

ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -1833,6 +1833,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1833
  // can't use 256 for D==80.
1834
  uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
1835
  auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
 
 
1836
  return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
1837
  };
1838
 
@@ -5511,6 +5513,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5511
  // the "aligned" shader variant will forcibly align strides, for performance
5512
  (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
5513
 
 
 
 
5514
  vk_pipeline pipeline = pipelines[aligned];
5515
  assert(pipeline);
5516
 
 
1833
  // can't use 256 for D==80.
1834
  uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
1835
  auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
1836
+ // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
1837
+ GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
1838
  return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
1839
  };
1840
 
 
5513
  // the "aligned" shader variant will forcibly align strides, for performance
5514
  (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
5515
 
5516
+ // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
5517
+ GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
5518
+
5519
  vk_pipeline pipeline = pipelines[aligned];
5520
  assert(pipeline);
5521
 
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp CHANGED
@@ -256,7 +256,7 @@ void main() {
256
  }
257
 
258
  if (p.mask != 0) {
259
- tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
260
  tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
261
  // When using grouped query attention, all rows use the same mask.
262
  if (p.gqa_ratio > 1) {
 
256
  }
257
 
258
  if (p.mask != 0) {
259
+ tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
260
  tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
261
  // When using grouped query attention, all rows use the same mask.
262
  if (p.gqa_ratio > 1) {