JohannesGaessler commited on
Commit
a867083
·
1 Parent(s): fd2d86d

CUDA: batched+noncont MMQ, refactor bs>1 MoE code (llama/13199)

Browse files
ggml/src/ggml-cuda/getrows.cu CHANGED
@@ -33,8 +33,8 @@ static __global__ void k_get_rows(
33
  dfloat2 v;
34
  dequantize_kernel(src0_row, ib, iqs, v);
35
 
36
- dst_row[iybs + iqs + 0] = v.x;
37
- dst_row[iybs + iqs + y_offset] = v.y;
38
  }
39
 
40
  template<typename src0_t, typename dst_t>
@@ -60,7 +60,7 @@ static __global__ void k_get_rows_float(
60
  dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
61
  const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
62
 
63
- dst_row[i00] = src0_row[i00];
64
  }
65
 
66
  template<typename grad_t, typename dst_t>
@@ -86,122 +86,161 @@ static __global__ void k_get_rows_back_float(
86
  dst[dst_row*ncols + col] = sum;
87
  }
88
 
89
- template<int qk, int qr, dequantize_kernel_t dq>
90
- static void get_rows_cuda(
91
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
92
- const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
93
-
94
- GGML_TENSOR_BINARY_OP_LOCALS
95
-
96
  const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
97
  const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
98
  const dim3 block_nums(block_num_x, ne10, ne11*ne12);
99
 
100
  // strides in elements
101
- //const size_t s0 = nb0 / ggml_element_size(dst);
102
- const size_t s1 = nb1 / ggml_element_size(dst);
103
- const size_t s2 = nb2 / ggml_element_size(dst);
104
- const size_t s3 = nb3 / ggml_element_size(dst);
105
 
106
- const size_t s10 = nb10 / ggml_element_size(src1);
107
- const size_t s11 = nb11 / ggml_element_size(src1);
108
- const size_t s12 = nb12 / ggml_element_size(src1);
109
- //const size_t s13 = nb13 / ggml_element_size(src1);
110
 
111
  GGML_ASSERT(ne00 % 2 == 0);
112
 
113
  k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
114
- src0_dd, src1_dd, dst_dd,
115
  ne00, /*ne01, ne02, ne03,*/
116
  /*ne10, ne11,*/ ne12, /*ne13,*/
117
  /* s0,*/ s1, s2, s3,
118
  /* nb00,*/ nb01, nb02, nb03,
119
  s10, s11, s12/*, s13*/);
120
-
121
- GGML_UNUSED(dst);
122
  }
123
 
124
- template<typename src0_t>
125
  static void get_rows_cuda_float(
126
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
127
- const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
128
-
129
- GGML_TENSOR_BINARY_OP_LOCALS
130
-
131
- GGML_ASSERT(ne13 == 1);
132
-
133
  const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
134
  const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
135
  const dim3 block_nums(block_num_x, ne10, ne11*ne12);
136
 
137
  // strides in elements
138
- //const size_t s0 = nb0 / ggml_element_size(dst);
139
- const size_t s1 = nb1 / ggml_element_size(dst);
140
- const size_t s2 = nb2 / ggml_element_size(dst);
141
- const size_t s3 = nb3 / ggml_element_size(dst);
142
 
143
- const size_t s10 = nb10 / ggml_element_size(src1);
144
- const size_t s11 = nb11 / ggml_element_size(src1);
145
- const size_t s12 = nb12 / ggml_element_size(src1);
146
- //const size_t s13 = nb13 / ggml_element_size(src1);
147
 
148
  k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
149
- src0_dd, src1_dd, dst_dd,
150
  ne00, /*ne01, ne02, ne03,*/
151
  /*ne10, ne11,*/ ne12, /*ne13,*/
152
  /* s0,*/ s1, s2, s3,
153
  /* nb00,*/ nb01, nb02, nb03,
154
  s10, s11, s12/*, s13*/);
155
-
156
- GGML_UNUSED(dst);
157
  }
158
 
159
- void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160
- const ggml_tensor * src0 = dst->src[0];
161
- const ggml_tensor * src1 = dst->src[1];
162
-
163
- const void * src0_d = (const void *) src0->data;
164
- const int32_t * src1_d = (const int32_t *) src1->data;
165
- float * dst_d = (float *) dst->data;
166
-
167
- cudaStream_t stream = ctx.stream();
168
-
169
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
170
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
171
-
172
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
173
- GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
174
- GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
175
-
176
- switch (src0->type) {
177
  case GGML_TYPE_F16:
178
- get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream);
 
179
  break;
180
  case GGML_TYPE_F32:
181
- get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream);
 
 
 
 
 
182
  break;
183
  case GGML_TYPE_Q4_0:
184
- get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
 
185
  break;
186
  case GGML_TYPE_Q4_1:
187
- get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
 
188
  break;
189
  case GGML_TYPE_Q5_0:
190
- get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
 
191
  break;
192
  case GGML_TYPE_Q5_1:
193
- get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
 
194
  break;
195
  case GGML_TYPE_Q8_0:
196
- get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
 
197
  break;
198
  default:
199
  // TODO: k-quants
200
- GGML_ABORT("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
201
  break;
202
  }
203
  }
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
206
  const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
207
  const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
 
33
  dfloat2 v;
34
  dequantize_kernel(src0_row, ib, iqs, v);
35
 
36
+ dst_row[iybs + iqs + 0] = float(v.x);
37
+ dst_row[iybs + iqs + y_offset] = float(v.y);
38
  }
39
 
40
  template<typename src0_t, typename dst_t>
 
60
  dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
61
  const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
62
 
63
+ dst_row[i00] = float(src0_row[i00]);
64
  }
65
 
66
  template<typename grad_t, typename dst_t>
 
86
  dst[dst_row*ncols + col] = sum;
87
  }
88
 
89
+ template<int qk, int qr, dequantize_kernel_t dq, typename dst_t>
90
+ static void get_rows_cuda_q(
91
+ const void * src0_d, const int32_t * src1_d, dst_t * dst_d,
92
+ const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
93
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
94
+ const size_t nb1, const size_t nb2, const size_t nb3,
95
+ cudaStream_t stream) {
96
  const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
97
  const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
98
  const dim3 block_nums(block_num_x, ne10, ne11*ne12);
99
 
100
  // strides in elements
101
+ // const size_t s0 = nb0 / sizeof(dst_t);
102
+ const size_t s1 = nb1 / sizeof(dst_t);
103
+ const size_t s2 = nb2 / sizeof(dst_t);
104
+ const size_t s3 = nb3 / sizeof(dst_t);
105
 
106
+ const size_t s10 = nb10 / sizeof(int32_t);
107
+ const size_t s11 = nb11 / sizeof(int32_t);
108
+ const size_t s12 = nb12 / sizeof(int32_t);
109
+ // const size_t s13 = nb13 / sizeof(int32_t);
110
 
111
  GGML_ASSERT(ne00 % 2 == 0);
112
 
113
  k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
114
+ src0_d, src1_d, dst_d,
115
  ne00, /*ne01, ne02, ne03,*/
116
  /*ne10, ne11,*/ ne12, /*ne13,*/
117
  /* s0,*/ s1, s2, s3,
118
  /* nb00,*/ nb01, nb02, nb03,
119
  s10, s11, s12/*, s13*/);
 
 
120
  }
121
 
122
+ template<typename src0_t, typename dst_t>
123
  static void get_rows_cuda_float(
124
+ const src0_t * src0_d, const int32_t * src1_d, dst_t * dst_d,
125
+ const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
126
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
127
+ const size_t nb1, const size_t nb2, const size_t nb3,
128
+ cudaStream_t stream) {
 
 
129
  const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
130
  const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
131
  const dim3 block_nums(block_num_x, ne10, ne11*ne12);
132
 
133
  // strides in elements
134
+ // const size_t s0 = nb0 / sizeof(dst_t);
135
+ const size_t s1 = nb1 / sizeof(dst_t);
136
+ const size_t s2 = nb2 / sizeof(dst_t);
137
+ const size_t s3 = nb3 / sizeof(dst_t);
138
 
139
+ const size_t s10 = nb10 / sizeof(int32_t);
140
+ const size_t s11 = nb11 / sizeof(int32_t);
141
+ const size_t s12 = nb12 / sizeof(int32_t);
142
+ // const size_t s13 = nb13 / sizeof(int32_t);
143
 
144
  k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
145
+ src0_d, src1_d, dst_d,
146
  ne00, /*ne01, ne02, ne03,*/
147
  /*ne10, ne11,*/ ne12, /*ne13,*/
148
  /* s0,*/ s1, s2, s3,
149
  /* nb00,*/ nb01, nb02, nb03,
150
  s10, s11, s12/*, s13*/);
 
 
151
  }
152
 
153
+ template <typename dst_t>
154
+ static void ggml_cuda_get_rows_switch_src0_type(
155
+ const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
156
+ const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
157
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
158
+ const size_t nb1, const size_t nb2, const size_t nb3,
159
+ cudaStream_t stream) {
160
+ switch (src0_type) {
 
 
 
 
 
 
 
 
 
 
161
  case GGML_TYPE_F16:
162
+ get_rows_cuda_float((const half *) src0_d, src1_d, dst_d,
163
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
164
  break;
165
  case GGML_TYPE_F32:
166
+ get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
167
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
168
+ break;
169
+ case GGML_TYPE_BF16:
170
+ get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
171
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
172
  break;
173
  case GGML_TYPE_Q4_0:
174
+ get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,
175
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
176
  break;
177
  case GGML_TYPE_Q4_1:
178
+ get_rows_cuda_q<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_d, dst_d,
179
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
180
  break;
181
  case GGML_TYPE_Q5_0:
182
+ get_rows_cuda_q<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_d, dst_d,
183
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
184
  break;
185
  case GGML_TYPE_Q5_1:
186
+ get_rows_cuda_q<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_d, dst_d,
187
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
188
  break;
189
  case GGML_TYPE_Q8_0:
190
+ get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
191
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
192
  break;
193
  default:
194
  // TODO: k-quants
195
+ GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type));
196
  break;
197
  }
198
  }
199
 
200
+ void get_rows_cuda(
201
+ const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,
202
+ int64_t ne00, size_t nb01, size_t nb02, size_t nb03,
203
+ int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,
204
+ size_t nb1, size_t nb2, size_t nb3,
205
+ cudaStream_t stream) {
206
+ switch (dst_type) {
207
+ case GGML_TYPE_F32:
208
+ ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
209
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
210
+ break;
211
+ case GGML_TYPE_F16:
212
+ ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
213
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
214
+ break;
215
+ case GGML_TYPE_BF16:
216
+ ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (nv_bfloat16 *) dst_d,
217
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
218
+ break;
219
+ default:
220
+ GGML_ABORT("%s: unsupported dst type: %s\n", __func__, ggml_type_name(dst_type));
221
+ break;
222
+ }
223
+ }
224
+
225
+ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
226
+ const ggml_tensor * src0 = dst->src[0];
227
+ const ggml_tensor * src1 = dst->src[1];
228
+
229
+ cudaStream_t stream = ctx.stream();
230
+
231
+ GGML_TENSOR_BINARY_OP_LOCALS
232
+
233
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
234
+ GGML_ASSERT(ne13 == 1);
235
+
236
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
237
+ GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
238
+ GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
239
+
240
+ get_rows_cuda(src0->data, src0->type, (const int32_t *) src1->data, dst->data, dst->type,
241
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
242
+ }
243
+
244
  void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
245
  const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
246
  const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
ggml/src/ggml-cuda/getrows.cuh CHANGED
@@ -3,6 +3,13 @@
3
  #define CUDA_GET_ROWS_BLOCK_SIZE 256
4
  #define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
5
 
 
 
 
 
 
 
 
6
  void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
7
 
8
  void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
3
  #define CUDA_GET_ROWS_BLOCK_SIZE 256
4
  #define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
5
 
6
+ void get_rows_cuda(
7
+ const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,
8
+ int64_t ne00, size_t nb01, size_t nb02, size_t nb03,
9
+ int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,
10
+ size_t nb1, size_t nb2, size_t nb3,
11
+ cudaStream_t stream);
12
+
13
  void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
14
 
15
  void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -1551,7 +1551,7 @@ static void ggml_cuda_op_mul_mat(
1551
 
1552
  if (src1_on_device && src1_is_contiguous) {
1553
  quantize_src1(
1554
- dev[id].src1_ddf, dev[id].src1_ddq, src0->type, ne10,
1555
  nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),
1556
  src1_padded_col_size, ne11, ne12, ne13, stream);
1557
  CUDA_CHECK(cudaGetLastError());
@@ -1649,7 +1649,7 @@ static void ggml_cuda_op_mul_mat(
1649
 
1650
  if (quantize_src1 && !src1_is_contiguous) {
1651
  quantize_src1(
1652
- src1_ddf_i, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
1653
  src1_padded_col_size, src1_ncols, 1, 1, stream);
1654
  CUDA_CHECK(cudaGetLastError());
1655
  }
@@ -1949,6 +1949,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1949
  ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
1950
  } else if (!split && use_mul_mat_vec_q) {
1951
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
 
 
1952
  } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1953
  !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1954
  // general KQ + KQV multi-batch without FlashAttention
@@ -1964,183 +1966,145 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1964
  }
1965
  }
1966
 
1967
- struct mmid_row_mapping {
1968
- int32_t i1;
1969
- int32_t i2;
1970
- };
1971
-
1972
- static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
1973
- int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
1974
- const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
1975
- int64_t ne11, int64_t ne10,
1976
- size_t nb11, size_t nb12) {
1977
- int32_t iid1 = blockIdx.x;
1978
- int32_t id = blockIdx.y;
1979
-
1980
- const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
1981
-
1982
- if (row_id_i != i02) {
1983
- return;
1984
- }
1985
-
1986
- const int64_t i11 = id % ne11;
1987
- const int64_t i12 = iid1;
1988
-
1989
- __shared__ int src1_row;
1990
- if (threadIdx.x == 0) {
1991
- src1_row = atomicAdd(cur_src1_row, 1);
1992
- row_mapping[src1_row] = {id, iid1};
1993
- }
1994
- __syncthreads();
1995
-
1996
- const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
1997
- float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
1998
-
1999
- for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
2000
- src1_row_contiguous[i] = src1_row_original[i];
2001
- }
2002
- }
2003
-
2004
- static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
2005
- const mmid_row_mapping * __restrict__ row_mapping,
2006
- int64_t ne0,
2007
- size_t nb1, size_t nb2) {
2008
- int32_t i = blockIdx.x;
2009
-
2010
- const int32_t i1 = row_mapping[i].i1;
2011
- const int32_t i2 = row_mapping[i].i2;
2012
-
2013
- const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
2014
- float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
2015
-
2016
- for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
2017
- dst_row_original[j] = dst_row_contiguous[j];
2018
- }
2019
- }
2020
-
2021
  static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
2022
  const ggml_tensor * src0 = dst->src[0];
2023
  const ggml_tensor * src1 = dst->src[1];
2024
  const ggml_tensor * ids = dst->src[2];
2025
 
2026
- GGML_TENSOR_BINARY_OP_LOCALS
2027
-
2028
- if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ne2 == 1) {
2029
- if (ggml_is_quantized(src0->type)) {
2030
- ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
2031
- } else {
2032
- ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
2033
- }
2034
- return;
2035
- }
2036
-
2037
  GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
2038
 
2039
- cudaStream_t stream = ctx.stream();
2040
 
2041
- const int64_t n_as = ne02;
2042
- const int64_t n_ids = ids->ne[0];
2043
 
2044
- std::vector<char> ids_host(ggml_nbytes(ids));
2045
- const char * ids_dev = (const char *) ids->data;
2046
- CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
2047
- CUDA_CHECK(cudaStreamSynchronize(stream));
 
 
 
 
 
2048
 
2049
- ggml_tensor src0_row = *src0;
2050
- ggml_tensor src1_row = *src1;
2051
- ggml_tensor dst_row = *dst;
 
 
2052
 
2053
- char * src0_original = (char *) src0->data;
2054
- char * src1_original = (char *) src1->data;
2055
- char * dst_original = (char *) dst->data;
2056
 
2057
- src0_row.ne[2] = 1;
2058
- src0_row.ne[3] = 1;
2059
- src0_row.nb[3] = nb02;
2060
 
2061
- src1_row.ne[1] = 1;
2062
- src1_row.ne[2] = 1;
2063
- src1_row.ne[3] = 1;
2064
- src1_row.nb[2] = nb11;
2065
- src1_row.nb[3] = nb11;
2066
 
2067
- dst_row.ne[1] = 1;
2068
- dst_row.ne[2] = 1;
2069
- dst_row.ne[3] = 1;
2070
- dst_row.nb[2] = nb1;
2071
- dst_row.nb[3] = nb1;
2072
 
2073
- ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
2074
- ggml_cuda_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
 
2075
 
2076
- src1_row.data = src1_contiguous.get();
2077
- dst_row.data = dst_contiguous.get();
2078
 
2079
- for (int64_t i02 = 0; i02 < n_as; i02++) {
2080
- int64_t num_src1_rows = 0;
2081
 
2082
- for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2083
- for (int64_t id = 0; id < n_ids; id++) {
2084
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2085
 
2086
- GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
 
 
2087
 
2088
- if (row_id_i != i02) {
2089
- continue;
 
 
 
 
 
 
 
 
2090
  }
2091
-
2092
- num_src1_rows++;
2093
  }
2094
  }
 
 
2095
 
2096
- if (num_src1_rows == 0) {
2097
- continue;
2098
- }
2099
-
2100
- ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
2101
- ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
2102
- CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
2103
-
2104
- {
2105
- dim3 block_dims(std::min((unsigned int)ne10, 768u));
2106
- dim3 grid_dims(ids->ne[1], n_ids);
2107
- k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
2108
- src1_original, src1_contiguous.get(),
2109
- dev_cur_src1_row.get(), dev_row_mapping.get(),
2110
- ids_dev, i02, ids->nb[1], ids->nb[0],
2111
- ne11, ne10,
2112
- nb11, nb12);
2113
- CUDA_CHECK(cudaGetLastError());
2114
- }
2115
 
2116
- src0_row.data = src0_original + i02*nb02;
 
2117
 
2118
- GGML_ASSERT(nb11 == sizeof(float)*ne10);
2119
- GGML_ASSERT(nb1 == sizeof(float)*ne0);
2120
 
2121
- src1_row.ne[1] = num_src1_rows;
2122
- src1_row.nb[1] = nb11;
2123
- src1_row.nb[2] = num_src1_rows*nb11;
2124
- src1_row.nb[3] = num_src1_rows*nb11;
 
2125
 
2126
- dst_row.ne[1] = num_src1_rows;
2127
- dst_row.nb[1] = nb1;
2128
- dst_row.nb[2] = num_src1_rows*nb1;
2129
- dst_row.nb[3] = num_src1_rows*nb1;
 
 
2130
 
2131
- ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2132
 
2133
- {
2134
- dim3 block_dims(std::min((unsigned int)ne0, 768u));
2135
- dim3 grid_dims(num_src1_rows);
2136
- k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
2137
- dst_original, dst_contiguous.get(),
2138
- dev_row_mapping.get(),
2139
- ne0,
2140
- nb1, nb2);
2141
- CUDA_CHECK(cudaGetLastError());
2142
- }
2143
  }
 
 
 
 
 
2144
  }
2145
 
2146
  static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
 
1551
 
1552
  if (src1_on_device && src1_is_contiguous) {
1553
  quantize_src1(
1554
+ dev[id].src1_ddf, nullptr, dev[id].src1_ddq, src0->type, ne10,
1555
  nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),
1556
  src1_padded_col_size, ne11, ne12, ne13, stream);
1557
  CUDA_CHECK(cudaGetLastError());
 
1649
 
1650
  if (quantize_src1 && !src1_is_contiguous) {
1651
  quantize_src1(
1652
+ src1_ddf_i, nullptr, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
1653
  src1_padded_col_size, src1_ncols, 1, 1, stream);
1654
  CUDA_CHECK(cudaGetLastError());
1655
  }
 
1949
  ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
1950
  } else if (!split && use_mul_mat_vec_q) {
1951
  ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
1952
+ } else if (!split && use_mul_mat_q) {
1953
+ ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
1954
  } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1955
  !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1956
  // general KQ + KQV multi-batch without FlashAttention
 
1966
  }
1967
  }
1968
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1969
  static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1970
  const ggml_tensor * src0 = dst->src[0];
1971
  const ggml_tensor * src1 = dst->src[1];
1972
  const ggml_tensor * ids = dst->src[2];
1973
 
1974
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
1975
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
 
 
 
 
 
 
 
 
1976
  GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
1977
 
1978
+ GGML_TENSOR_BINARY_OP_LOCALS
1979
 
1980
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 
1981
 
1982
+ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
1983
+ if (ne2 == 1) {
1984
+ if (ggml_is_quantized(src0->type)) {
1985
+ ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
1986
+ } else {
1987
+ ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
1988
+ }
1989
+ return;
1990
+ }
1991
 
1992
+ if (ggml_cuda_should_use_mmq(src0->type, cc, ne12)) {
1993
+ ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
1994
+ return;
1995
+ }
1996
+ }
1997
 
1998
+ cudaStream_t stream = ctx.stream();
 
 
1999
 
2000
+ GGML_ASSERT(nb12 % nb11 == 0);
2001
+ GGML_ASSERT(nb2 % nb1 == 0);
 
2002
 
2003
+ const ggml_type type_src1_sorted = (src0->type == GGML_TYPE_F16 && !fast_fp16_hardware_available(cc))
2004
+ || ggml_is_quantized(src0->type) ? GGML_TYPE_F32 : src0->type;
2005
+ const ggml_type type_dst_sorted = GGML_TYPE_F32;
2006
+ const size_t ts_src1_sorted = ggml_type_size(type_src1_sorted);
2007
+ const size_t ts_dst_sorted = ggml_type_size(type_dst_sorted);
2008
 
2009
+ const int64_t n_expert_used = ids->ne[0];
2010
+ const int64_t ne_get_rows = ne12 * n_expert_used;
 
 
 
2011
 
2012
+ std::vector<int32_t> ids_to_sorted_host;
2013
+ ids_to_sorted_host.reserve(2*ne_get_rows);
2014
+ std::vector<int32_t> ids_from_sorted_host(ne_get_rows);
2015
 
2016
+ ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool(), 2*ne_get_rows);
 
2017
 
2018
+ std::vector<int32_t> tokens_per_expert(ne02);
 
2019
 
2020
+ ggml_cuda_pool_alloc<char> src1_sorted(ctx.pool(), ne12*n_expert_used*ne10*ts_src1_sorted);
2021
+ ggml_cuda_pool_alloc<char> dst_sorted(ctx.pool(), ne2 *n_expert_used* ne0*ts_dst_sorted);
 
2022
 
2023
+ std::vector<char> ids_host(ggml_nbytes(ids));
2024
+ CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
2025
+ CUDA_CHECK(cudaStreamSynchronize(stream));
2026
 
2027
+ for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
2028
+ for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
2029
+ for (int64_t iex = 0; iex < n_expert_used; ++iex) {
2030
+ const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
2031
+ assert(expert_to_use >= 0 && expert_to_use < ne02);
2032
+ if (expert_to_use == i02) {
2033
+ ids_from_sorted_host[i12*n_expert_used + iex] = ids_to_sorted_host.size();
2034
+ ids_to_sorted_host.push_back(i12*ne11 + iex % ne11);
2035
+ tokens_per_expert[i02]++;
2036
+ break;
2037
  }
 
 
2038
  }
2039
  }
2040
+ }
2041
+ GGML_ASSERT(ids_to_sorted_host.size() == size_t(ne_get_rows));
2042
 
2043
+ ids_to_sorted_host.insert(ids_to_sorted_host.end(), ids_from_sorted_host.begin(), ids_from_sorted_host.end());
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2044
 
2045
+ CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_to_sorted_host.data(), 2*ne_get_rows*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
2046
+ CUDA_CHECK(cudaStreamSynchronize(stream));
2047
 
2048
+ const int32_t * ids_to_sorted = ids_buf_dev.ptr + 0*ne_get_rows;
2049
+ const int32_t * ids_from_sorted = ids_buf_dev.ptr + 1*ne_get_rows;
2050
 
2051
+ get_rows_cuda(src1->data, src1->type, ids_to_sorted, src1_sorted.ptr, type_src1_sorted,
2052
+ ne10, nb11, nb12, nb13,
2053
+ ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t),
2054
+ ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, stream);
2055
+ CUDA_CHECK(cudaGetLastError());
2056
 
2057
+ char * src1_data_cur = (char *) src1_sorted.ptr;
2058
+ char * dst_data_cur = (char *) dst_sorted.ptr;
2059
+ for (int64_t i02 = 0; i02 < ne02; ++i02) {
2060
+ if (tokens_per_expert[i02] == 0) {
2061
+ continue;
2062
+ }
2063
 
2064
+ ggml_tensor src0_slice = *src0;
2065
+ src0_slice.ne[2] = 1;
2066
+ src0_slice.nb[3] = src0_slice.nb[2];
2067
+ src0_slice.data = (char *) src0->data + i02*nb02;
2068
+
2069
+ ggml_tensor src1_slice;
2070
+ memset(&src1_slice, 0, sizeof(src1_slice));
2071
+ src1_slice.buffer = src1->buffer;
2072
+ src1_slice.type = type_src1_sorted;
2073
+ src1_slice.ne[0] = ne10;
2074
+ src1_slice.ne[1] = tokens_per_expert[i02];
2075
+ src1_slice.ne[2] = 1;
2076
+ src1_slice.ne[3] = 1;
2077
+ src1_slice.nb[0] = ts_src1_sorted;
2078
+ src1_slice.nb[1] = src1_slice.ne[0] * src1_slice.nb[0];
2079
+ src1_slice.nb[2] = src1_slice.ne[1] * src1_slice.nb[1];
2080
+ src1_slice.nb[3] = src1_slice.ne[2] * src1_slice.nb[2];
2081
+ src1_slice.data = src1_data_cur;
2082
+
2083
+ ggml_tensor dst_slice;
2084
+ memset(&dst_slice, 0, sizeof(dst_slice));
2085
+ dst_slice.buffer = dst->buffer;
2086
+ dst_slice.type = type_dst_sorted;
2087
+ dst_slice.ne[0] = ne0;
2088
+ dst_slice.ne[1] = tokens_per_expert[i02];
2089
+ dst_slice.ne[2] = 1;
2090
+ dst_slice.ne[3] = 1;
2091
+ dst_slice.nb[0] = ts_dst_sorted;
2092
+ dst_slice.nb[1] = dst_slice.ne[0] * dst_slice.nb[0];
2093
+ dst_slice.nb[2] = dst_slice.ne[1] * dst_slice.nb[1];
2094
+ dst_slice.nb[3] = dst_slice.ne[2] * dst_slice.nb[2];
2095
+ dst_slice.data = dst_data_cur;
2096
+
2097
+ ggml_cuda_mul_mat(ctx, &src0_slice, &src1_slice, &dst_slice);
2098
+ CUDA_CHECK(cudaGetLastError());
2099
 
2100
+ src1_data_cur += src1_slice.nb[2];
2101
+ dst_data_cur += dst_slice.nb[2];
 
 
 
 
 
 
 
 
2102
  }
2103
+
2104
+ get_rows_cuda(dst_sorted.ptr, type_dst_sorted, ids_from_sorted, dst->data, dst->type,
2105
+ ne0, ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted,
2106
+ ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t),
2107
+ nb1, nb2, nb3, stream);
2108
  }
2109
 
2110
  static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
ggml/src/ggml-cuda/mmq.cu CHANGED
@@ -1,37 +1,10 @@
1
  #include "mmq.cuh"
 
2
 
3
- void ggml_cuda_op_mul_mat_q(
4
- ggml_backend_cuda_context & ctx,
5
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
6
- const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
7
- const int64_t src1_padded_row_size, cudaStream_t stream) {
8
-
9
- const int64_t ne00 = src0->ne[0];
10
-
11
- const int64_t ne10 = src1->ne[0];
12
- const int64_t ne11 = src1->ne[1];
13
- GGML_ASSERT(ne10 % QK8_1 == 0);
14
 
15
- const int64_t ne0 = dst->ne[0];
16
-
17
- const int64_t row_diff = row_high - row_low;
18
- const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
19
-
20
- int id = ggml_cuda_get_device();
21
- const int cc = ggml_cuda_info().devices[id].cc;
22
-
23
- // the main device has a larger memory buffer to hold the results from all GPUs
24
- // nrows_dst == nrows of the matrix that the kernel writes into
25
- const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
26
-
27
- // The stream-k decomposition is only faster for recent NVIDIA GPUs.
28
- // Also its fixup needs to allocate a temporary buffer in the memory pool.
29
- // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
30
- const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
31
- ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
32
- const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
33
-
34
- switch (src0->type) {
35
  case GGML_TYPE_Q4_0:
36
  mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
37
  break;
@@ -90,10 +63,195 @@ void ggml_cuda_op_mul_mat_q(
90
  GGML_ABORT("fatal error");
91
  break;
92
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  GGML_UNUSED(src1);
95
  GGML_UNUSED(dst);
96
  GGML_UNUSED(src1_ddf_i);
 
97
  }
98
 
99
  bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
 
1
  #include "mmq.cuh"
2
+ #include "quantize.cuh"
3
 
4
+ #include <vector>
 
 
 
 
 
 
 
 
 
 
5
 
6
+ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
7
+ switch (args.type_x) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  case GGML_TYPE_Q4_0:
9
  mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
10
  break;
 
63
  GGML_ABORT("fatal error");
64
  break;
65
  }
66
+ }
67
+
68
+ void ggml_cuda_mul_mat_q(
69
+ ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
70
+ GGML_ASSERT( src1->type == GGML_TYPE_F32);
71
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
72
+ GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
73
+
74
+ GGML_TENSOR_BINARY_OP_LOCALS;
75
+
76
+ cudaStream_t stream = ctx.stream();
77
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
78
+
79
+ const size_t ts_src0 = ggml_type_size(src0->type);
80
+ const size_t ts_src1 = ggml_type_size(src1->type);
81
+ const size_t ts_dst = ggml_type_size(dst->type);
82
+
83
+ GGML_ASSERT( nb00 == ts_src0);
84
+ GGML_ASSERT( nb10 == ts_src1);
85
+ GGML_ASSERT( nb0 == ts_dst);
86
+ GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
87
+
88
+ const char * src0_d = (const char *) src0->data;
89
+ const float * src1_d = (const float *) src1->data;
90
+ float * dst_d = (float *) dst->data;
91
+
92
+ const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
93
+
94
+ const int64_t s01 = src0->nb[1] / ts_src0;
95
+ const int64_t s1 = dst->nb[1] / ts_dst;
96
+ const int64_t s02 = src0->nb[2] / ts_src0;
97
+ const int64_t s2 = dst->nb[2] / ts_dst;
98
+ const int64_t s03 = src0->nb[3] / ts_src0;
99
+ const int64_t s3 = dst->nb[3] / ts_dst;
100
+
101
+ const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
102
+
103
+ if (!ids) {
104
+ const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
105
+ get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
106
+ ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
107
+
108
+ {
109
+ const int64_t s11 = src1->nb[1] / ts_src1;
110
+ const int64_t s12 = src1->nb[2] / ts_src1;
111
+ const int64_t s13 = src1->nb[3] / ts_src1;
112
+ quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
113
+ ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
114
+ }
115
+
116
+ const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
117
+ const int64_t s13 = ne12*s12;
118
+
119
+ const mmq_args args = {
120
+ src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
121
+ ne00, ne01, ne1, s01, s1,
122
+ ne02, ne12, s02, s12, s2,
123
+ ne03, ne13, s03, s13, s3,
124
+ use_stream_k};
125
+ ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
126
+ return;
127
+ }
128
+
129
+ GGML_ASSERT(ne13 == 1);
130
+ GGML_ASSERT(nb12 % nb11 == 0);
131
+ GGML_ASSERT(nb2 % nb1 == 0);
132
+
133
+ const int64_t n_expert_used = ids->ne[0];
134
+ const int64_t ne_get_rows = ne12 * n_expert_used;
135
+
136
+ std::vector<char> ids_host(ggml_nbytes(ids));
137
+ std::vector<int32_t> ids_src1_host;
138
+ ids_src1_host.reserve(ne_get_rows);
139
+ std::vector<int32_t> ids_dst_host;
140
+ ids_dst_host.reserve(ne_get_rows);
141
+ std::vector<int32_t> tokens_per_expert_host(ne02);
142
+ std::vector<int32_t> expert_bounds_host(ne02 + 1);
143
+ ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool());
144
+
145
+ CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
146
+ CUDA_CHECK(cudaStreamSynchronize(stream));
147
+
148
+ for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
149
+ for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
150
+ for (int64_t iex = 0; iex < n_expert_used; ++iex) {
151
+ const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
152
+ assert(expert_to_use >= 0 && expert_to_use < ne02);
153
+ if (expert_to_use == i02) {
154
+ ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11);
155
+ ids_dst_host.push_back(i12*ne1 + iex);
156
+ tokens_per_expert_host[i02]++;
157
+ break;
158
+ }
159
+ }
160
+ }
161
+ }
162
+
163
+ int32_t cumsum = 0;
164
+ for (int64_t i = 0; i < ne02; ++i) {
165
+ expert_bounds_host[i] = cumsum;
166
+ cumsum += tokens_per_expert_host[i];
167
+ }
168
+ expert_bounds_host[ne02] = cumsum;
169
+
170
+ std::vector<int32_t> ids_buf_host;
171
+ ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size());
172
+ ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end());
173
+ ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end());
174
+ ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end());
175
+ ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device.
176
+ CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
177
+ CUDA_CHECK(cudaStreamSynchronize(stream));
178
+
179
+ const int32_t * ids_src1_dev = ids_buf_dev.ptr;
180
+ const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size();
181
+ const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size();
182
+
183
+ const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
184
+ get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
185
+ ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
186
+
187
+ const int64_t ne11_flat = ne12*n_expert_used;
188
+ const int64_t ne12_flat = 1;
189
+ const int64_t ne13_flat = 1;
190
+
191
+ {
192
+ const int64_t s11 = src1->nb[1] / ts_src1;
193
+ const int64_t s12 = src1->nb[2] / ts_src1;
194
+ const int64_t s13 = src1->nb[2] / ts_src1;
195
+ quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
196
+ ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
197
+ }
198
+
199
+ const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
200
+ const int64_t s13 = ne12*s12;
201
+
202
+ // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
203
+ const mmq_args args = {
204
+ src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
205
+ ne00, ne01, ne_get_rows, s01, s1,
206
+ ne02, ne02, s02, s12, s2,
207
+ ne03, ne13, s03, s13, s3,
208
+ use_stream_k};
209
+
210
+ ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
211
+ }
212
+
213
+ void ggml_cuda_op_mul_mat_q(
214
+ ggml_backend_cuda_context & ctx,
215
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
216
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
217
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
218
+
219
+ const int64_t ne00 = src0->ne[0];
220
+
221
+ const int64_t ne10 = src1->ne[0];
222
+ const int64_t ne11 = src1->ne[1];
223
+ GGML_ASSERT(ne10 % QK8_1 == 0);
224
+
225
+ const int64_t ne0 = dst->ne[0];
226
+
227
+ const int64_t row_diff = row_high - row_low;
228
+ const int64_t stride01 = ne00 / ggml_blck_size(src0->type);
229
+
230
+ const int id = ggml_cuda_get_device();
231
+ const int cc = ggml_cuda_info().devices[id].cc;
232
+
233
+ // the main device has a larger memory buffer to hold the results from all GPUs
234
+ // nrows_dst == nrows of the matrix that the kernel writes into
235
+ const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
236
+
237
+ // The stream-k decomposition is only faster for recent NVIDIA GPUs.
238
+ // Also its fixup needs to allocate a temporary buffer in the memory pool.
239
+ // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
240
+ const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
241
+ ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
242
+ const mmq_args args = {
243
+ src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
244
+ ne00, row_diff, src1_ncols, stride01, nrows_dst,
245
+ 1, 1, 0, 0, 0,
246
+ 1, 1, 0, 0, 0,
247
+ use_stream_k};
248
+
249
+ ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
250
 
251
  GGML_UNUSED(src1);
252
  GGML_UNUSED(dst);
253
  GGML_UNUSED(src1_ddf_i);
254
+ GGML_UNUSED(src1_padded_row_size);
255
  }
256
 
257
  bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
ggml/src/ggml-cuda/mmq.cuh CHANGED
@@ -13,9 +13,10 @@ using namespace ggml_cuda_mma;
13
  #define MMQ_ITER_K 256
14
  #define MMQ_NWARPS 8
15
 
16
- typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
17
- typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00);
18
- typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
 
19
 
20
  enum mmq_q8_1_ds_layout {
21
  MMQ_Q8_1_DS_LAYOUT_D4,
@@ -233,7 +234,7 @@ static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */
233
  // ------------------------------------------------------------
234
 
235
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
236
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
237
 
238
  #ifdef NEW_MMA_AVAILABLE
239
  int * x_qs = (int *) x_tile;
@@ -289,7 +290,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
289
 
290
  template <int mmq_x, int mmq_y, int nwarps>
291
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
292
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
293
 
294
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
295
  const int * x_qs = (const int *) x;
@@ -328,7 +329,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
328
  }
329
 
330
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
331
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
332
 
333
  #ifdef NEW_MMA_AVAILABLE
334
  int * x_qs = (int *) x_tile;
@@ -384,7 +385,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
384
 
385
  template <int mmq_x, int mmq_y, int nwarps>
386
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
387
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
388
 
389
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
390
  const int * x_qs = (const int *) x;
@@ -423,7 +424,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
423
  }
424
 
425
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
426
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
427
 
428
  #ifdef NEW_MMA_AVAILABLE
429
  int * x_qs = (int *) x_tile;
@@ -495,7 +496,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
495
  }
496
 
497
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
498
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
499
 
500
  #ifdef NEW_MMA_AVAILABLE
501
  int * x_qs = (int *) x_tile;
@@ -565,7 +566,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
565
  }
566
 
567
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
568
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
569
 
570
  #ifdef NEW_MMA_AVAILABLE
571
  int * x_qs = (int *) x_tile;
@@ -621,7 +622,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
621
 
622
  template <int mmq_x, int mmq_y, int nwarps>
623
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
624
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
625
 
626
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
627
  const int * x_qs = (const int *) x;
@@ -651,7 +652,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
651
 
652
  template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
653
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
654
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
655
 
656
  typedef tile<16, 8, int> tile_A;
657
  typedef tile< 8, 8, int> tile_B;
@@ -732,7 +733,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
732
 
733
  template <int mmq_x, int mmq_y, int nwarps>
734
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
735
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
736
 
737
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
738
  const int * x_qs = (const int *) x;
@@ -762,7 +763,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
762
 
763
  template <int mmq_x, int mmq_y, int nwarps>
764
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
765
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
766
 
767
  typedef tile<16, 8, int> tile_A;
768
  typedef tile< 8, 8, int> tile_B;
@@ -839,7 +840,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
839
 
840
  template <int mmq_x, int mmq_y, int nwarps>
841
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
842
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
843
 
844
  constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
845
  const int * x_qs = (const int *) x;
@@ -871,7 +872,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
871
 
872
  template <int mmq_x, int mmq_y, int nwarps>
873
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
874
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
875
  #ifdef NEW_MMA_AVAILABLE
876
 
877
  typedef tile<16, 4, int> tile_A;
@@ -955,7 +956,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
955
  }
956
 
957
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
958
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
959
 
960
  #ifdef NEW_MMA_AVAILABLE
961
  int * x_qs = (int *) x_tile;
@@ -1011,7 +1012,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1011
 
1012
  template <int mmq_x, int mmq_y, int nwarps>
1013
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1014
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1015
 
1016
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1017
  const int * x_qs = (const int *) x;
@@ -1074,7 +1075,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1074
 
1075
  template <int mmq_x, int mmq_y, int nwarps>
1076
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1077
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1078
  #ifdef NEW_MMA_AVAILABLE
1079
 
1080
  typedef tile<16, 4, int> tile_A;
@@ -1201,7 +1202,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1201
  }
1202
 
1203
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
1204
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1205
 
1206
  #ifdef NEW_MMA_AVAILABLE
1207
  int * x_qs = (int *) x_tile;
@@ -1298,7 +1299,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1298
 
1299
  template <int mmq_x, int mmq_y, int nwarps>
1300
  static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1301
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1302
 
1303
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1304
  const int * x_qs = (const int *) x;
@@ -1340,7 +1341,7 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
1340
  }
1341
 
1342
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1343
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1344
 
1345
  #ifdef NEW_MMA_AVAILABLE
1346
  int * x_qs = (int *) x_tile;
@@ -1437,7 +1438,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1437
 
1438
  template <int mmq_x, int mmq_y, int nwarps>
1439
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1440
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1441
 
1442
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1443
  const int * x_qs = (const int *) x;
@@ -1469,7 +1470,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1469
  }
1470
 
1471
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
1472
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1473
 
1474
  #ifdef NEW_MMA_AVAILABLE
1475
  int * x_qs = (int *) x_tile;
@@ -1578,7 +1579,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1578
 
1579
  template <int mmq_x, int mmq_y, int nwarps>
1580
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1581
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1582
 
1583
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
1584
  const int * x_qs = (const int *) x;
@@ -1610,7 +1611,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1610
  }
1611
 
1612
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
1613
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1614
 
1615
  #ifdef NEW_MMA_AVAILABLE
1616
  int * x_qs = (int *) x_tile;
@@ -1693,7 +1694,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1693
 
1694
  template <int mmq_x, int mmq_y, int nwarps>
1695
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1696
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1697
 
1698
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
1699
  const int * x_qs = (const int *) x;
@@ -1726,7 +1727,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1726
 
1727
  template <int mmq_x, int mmq_y, int nwarps>
1728
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1729
- const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
1730
  #ifdef NEW_MMA_AVAILABLE
1731
 
1732
  typedef tile<16, 4, int> tile_A;
@@ -1835,7 +1836,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1835
  }
1836
 
1837
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
1838
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1839
 
1840
  #ifdef NEW_MMA_AVAILABLE
1841
  int * x_qs = (int *) x_tile;
@@ -1893,7 +1894,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1893
  }
1894
 
1895
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
1896
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1897
 
1898
  #ifdef NEW_MMA_AVAILABLE
1899
  int * x_qs = (int *) x_tile;
@@ -1951,7 +1952,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
1951
  }
1952
 
1953
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
1954
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
1955
 
1956
  #ifdef NEW_MMA_AVAILABLE
1957
  int * x_qs = (int *) x_tile;
@@ -2007,7 +2008,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2007
  }
2008
 
2009
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
2010
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2011
 
2012
  #ifdef NEW_MMA_AVAILABLE
2013
  int * x_qs = (int *) x_tile;
@@ -2070,7 +2071,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2070
  }
2071
 
2072
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
2073
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2074
 
2075
  #ifdef NEW_MMA_AVAILABLE
2076
  int * x_qs = (int *) x_tile;
@@ -2126,7 +2127,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2126
  }
2127
 
2128
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
2129
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2130
 
2131
  #ifdef NEW_MMA_AVAILABLE
2132
  int * x_qs = (int *) x_tile;
@@ -2189,7 +2190,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2189
  }
2190
 
2191
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
2192
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2193
 
2194
  #ifdef NEW_MMA_AVAILABLE
2195
  int * x_qs = (int *) x_tile;
@@ -2245,7 +2246,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2245
  }
2246
 
2247
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
2248
- const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2249
 
2250
  #ifdef NEW_MMA_AVAILABLE
2251
  int * x_qs = (int *) x_tile;
@@ -2306,8 +2307,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
2306
 
2307
  template<int mmq_x, int mmq_y, int nwarps, bool need_check>
2308
  static __device__ __forceinline__ void mmq_write_back_dp4a(
2309
- const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
2310
-
2311
  #pragma unroll
2312
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2313
  const int j = j0 + threadIdx.y;
@@ -2324,15 +2325,15 @@ static __device__ __forceinline__ void mmq_write_back_dp4a(
2324
  continue;
2325
  }
2326
 
2327
- dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
2328
  }
2329
  }
2330
  }
2331
 
2332
  template<int mmq_x, int mmq_y, int nwarps, bool need_check>
2333
  static __device__ __forceinline__ void mmq_write_back_mma(
2334
- const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
2335
-
2336
  typedef tile<16, 8, int> tile_C;
2337
 
2338
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
@@ -2362,7 +2363,7 @@ static __device__ __forceinline__ void mmq_write_back_mma(
2362
  continue;
2363
  }
2364
 
2365
- dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
2366
  }
2367
  }
2368
  }
@@ -2518,17 +2519,18 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
2518
  };
2519
 
2520
  template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
2521
- static __device__ void mul_mat_q_process_tile(
2522
- const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2523
- const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
2524
- const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
 
2525
 
2526
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2527
  constexpr int mmq_y = get_mmq_y_device();
2528
  constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
2529
 
2530
- extern __shared__ char data_mul_mat_q[];
2531
- int * tile_y = (int *) data_mul_mat_q;
2532
  int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
2533
 
2534
  #ifdef NEW_MMA_AVAILABLE
@@ -2543,16 +2545,11 @@ static __device__ void mul_mat_q_process_tile(
2543
 
2544
  float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
2545
 
2546
- const int tile_x_max_i = ne01 - it*mmq_y - 1;
2547
- const int tile_y_max_j = ne11 - jt*mmq_x - 1;
2548
-
2549
- const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
2550
-
2551
  for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
2552
- load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
2553
 
2554
  {
2555
- const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
2556
  #pragma unroll
2557
  for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
2558
  int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2568,7 +2565,7 @@ static __device__ void mul_mat_q_process_tile(
2568
  __syncthreads();
2569
 
2570
  {
2571
- const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
2572
  #pragma unroll
2573
  for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
2574
  int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2585,12 +2582,10 @@ static __device__ void mul_mat_q_process_tile(
2585
  }
2586
 
2587
  if (fixup) {
2588
- write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
2589
  } else {
2590
- write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
2591
  }
2592
-
2593
- GGML_UNUSED(ne00); GGML_UNUSED(ne10);
2594
  }
2595
 
2596
 
@@ -2609,8 +2604,11 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2609
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
2610
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2611
  static __global__ void mul_mat_q(
2612
- const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2613
- const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
 
 
 
2614
 
2615
  // Skip unused template specializations for faster compilation:
2616
  if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
@@ -2621,26 +2619,85 @@ static __global__ void mul_mat_q(
2621
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2622
  constexpr int mmq_y = get_mmq_y_device();
2623
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2624
  // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
2625
  #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
2626
  {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2627
  constexpr bool fixup = false;
2628
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2629
- (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
2630
- blockIdx.x, blockIdx.y, 0, ne00/qk);
2631
  return;
2632
  }
2633
  #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
2634
 
2635
- const int64_t blocks_per_ne00 = ne00 / qk;
2636
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
2637
 
2638
- const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
2639
- const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
2640
-
2641
  // kbc == k block continuous, current index in continuous ijk space.
2642
- int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x;
2643
- int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
2644
 
2645
  kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
2646
  kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
@@ -2649,13 +2706,64 @@ static __global__ void mul_mat_q(
2649
  int kb0_start = kbc % blocks_per_ne00;
2650
  int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
2651
  while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
2652
- const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile.
2653
- const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2654
 
2655
  constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
2656
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2657
- (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
2658
- it, jt, kb0_start, kb0_stop);
2659
 
2660
  kbc += blocks_per_ne00;
2661
  kbc -= kbc % blocks_per_ne00;
@@ -2668,55 +2776,106 @@ static __global__ void mul_mat_q(
2668
  return;
2669
  }
2670
 
2671
- const int jt = kbc / (blocks_per_ne00*nty);
2672
- const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2673
 
2674
  constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
2675
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2676
- (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
2677
- it, jt, kb0_start, kb0_stop);
2678
  }
2679
 
2680
 
2681
  template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2682
  static __global__ void mul_mat_q_stream_k_fixup(
2683
- float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
2684
-
 
2685
  constexpr int mmq_y = get_mmq_y_device();
2686
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2687
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
2688
- const int64_t blocks_per_ne00 = ne00 / qk;
2689
 
2690
  float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
2691
 
2692
- const int ntx = (ne11 + mmq_x - 1) / mmq_x;
2693
- const int nty = (ne01 + mmq_y - 1) / mmq_y;
2694
-
2695
- bool any_fixup = false;
2696
 
2697
- const int bidx_start = ((blockIdx.y*nty + blockIdx.x) * block_num_mmq) / (gridDim.y*gridDim.x);
2698
- const int bidx_stop = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x);
2699
 
2700
- int64_t kbc_0;
2701
- int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq;
2702
-
2703
- for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
2704
- kbc_0 = kbc_stop_0;
2705
- kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
2706
 
2707
- const int64_t kbc = kbc_0 - (kbc_0 % blocks_per_ne00) % blocks_per_iter;
2708
- const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter;
2709
 
2710
- // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
2711
- if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
2712
- continue;
2713
- }
 
 
2714
 
2715
- const int jt = kbc_stop / (blocks_per_ne00*nty);
2716
- const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
2717
 
2718
- // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
2719
- if ((unsigned)it != blockIdx.x || (unsigned)jt != blockIdx.y) {
 
 
 
 
 
 
 
 
 
2720
  continue;
2721
  }
2722
 
@@ -2733,16 +2892,71 @@ static __global__ void mul_mat_q_stream_k_fixup(
2733
  sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
2734
  }
2735
  }
 
 
 
 
 
 
 
2736
  }
2737
 
2738
  if (!any_fixup) {
2739
  return;
2740
  }
2741
 
2742
- dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
 
 
 
 
 
 
 
2743
 
2744
- const int i_max = ne01 - blockIdx.x*mmq_y - 1;
2745
- const int j_max = ne11 - blockIdx.y*mmq_x - 1;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2746
 
2747
  #pragma unroll
2748
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -2760,26 +2974,27 @@ static __global__ void mul_mat_q_stream_k_fixup(
2760
  continue;
2761
  }
2762
 
2763
- dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
2764
  }
2765
  }
2766
  }
2767
 
2768
  struct mmq_args {
2769
- const char * x; const char * y; float * dst;
2770
- int64_t ne00; int64_t ne01; int64_t stride01;
2771
- int64_t ne10; int64_t ne11; int64_t stride11;
2772
- int64_t ne0;
2773
  bool use_stream_k;
2774
  };
2775
 
2776
  template<ggml_type type>
2777
- static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
2778
  const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
2779
  const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
2780
- const int shmem_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
2781
- const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
2782
- return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
 
2783
  }
2784
 
2785
  template <ggml_type type, int mmq_x>
@@ -2791,86 +3006,114 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
2791
 
2792
  const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
2793
 
2794
- const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
2795
 
2796
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
2797
- static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
2798
- if (!shmem_limit_raised[id]) {
2799
- CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
2800
- CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
2801
- shmem_limit_raised[id] = true;
2802
  }
2803
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
2804
 
2805
- const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
2806
- const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
2807
- const dim3 block_nums_xy_tiling(nty, ntx, 1);
 
 
 
 
 
 
2808
 
2809
  if (!args.use_stream_k) {
2810
- if (args.ne01 % mmq_y == 0) {
2811
  constexpr bool need_check = false;
2812
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
2813
- (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
 
 
 
2814
  } else {
2815
  constexpr bool need_check = true;
2816
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
2817
- (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
 
 
 
2818
  }
2819
  return;
2820
  }
2821
 
2822
- const dim3 block_nums_mmq(nsm, 1, 1);
 
2823
 
2824
  ggml_cuda_pool & pool = ctx.pool(id);
2825
- ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
 
 
 
2826
 
2827
- if (args.ne01 % mmq_y == 0) {
2828
  constexpr bool need_check = false;
2829
 
2830
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
2831
- (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
 
 
 
 
 
 
 
2832
 
2833
- mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
2834
- (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
 
2835
  } else {
2836
  constexpr bool need_check = true;
2837
 
2838
- mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
2839
- (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
 
 
 
 
 
 
 
2840
 
2841
- mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
2842
- (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
 
2843
  }
2844
  }
2845
 
2846
  template <ggml_type type>
2847
  void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
2848
- const int id = ggml_cuda_get_device();
2849
- const int cc = ggml_cuda_info().devices[id].cc;
2850
- const int smpbo = ggml_cuda_info().devices[id].smpbo;
2851
 
2852
  const int mmq_x_max = get_mmq_x_max_host(cc);
2853
  const int mmq_y = get_mmq_y_host(cc);
2854
- const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
2855
- const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
2856
 
2857
  int mmq_x_best = 0;
2858
- int nparts_best = INT_MAX;
2859
 
2860
- for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
2861
  const int granularity = mmq_get_granularity_host(mmq_x, cc);
2862
 
2863
- if (mmq_x % granularity != 0 || mmq_get_shmem<type>(mmq_x, mmq_y, cc) > smpbo) {
2864
  continue;
2865
  }
2866
 
2867
- const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
2868
- const int nwaves_xy_tiling = ntiles_x*block_num_y;
2869
- const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
2870
 
2871
- if (nparts < nparts_best) {
2872
- mmq_x_best = mmq_x;
2873
- nparts_best = nparts;
2874
  }
2875
  }
2876
 
@@ -2954,6 +3197,9 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
2954
 
2955
  // -------------------------------------------------------------------------------------------------------------------------
2956
 
 
 
 
2957
  void ggml_cuda_op_mul_mat_q(
2958
  ggml_backend_cuda_context & ctx,
2959
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
 
13
  #define MMQ_ITER_K 256
14
  #define MMQ_NWARPS 8
15
 
16
+ typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
17
+ typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
18
+ typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted,
19
+ float * __restrict__ dst, const int stride, const int i_max, const int j_max);
20
 
21
  enum mmq_q8_1_ds_layout {
22
  MMQ_Q8_1_DS_LAYOUT_D4,
 
234
  // ------------------------------------------------------------
235
 
236
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
237
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
238
 
239
  #ifdef NEW_MMA_AVAILABLE
240
  int * x_qs = (int *) x_tile;
 
290
 
291
  template <int mmq_x, int mmq_y, int nwarps>
292
  static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
293
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
294
 
295
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
296
  const int * x_qs = (const int *) x;
 
329
  }
330
 
331
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
332
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
333
 
334
  #ifdef NEW_MMA_AVAILABLE
335
  int * x_qs = (int *) x_tile;
 
385
 
386
  template <int mmq_x, int mmq_y, int nwarps>
387
  static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
388
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
389
 
390
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
391
  const int * x_qs = (const int *) x;
 
424
  }
425
 
426
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
427
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
428
 
429
  #ifdef NEW_MMA_AVAILABLE
430
  int * x_qs = (int *) x_tile;
 
496
  }
497
 
498
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
499
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
500
 
501
  #ifdef NEW_MMA_AVAILABLE
502
  int * x_qs = (int *) x_tile;
 
566
  }
567
 
568
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
569
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
570
 
571
  #ifdef NEW_MMA_AVAILABLE
572
  int * x_qs = (int *) x_tile;
 
622
 
623
  template <int mmq_x, int mmq_y, int nwarps>
624
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
625
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
626
 
627
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
628
  const int * x_qs = (const int *) x;
 
652
 
653
  template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
654
  static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
655
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
656
 
657
  typedef tile<16, 8, int> tile_A;
658
  typedef tile< 8, 8, int> tile_B;
 
733
 
734
  template <int mmq_x, int mmq_y, int nwarps>
735
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
736
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
737
 
738
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
739
  const int * x_qs = (const int *) x;
 
763
 
764
  template <int mmq_x, int mmq_y, int nwarps>
765
  static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
766
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
767
 
768
  typedef tile<16, 8, int> tile_A;
769
  typedef tile< 8, 8, int> tile_B;
 
840
 
841
  template <int mmq_x, int mmq_y, int nwarps>
842
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
843
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
844
 
845
  constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
846
  const int * x_qs = (const int *) x;
 
872
 
873
  template <int mmq_x, int mmq_y, int nwarps>
874
  static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
875
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
876
  #ifdef NEW_MMA_AVAILABLE
877
 
878
  typedef tile<16, 4, int> tile_A;
 
956
  }
957
 
958
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
959
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
960
 
961
  #ifdef NEW_MMA_AVAILABLE
962
  int * x_qs = (int *) x_tile;
 
1012
 
1013
  template <int mmq_x, int mmq_y, int nwarps>
1014
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1015
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1016
 
1017
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1018
  const int * x_qs = (const int *) x;
 
1075
 
1076
  template <int mmq_x, int mmq_y, int nwarps>
1077
  static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1078
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1079
  #ifdef NEW_MMA_AVAILABLE
1080
 
1081
  typedef tile<16, 4, int> tile_A;
 
1202
  }
1203
 
1204
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
1205
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1206
 
1207
  #ifdef NEW_MMA_AVAILABLE
1208
  int * x_qs = (int *) x_tile;
 
1299
 
1300
  template <int mmq_x, int mmq_y, int nwarps>
1301
  static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1302
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1303
 
1304
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1305
  const int * x_qs = (const int *) x;
 
1341
  }
1342
 
1343
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1344
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1345
 
1346
  #ifdef NEW_MMA_AVAILABLE
1347
  int * x_qs = (int *) x_tile;
 
1438
 
1439
  template <int mmq_x, int mmq_y, int nwarps>
1440
  static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1441
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1442
 
1443
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1444
  const int * x_qs = (const int *) x;
 
1470
  }
1471
 
1472
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
1473
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1474
 
1475
  #ifdef NEW_MMA_AVAILABLE
1476
  int * x_qs = (int *) x_tile;
 
1579
 
1580
  template <int mmq_x, int mmq_y, int nwarps>
1581
  static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1582
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1583
 
1584
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
1585
  const int * x_qs = (const int *) x;
 
1611
  }
1612
 
1613
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
1614
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1615
 
1616
  #ifdef NEW_MMA_AVAILABLE
1617
  int * x_qs = (int *) x_tile;
 
1694
 
1695
  template <int mmq_x, int mmq_y, int nwarps>
1696
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
1697
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1698
 
1699
  constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
1700
  const int * x_qs = (const int *) x;
 
1727
 
1728
  template <int mmq_x, int mmq_y, int nwarps>
1729
  static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
1730
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1731
  #ifdef NEW_MMA_AVAILABLE
1732
 
1733
  typedef tile<16, 4, int> tile_A;
 
1836
  }
1837
 
1838
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
1839
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1840
 
1841
  #ifdef NEW_MMA_AVAILABLE
1842
  int * x_qs = (int *) x_tile;
 
1894
  }
1895
 
1896
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
1897
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1898
 
1899
  #ifdef NEW_MMA_AVAILABLE
1900
  int * x_qs = (int *) x_tile;
 
1952
  }
1953
 
1954
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
1955
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1956
 
1957
  #ifdef NEW_MMA_AVAILABLE
1958
  int * x_qs = (int *) x_tile;
 
2008
  }
2009
 
2010
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
2011
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2012
 
2013
  #ifdef NEW_MMA_AVAILABLE
2014
  int * x_qs = (int *) x_tile;
 
2071
  }
2072
 
2073
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
2074
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2075
 
2076
  #ifdef NEW_MMA_AVAILABLE
2077
  int * x_qs = (int *) x_tile;
 
2127
  }
2128
 
2129
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
2130
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2131
 
2132
  #ifdef NEW_MMA_AVAILABLE
2133
  int * x_qs = (int *) x_tile;
 
2190
  }
2191
 
2192
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
2193
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2194
 
2195
  #ifdef NEW_MMA_AVAILABLE
2196
  int * x_qs = (int *) x_tile;
 
2246
  }
2247
 
2248
  template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
2249
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2250
 
2251
  #ifdef NEW_MMA_AVAILABLE
2252
  int * x_qs = (int *) x_tile;
 
2307
 
2308
  template<int mmq_x, int mmq_y, int nwarps, bool need_check>
2309
  static __device__ __forceinline__ void mmq_write_back_dp4a(
2310
+ const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
2311
+ const int stride, const int i_max, const int j_max) {
2312
  #pragma unroll
2313
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2314
  const int j = j0 + threadIdx.y;
 
2325
  continue;
2326
  }
2327
 
2328
+ dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
2329
  }
2330
  }
2331
  }
2332
 
2333
  template<int mmq_x, int mmq_y, int nwarps, bool need_check>
2334
  static __device__ __forceinline__ void mmq_write_back_mma(
2335
+ const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
2336
+ const int stride, const int i_max, const int j_max) {
2337
  typedef tile<16, 8, int> tile_C;
2338
 
2339
  constexpr int granularity = mmq_get_granularity_device(mmq_x);
 
2363
  continue;
2364
  }
2365
 
2366
+ dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
2367
  }
2368
  }
2369
  }
 
2519
  };
2520
 
2521
  template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
2522
+ static __device__ __forceinline__ void mul_mat_q_process_tile(
2523
+ const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
2524
+ const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2525
+ const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
2526
+ const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
2527
 
2528
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2529
  constexpr int mmq_y = get_mmq_y_device();
2530
  constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
2531
 
2532
+ extern __shared__ int data_mul_mat_q[];
2533
+ int * tile_y = data_mul_mat_q + mmq_x;
2534
  int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
2535
 
2536
  #ifdef NEW_MMA_AVAILABLE
 
2545
 
2546
  float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
2547
 
 
 
 
 
 
2548
  for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
2549
+ load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
2550
 
2551
  {
2552
+ const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
2553
  #pragma unroll
2554
  for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
2555
  int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
 
2565
  __syncthreads();
2566
 
2567
  {
2568
+ const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
2569
  #pragma unroll
2570
  for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
2571
  int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
 
2582
  }
2583
 
2584
  if (fixup) {
2585
+ write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
2586
  } else {
2587
+ write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j);
2588
  }
 
 
2589
  }
2590
 
2591
 
 
2604
  #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
2605
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2606
  static __global__ void mul_mat_q(
2607
+ const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
2608
+ const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2609
+ const int ncols_x, const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
2610
+ const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
2611
+ const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
2612
 
2613
  // Skip unused template specializations for faster compilation:
2614
  if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
 
2619
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2620
  constexpr int mmq_y = get_mmq_y_device();
2621
 
2622
+ const int ntx = (ncols_y + mmq_x - 1) / mmq_x; // Number of tiles x
2623
+ const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
2624
+
2625
+ // Initialize the ids for writing back data with just the index.
2626
+ // For regular matrix multiplications this is never changed.
2627
+ // For MoE the correct indices are loaded from ids_dst.
2628
+ extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
2629
+ #pragma unroll
2630
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2631
+ const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
2632
+
2633
+ if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
2634
+ break;
2635
+ }
2636
+
2637
+ ids_dst_shared[j] = j;
2638
+ }
2639
+
2640
  // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
2641
  #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
2642
  {
2643
+ const int wt = blockIdx.z / nchannels_y;
2644
+ const int zt = blockIdx.z - wt*nchannels_y;
2645
+ const int jt = blockIdx.y;
2646
+ const int it = blockIdx.x;
2647
+
2648
+ // Defaults for regular matrix multiplication:
2649
+ int col_low = 0;
2650
+ int col_high = ncols_y;
2651
+ int col_diff = ncols_y;
2652
+ int offset_y = wt*stride_sample_y + zt*stride_channel_y;
2653
+ int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
2654
+
2655
+ if (ids_dst) {
2656
+ col_low = expert_bounds[zt + 0];
2657
+ col_high = expert_bounds[zt + 1];
2658
+ col_diff = col_high - col_low;
2659
+
2660
+ offset_y = 0;
2661
+ offset_dst = 0;
2662
+
2663
+ if (jt*mmq_x >= col_diff) {
2664
+ return;
2665
+ }
2666
+
2667
+ #pragma unroll
2668
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2669
+ const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
2670
+
2671
+ if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
2672
+ break;
2673
+ }
2674
+
2675
+ ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
2676
+ }
2677
+ }
2678
+
2679
+ offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
2680
+ offset_dst += it*mmq_y;
2681
+
2682
+ const int tile_x_max_i = nrows_x - it*mmq_y - 1;
2683
+ const int tile_y_max_j = col_diff - jt*mmq_x - 1;
2684
+
2685
+ const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
2686
+
2687
  constexpr bool fixup = false;
2688
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2689
+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
2690
+ tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
2691
  return;
2692
  }
2693
  #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
2694
 
2695
+ const int64_t blocks_per_ne00 = ncols_x / qk;
2696
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
2697
 
 
 
 
2698
  // kbc == k block continuous, current index in continuous ijk space.
2699
+ int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
2700
+ int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
2701
 
2702
  kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
2703
  kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
 
2706
  int kb0_start = kbc % blocks_per_ne00;
2707
  int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
2708
  while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
2709
+ int tmp = kbc;
2710
+ const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
2711
+ tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
2712
+ const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
2713
+ tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
2714
+ const int zt = tmp / (ntx*blocks_per_ne00);
2715
+ tmp -= zt * (ntx*blocks_per_ne00);
2716
+ const int jt = tmp / blocks_per_ne00;
2717
+
2718
+ // Defaults for regular matrix multiplication:
2719
+ int col_low = 0;
2720
+ int col_high = ncols_y;
2721
+ int col_diff = ncols_y;
2722
+ int offset_y = wt*stride_sample_y + zt*stride_channel_y;
2723
+ int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
2724
+
2725
+ if (ids_dst) {
2726
+ col_low = expert_bounds[zt + 0];
2727
+ col_high = expert_bounds[zt + 1];
2728
+ col_diff = col_high - col_low;
2729
+
2730
+ offset_y = 0;
2731
+ offset_dst = 0;
2732
+
2733
+ if (jt*mmq_x >= col_diff) {
2734
+ kbc += blocks_per_ne00;
2735
+ kbc -= kbc % blocks_per_ne00;
2736
+
2737
+ kb0_start = 0;
2738
+ kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
2739
+
2740
+ continue;
2741
+ }
2742
+
2743
+ #pragma unroll
2744
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2745
+ const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
2746
+
2747
+ if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
2748
+ break;
2749
+ }
2750
+
2751
+ ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
2752
+ }
2753
+ }
2754
+
2755
+ offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
2756
+ offset_dst += it*mmq_y;
2757
+
2758
+ const int tile_x_max_i = nrows_x - it*mmq_y - 1;
2759
+ const int tile_y_max_j = col_diff - jt*mmq_x - 1;
2760
+
2761
+ const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
2762
 
2763
  constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
2764
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2765
+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
2766
+ tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
2767
 
2768
  kbc += blocks_per_ne00;
2769
  kbc -= kbc % blocks_per_ne00;
 
2776
  return;
2777
  }
2778
 
2779
+ int tmp = kbc;
2780
+ const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
2781
+ tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
2782
+ const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
2783
+ tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
2784
+ const int zt = tmp / (ntx*blocks_per_ne00);
2785
+ tmp -= zt * (ntx*blocks_per_ne00);
2786
+ const int jt = tmp / blocks_per_ne00;
2787
+
2788
+ // Defaults for regular matrix multiplication:
2789
+ int col_low = 0;
2790
+ int col_high = ncols_y;
2791
+ int col_diff = ncols_y;
2792
+ int offset_y = wt*stride_sample_y + zt*stride_channel_y;
2793
+ int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
2794
+
2795
+ if (ids_dst) {
2796
+ col_low = expert_bounds[zt + 0];
2797
+ col_high = expert_bounds[zt + 1];
2798
+ col_diff = col_high - col_low;
2799
+
2800
+ offset_y = 0;
2801
+ offset_dst = 0;
2802
+
2803
+ if (jt*mmq_x >= col_diff) {
2804
+ return;
2805
+ }
2806
+
2807
+ // The memory layout for the fixup buffer is always contiguous, therefore reset ids:
2808
+ #pragma unroll
2809
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2810
+ const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
2811
+
2812
+ if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
2813
+ break;
2814
+ }
2815
+
2816
+ ids_dst_shared[j] = j;
2817
+ }
2818
+ }
2819
+
2820
+ offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
2821
+ offset_dst += it*mmq_y;
2822
+
2823
+ const int tile_x_max_i = nrows_x - it*mmq_y - 1;
2824
+ const int tile_y_max_j = col_diff - jt*mmq_x - 1;
2825
+
2826
+ const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
2827
 
2828
  constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
2829
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2830
+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
2831
+ tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
2832
  }
2833
 
2834
 
2835
  template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2836
  static __global__ void mul_mat_q_stream_k_fixup(
2837
+ const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
2838
+ const int ncols_x, const int nrows_x, const int ncols_y, const int stride_col_dst,
2839
+ const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
2840
  constexpr int mmq_y = get_mmq_y_device();
2841
  constexpr int qk = ggml_cuda_type_traits<type>::qk;
2842
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
2843
+ const int64_t blocks_per_ne00 = ncols_x / qk;
2844
 
2845
  float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
2846
 
2847
+ const int ntx = (ncols_y + mmq_x - 1) / mmq_x;
2848
+ const int nty = (nrows_x + mmq_y - 1) / mmq_y;
 
 
2849
 
2850
+ const int bidx0 = blockIdx.x;
 
2851
 
2852
+ // kbc == k block continuous, current index in continuous ijk space.
2853
+ int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
2854
+ int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
 
 
 
2855
 
2856
+ kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter;
2857
+ kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter;
2858
 
2859
+ const bool did_not_have_any_data = kbc0 == kbc0_stop;
2860
+ const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0;
2861
+ const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0;
2862
+ if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
2863
+ return;
2864
+ }
2865
 
2866
+ bool any_fixup = false;
 
2867
 
2868
+ // Iterate over previous blocks and sum up partial sums written to fixup buffer.
2869
+ // All CUDA blocks that get here must have a previous block that needs a fixup.
2870
+ int64_t bidx = bidx0 - 1;
2871
+ int64_t kbc_stop = kbc0;
2872
+ while(true) {
2873
+ int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
2874
+ kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
2875
+
2876
+ if (kbc == kbc_stop) { // Did not have any data.
2877
+ bidx--;
2878
+ kbc_stop = kbc;
2879
  continue;
2880
  }
2881
 
 
2892
  sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
2893
  }
2894
  }
2895
+
2896
+ // If this block started in a previous tile we are done and don't need to combine additional partial results.
2897
+ if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) {
2898
+ break;
2899
+ }
2900
+ bidx--;
2901
+ kbc_stop = kbc;
2902
  }
2903
 
2904
  if (!any_fixup) {
2905
  return;
2906
  }
2907
 
2908
+ int tmp = kbc0;
2909
+ const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
2910
+ tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
2911
+ const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
2912
+ tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
2913
+ const int zt = tmp / (ntx*blocks_per_ne00);
2914
+ tmp -= zt * (ntx*blocks_per_ne00);
2915
+ const int jt = tmp / blocks_per_ne00;
2916
 
2917
+ if (!ids_dst) {
2918
+ const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
2919
+ dst += offset_dst;
2920
+
2921
+ const int i_max = nrows_x - it*mmq_y - 1;
2922
+ const int j_max = ncols_y - jt*mmq_x - 1;
2923
+
2924
+ #pragma unroll
2925
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2926
+ const int j = j0 + threadIdx.y;
2927
+
2928
+ if (j > j_max) {
2929
+ return;
2930
+ }
2931
+
2932
+ #pragma unroll
2933
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
2934
+ const int i = i0 + threadIdx.x;
2935
+
2936
+ if (need_check && i > i_max) {
2937
+ continue;
2938
+ }
2939
+
2940
+ dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
2941
+ }
2942
+ }
2943
+ return;
2944
+ }
2945
+
2946
+ __shared__ int ids_dst_shared[mmq_x];
2947
+ const int col_low = expert_bounds[zt + 0];
2948
+ const int col_high = expert_bounds[zt + 1];
2949
+ const int col_diff = col_high - col_low;
2950
+
2951
+ for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) {
2952
+ ids_dst_shared[j] = ids_dst[col_low + j];
2953
+ }
2954
+
2955
+ const int offset_dst = it*mmq_y;
2956
+ dst += offset_dst;
2957
+
2958
+ const int i_max = nrows_x - it*mmq_y - 1;
2959
+ const int j_max = col_diff - jt*mmq_x - 1;
2960
 
2961
  #pragma unroll
2962
  for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
 
2974
  continue;
2975
  }
2976
 
2977
+ dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
2978
  }
2979
  }
2980
  }
2981
 
2982
  struct mmq_args {
2983
+ const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
2984
+ int64_t ncols_x; int64_t nrows_x; int64_t ncols_y; int64_t stride_row_x; int64_t nrows_dst;
2985
+ int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
2986
+ int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
2987
  bool use_stream_k;
2988
  };
2989
 
2990
  template<ggml_type type>
2991
+ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) {
2992
  const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
2993
  const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
2994
+ const size_t nbs_ids = mmq_x*sizeof(int);
2995
+ const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
2996
+ const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
2997
+ return nbs_ids + nbs_x + GGML_PAD(nbs_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
2998
  }
2999
 
3000
  template <ggml_type type, int mmq_x>
 
3006
 
3007
  const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
3008
 
3009
+ const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
3010
 
3011
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3012
+ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
3013
+ if (!shared_memory_limit_raised[id]) {
3014
+ CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3015
+ CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3016
+ shared_memory_limit_raised[id] = true;
3017
  }
3018
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
3019
 
3020
+ const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
3021
+ const int ntx = (args.ncols_y + mmq_x - 1) / mmq_x;
3022
+ const int ntzw = args.nchannels_y * args.nsamples_y;
3023
+ const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
3024
+
3025
+ GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0);
3026
+ GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0);
3027
+ const int channel_ratio = args.nchannels_y / args.nchannels_x;
3028
+ const int sample_ratio = args.nsamples_y / args.nsamples_x;
3029
 
3030
  if (!args.use_stream_k) {
3031
+ if (args.nrows_x % mmq_y == 0) {
3032
  constexpr bool need_check = false;
3033
+ mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3034
+ (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3035
+ args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
3036
+ channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3037
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3038
  } else {
3039
  constexpr bool need_check = true;
3040
+ mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3041
+ (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3042
+ args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
3043
+ channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3044
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3045
  }
3046
  return;
3047
  }
3048
 
3049
+ const dim3 block_nums_stream_k(nsm, 1, 1);
3050
+ const bool fixup_needed = ntx*nty*ntzw % nsm != 0;
3051
 
3052
  ggml_cuda_pool & pool = ctx.pool(id);
3053
+ ggml_cuda_pool_alloc<float> tmp_fixup(pool);
3054
+ if (fixup_needed) {
3055
+ tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y);
3056
+ }
3057
 
3058
+ if (args.nrows_x % mmq_y == 0) {
3059
  constexpr bool need_check = false;
3060
 
3061
+ mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3062
+ (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3063
+ args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
3064
+ channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3065
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3066
+
3067
+ if (!fixup_needed) {
3068
+ return;
3069
+ }
3070
 
3071
+ mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3072
+ (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
3073
+ args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
3074
  } else {
3075
  constexpr bool need_check = true;
3076
 
3077
+ mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3078
+ (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3079
+ args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
3080
+ channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3081
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
3082
+
3083
+ if (!fixup_needed) {
3084
+ return;
3085
+ }
3086
 
3087
+ mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3088
+ (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
3089
+ args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
3090
  }
3091
  }
3092
 
3093
  template <ggml_type type>
3094
  void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
3095
+ const int id = ggml_cuda_get_device();
3096
+ const int cc = ggml_cuda_info().devices[id].cc;
3097
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
3098
 
3099
  const int mmq_x_max = get_mmq_x_max_host(cc);
3100
  const int mmq_y = get_mmq_y_host(cc);
 
 
3101
 
3102
  int mmq_x_best = 0;
3103
+ int ntiles_x_best = INT_MAX;
3104
 
3105
+ for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
3106
  const int granularity = mmq_get_granularity_host(mmq_x, cc);
3107
 
3108
+ if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc) > smpbo) {
3109
  continue;
3110
  }
3111
 
3112
+ const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
 
 
3113
 
3114
+ if (ntiles_x < ntiles_x_best) {
3115
+ mmq_x_best = mmq_x;
3116
+ ntiles_x_best = ntiles_x;
3117
  }
3118
  }
3119
 
 
3197
 
3198
  // -------------------------------------------------------------------------------------------------------------------------
3199
 
3200
+ void ggml_cuda_mul_mat_q(
3201
+ ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
3202
+
3203
  void ggml_cuda_op_mul_mat_q(
3204
  ggml_backend_cuda_context & ctx,
3205
  const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
ggml/src/ggml-cuda/mmvq.cu CHANGED
@@ -158,7 +158,7 @@ static __global__ void mul_mat_vec_q(
158
  const int blocks_per_row_x = ncols_x / qk;
159
  constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
160
 
161
- // The MUL_MAT_ID code path with ids != nullptr is only implemetned for ncols_dst == 1.
162
  const int channel_dst = blockIdx.y;
163
  const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
164
  const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
@@ -507,7 +507,7 @@ void ggml_cuda_mul_mat_vec_q(
507
  GGML_ASSERT( nb0 == ts_dst);
508
  GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
509
 
510
- GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
511
 
512
  const float * src1_d = (const float *) src1->data;
513
  const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
@@ -519,7 +519,7 @@ void ggml_cuda_mul_mat_vec_q(
519
  const int64_t s11 = src1->nb[1] / ts_src1;
520
  const int64_t s12 = src1->nb[2] / ts_src1;
521
  const int64_t s13 = src1->nb[3] / ts_src1;
522
- quantize_row_q8_1_cuda(src1_d, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
523
  }
524
 
525
  const int64_t s01 = src0->nb[1] / ts_src0;
 
158
  const int blocks_per_row_x = ncols_x / qk;
159
  constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
160
 
161
+ // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
162
  const int channel_dst = blockIdx.y;
163
  const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
164
  const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
 
507
  GGML_ASSERT( nb0 == ts_dst);
508
  GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
509
 
510
+ GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
511
 
512
  const float * src1_d = (const float *) src1->data;
513
  const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
 
519
  const int64_t s11 = src1->nb[1] / ts_src1;
520
  const int64_t s12 = src1->nb[2] / ts_src1;
521
  const int64_t s13 = src1->nb[3] / ts_src1;
522
+ quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
523
  }
524
 
525
  const int64_t s01 = src0->nb[1] / ts_src0;
ggml/src/ggml-cuda/quantize.cu CHANGED
@@ -49,29 +49,38 @@ static __global__ void quantize_q8_1(
49
 
50
  template <mmq_q8_1_ds_layout ds_layout>
51
  static __global__ void quantize_mmq_q8_1(
52
- const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
 
 
53
 
54
  constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
55
  constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
56
 
57
- const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
58
 
59
- if (ix0 >= kx0_padded) {
60
  return;
61
  }
62
 
63
- const float4 * x4 = (const float4 *) x;
 
 
64
 
65
- const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
 
 
 
 
 
66
 
67
  block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
68
 
69
  const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
70
- const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
71
- const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
72
 
73
  // Load 4 floats per thread and calculate max. abs. value between them:
74
- const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
75
  float amax = fabsf(xi.x);
76
  amax = fmaxf(amax, fabsf(xi.y));
77
  amax = fmaxf(amax, fabsf(xi.z));
@@ -87,7 +96,7 @@ static __global__ void quantize_mmq_q8_1(
87
  if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
88
  sum = xi.x + xi.y + xi.z + xi.w;
89
 
90
- // Exchange calculate sum across vals_per_sum/4 threads.
91
  #pragma unroll
92
  for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
93
  sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
@@ -137,9 +146,10 @@ static __global__ void quantize_mmq_q8_1(
137
  }
138
 
139
  void quantize_row_q8_1_cuda(
140
- const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
141
- const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
142
-
 
143
  GGML_ASSERT(ne0 % QK8_1 == 0);
144
 
145
  const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
@@ -150,9 +160,9 @@ void quantize_row_q8_1_cuda(
150
  }
151
 
152
  void quantize_mmq_q8_1_cuda(
153
- const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
154
- const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
155
-
156
  GGML_ASSERT(ne0 % (4*QK8_1) == 0);
157
 
158
  const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
@@ -161,21 +171,18 @@ void quantize_mmq_q8_1_cuda(
161
  switch (mmq_get_q8_1_ds_layout(type_src0)) {
162
  case MMQ_Q8_1_DS_LAYOUT_D4:
163
  quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
164
- <<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
165
  break;
166
  case MMQ_Q8_1_DS_LAYOUT_DS4:
167
  quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
168
- <<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
169
  break;
170
  case MMQ_Q8_1_DS_LAYOUT_D2S6:
171
  quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
172
- <<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
173
  break;
174
  default:
175
  GGML_ABORT("fatal error");
176
  break;
177
  }
178
- GGML_UNUSED(s01);
179
- GGML_UNUSED(s02);
180
- GGML_UNUSED(s03);
181
  }
 
49
 
50
  template <mmq_q8_1_ds_layout ds_layout>
51
  static __global__ void quantize_mmq_q8_1(
52
+ const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
53
+ const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
54
+ const int64_t ne0, const int ne1, const int ne2) {
55
 
56
  constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
57
  constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
58
 
59
+ const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
60
 
61
+ if (i0 >= ne0) {
62
  return;
63
  }
64
 
65
+ const int64_t i1 = blockIdx.y;
66
+ const int64_t i2 = blockIdx.z % ne2;
67
+ const int64_t i3 = blockIdx.z / ne2;
68
 
69
+ const int64_t i00 = i0;
70
+ const int64_t i01 = ids ? ids[i1] : i1;
71
+ const int64_t i02 = i2;
72
+ const int64_t i03 = i3;
73
+
74
+ const float4 * x4 = (const float4 *) x;
75
 
76
  block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
77
 
78
  const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
79
+ const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel
80
+ const int64_t iqs = i0 % (4*QK8_1); // quant index in block
81
 
82
  // Load 4 floats per thread and calculate max. abs. value between them:
83
+ const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
84
  float amax = fabsf(xi.x);
85
  amax = fmaxf(amax, fabsf(xi.y));
86
  amax = fmaxf(amax, fabsf(xi.z));
 
96
  if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
97
  sum = xi.x + xi.y + xi.z + xi.w;
98
 
99
+ // Calculate sums across vals_per_sum/4 threads.
100
  #pragma unroll
101
  for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
102
  sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
 
146
  }
147
 
148
  void quantize_row_q8_1_cuda(
149
+ const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
150
+ const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
151
+ const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
152
+ GGML_ASSERT(!ids);
153
  GGML_ASSERT(ne0 % QK8_1 == 0);
154
 
155
  const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
 
160
  }
161
 
162
  void quantize_mmq_q8_1_cuda(
163
+ const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
164
+ const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
165
+ const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
166
  GGML_ASSERT(ne0 % (4*QK8_1) == 0);
167
 
168
  const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
 
171
  switch (mmq_get_q8_1_ds_layout(type_src0)) {
172
  case MMQ_Q8_1_DS_LAYOUT_D4:
173
  quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
174
+ <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
175
  break;
176
  case MMQ_Q8_1_DS_LAYOUT_DS4:
177
  quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
178
+ <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
179
  break;
180
  case MMQ_Q8_1_DS_LAYOUT_D2S6:
181
  quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
182
+ <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
183
  break;
184
  default:
185
  GGML_ABORT("fatal error");
186
  break;
187
  }
 
 
 
188
  }
ggml/src/ggml-cuda/quantize.cuh CHANGED
@@ -12,13 +12,16 @@ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk
12
  static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
13
 
14
  typedef void (*quantize_cuda_t)(
15
- const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
16
- const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
 
17
 
18
  void quantize_row_q8_1_cuda(
19
- const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
20
- const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
 
21
 
22
  void quantize_mmq_q8_1_cuda(
23
- const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
24
- const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
 
 
12
  static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
13
 
14
  typedef void (*quantize_cuda_t)(
15
+ const float * x, const int32_t * ids, void * vy,
16
+ ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
17
+ int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
18
 
19
  void quantize_row_q8_1_cuda(
20
+ const float * x, const int32_t * ids, void * vy,
21
+ ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
22
+ int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
23
 
24
  void quantize_mmq_q8_1_cuda(
25
+ const float * x, const int32_t * ids, void * vy,
26
+ ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
27
+ int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);