Spaces:
Running
Running
Commit
·
77d7613
1
Parent(s):
a76ef69
vulkan: fix NaN issue in flash attention shader (llama/12776)
Browse filesUse -FLT_MAX/2 rather than -inf as the initial value for computing the maximum.
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
CHANGED
|
@@ -227,8 +227,11 @@ void main() {
|
|
| 227 |
|
| 228 |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
| 229 |
|
|
|
|
|
|
|
|
|
|
| 230 |
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
| 231 |
-
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(
|
| 232 |
|
| 233 |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
|
| 234 |
|
|
@@ -278,7 +281,7 @@ void main() {
|
|
| 278 |
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
|
| 279 |
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
|
| 280 |
|
| 281 |
-
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(
|
| 282 |
}
|
| 283 |
|
| 284 |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
|
|
|
|
| 227 |
|
| 228 |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
| 229 |
|
| 230 |
+
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
|
| 231 |
+
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
| 232 |
+
|
| 233 |
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
| 234 |
+
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
|
| 235 |
|
| 236 |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
|
| 237 |
|
|
|
|
| 281 |
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
|
| 282 |
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
|
| 283 |
|
| 284 |
+
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
|
| 285 |
}
|
| 286 |
|
| 287 |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
|