Spaces:
Running
Running
Commit
·
efbb7be
1
Parent(s):
6dc2887
CUDA: fix broken oob check for FA vec f32 kernel (llama/7904)
Browse files
ggml-cuda/fattn-vec-f32.cuh
CHANGED
|
@@ -149,7 +149,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 149 |
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 150 |
const int i = i0 + threadIdx.x;
|
| 151 |
|
| 152 |
-
Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
|
| 153 |
Q_f2[j][i0/WARP_SIZE].x *= scale;
|
| 154 |
Q_f2[j][i0/WARP_SIZE].y *= scale;
|
| 155 |
}
|
|
|
|
| 149 |
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 150 |
const int i = i0 + threadIdx.x;
|
| 151 |
|
| 152 |
+
Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
|
| 153 |
Q_f2[j][i0/WARP_SIZE].x *= scale;
|
| 154 |
Q_f2[j][i0/WARP_SIZE].y *= scale;
|
| 155 |
}
|