Spaces:
Running
Running
Commit
·
eb84e7e
1
Parent(s):
31edd77
CUDA: fix pointer incrementation in FA (llama/14916)
Browse files
ggml/src/ggml-cuda/fattn-vec-f16.cuh
CHANGED
|
@@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 174 |
K += blockIdx.y*D * nb11;
|
| 175 |
V += blockIdx.y*D * nb21;
|
| 176 |
maskh += blockIdx.y*D;
|
| 177 |
-
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D
|
|
|
|
|
|
|
|
|
|
| 178 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 179 |
|
| 180 |
if (mask) {
|
|
@@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 291 |
}
|
| 292 |
}
|
| 293 |
|
| 294 |
-
K += gridDim.y*D * nb11;
|
| 295 |
-
V += gridDim.y*D * nb21;
|
| 296 |
-
maskh += gridDim.y*D;
|
| 297 |
-
|
| 298 |
__syncthreads();
|
| 299 |
}
|
| 300 |
|
|
|
|
| 174 |
K += blockIdx.y*D * nb11;
|
| 175 |
V += blockIdx.y*D * nb21;
|
| 176 |
maskh += blockIdx.y*D;
|
| 177 |
+
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
|
| 178 |
+
// Increment pointers after each loop:
|
| 179 |
+
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
|
| 180 |
+
|
| 181 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 182 |
|
| 183 |
if (mask) {
|
|
|
|
| 294 |
}
|
| 295 |
}
|
| 296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
__syncthreads();
|
| 298 |
}
|
| 299 |
|
ggml/src/ggml-cuda/fattn-vec-f32.cuh
CHANGED
|
@@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 180 |
K += blockIdx.y*D * nb11;
|
| 181 |
V += blockIdx.y*D * nb21;
|
| 182 |
maskh += blockIdx.y*D;
|
| 183 |
-
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D
|
|
|
|
|
|
|
|
|
|
| 184 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 185 |
|
| 186 |
if (mask) {
|
|
@@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 286 |
}
|
| 287 |
}
|
| 288 |
|
| 289 |
-
K += gridDim.y*D * nb11;
|
| 290 |
-
V += gridDim.y*D * nb21;
|
| 291 |
-
maskh += gridDim.y*D;
|
| 292 |
-
|
| 293 |
__syncthreads();
|
| 294 |
}
|
| 295 |
|
|
|
|
| 180 |
K += blockIdx.y*D * nb11;
|
| 181 |
V += blockIdx.y*D * nb21;
|
| 182 |
maskh += blockIdx.y*D;
|
| 183 |
+
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
|
| 184 |
+
// Increment pointers after each loop:
|
| 185 |
+
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
|
| 186 |
+
|
| 187 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 188 |
|
| 189 |
if (mask) {
|
|
|
|
| 289 |
}
|
| 290 |
}
|
| 291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
__syncthreads();
|
| 293 |
}
|
| 294 |
|