jeffbolznv commited on
Commit
77d7613
·
1 Parent(s): a76ef69

vulkan: fix NaN issue in flash attention shader (llama/12776)

Browse files

Use -FLT_MAX/2 rather than -inf as the initial value for computing the maximum.

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp CHANGED
@@ -227,8 +227,11 @@ void main() {
227
 
228
  coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
229
 
 
 
 
230
  L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
231
- M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
232
 
233
  coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
234
 
@@ -278,7 +281,7 @@ void main() {
278
  uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
279
  uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
280
 
281
- coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C);
282
  }
283
 
284
  coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
 
227
 
228
  coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
229
 
230
+ // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
231
+ const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
232
+
233
  L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
234
+ M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
235
 
236
  coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
237
 
 
281
  uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
282
  uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
283
 
284
+ coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
285
  }
286
 
287
  coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;