JohannesGaessler commited on
Commit
eb84e7e
·
1 Parent(s): 31edd77

CUDA: fix pointer incrementation in FA (llama/14916)

Browse files
ggml/src/ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16(
174
  K += blockIdx.y*D * nb11;
175
  V += blockIdx.y*D * nb21;
176
  maskh += blockIdx.y*D;
177
- for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
 
 
 
178
  // Calculate KQ tile and keep track of new maximum KQ values:
179
 
180
  if (mask) {
@@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16(
291
  }
292
  }
293
 
294
- K += gridDim.y*D * nb11;
295
- V += gridDim.y*D * nb21;
296
- maskh += gridDim.y*D;
297
-
298
  __syncthreads();
299
  }
300
 
 
174
  K += blockIdx.y*D * nb11;
175
  V += blockIdx.y*D * nb21;
176
  maskh += blockIdx.y*D;
177
+ for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
178
+ // Increment pointers after each loop:
179
+ K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
180
+
181
  // Calculate KQ tile and keep track of new maximum KQ values:
182
 
183
  if (mask) {
 
294
  }
295
  }
296
 
 
 
 
 
297
  __syncthreads();
298
  }
299
 
ggml/src/ggml-cuda/fattn-vec-f32.cuh CHANGED
@@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32(
180
  K += blockIdx.y*D * nb11;
181
  V += blockIdx.y*D * nb21;
182
  maskh += blockIdx.y*D;
183
- for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
 
 
 
184
  // Calculate KQ tile and keep track of new maximum KQ values:
185
 
186
  if (mask) {
@@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32(
286
  }
287
  }
288
 
289
- K += gridDim.y*D * nb11;
290
- V += gridDim.y*D * nb21;
291
- maskh += gridDim.y*D;
292
-
293
  __syncthreads();
294
  }
295
 
 
180
  K += blockIdx.y*D * nb11;
181
  V += blockIdx.y*D * nb21;
182
  maskh += blockIdx.y*D;
183
+ for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
184
+ // Increment pointers after each loop:
185
+ K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
186
+
187
  // Calculate KQ tile and keep track of new maximum KQ values:
188
 
189
  if (mask) {
 
289
  }
290
  }
291
 
 
 
 
 
292
  __syncthreads();
293
  }
294