Spaces:
Running
Running
Commit
·
a867083
1
Parent(s):
fd2d86d
CUDA: batched+noncont MMQ, refactor bs>1 MoE code (llama/13199)
Browse files- ggml/src/ggml-cuda/getrows.cu +105 -66
- ggml/src/ggml-cuda/getrows.cuh +7 -0
- ggml/src/ggml-cuda/ggml-cuda.cu +112 -148
- ggml/src/ggml-cuda/mmq.cu +189 -31
- ggml/src/ggml-cuda/mmq.cuh +400 -154
- ggml/src/ggml-cuda/mmvq.cu +3 -3
- ggml/src/ggml-cuda/quantize.cu +28 -21
- ggml/src/ggml-cuda/quantize.cuh +9 -6
ggml/src/ggml-cuda/getrows.cu
CHANGED
|
@@ -33,8 +33,8 @@ static __global__ void k_get_rows(
|
|
| 33 |
dfloat2 v;
|
| 34 |
dequantize_kernel(src0_row, ib, iqs, v);
|
| 35 |
|
| 36 |
-
dst_row[iybs + iqs + 0] = v.x;
|
| 37 |
-
dst_row[iybs + iqs + y_offset] = v.y;
|
| 38 |
}
|
| 39 |
|
| 40 |
template<typename src0_t, typename dst_t>
|
|
@@ -60,7 +60,7 @@ static __global__ void k_get_rows_float(
|
|
| 60 |
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
| 61 |
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
| 62 |
|
| 63 |
-
dst_row[i00] = src0_row[i00];
|
| 64 |
}
|
| 65 |
|
| 66 |
template<typename grad_t, typename dst_t>
|
|
@@ -86,122 +86,161 @@ static __global__ void k_get_rows_back_float(
|
|
| 86 |
dst[dst_row*ncols + col] = sum;
|
| 87 |
}
|
| 88 |
|
| 89 |
-
template<int qk, int qr, dequantize_kernel_t dq>
|
| 90 |
-
static void
|
| 91 |
-
const
|
| 92 |
-
const
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
| 97 |
const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
| 98 |
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
| 99 |
|
| 100 |
// strides in elements
|
| 101 |
-
//const size_t s0 = nb0 /
|
| 102 |
-
const size_t s1 = nb1 /
|
| 103 |
-
const size_t s2 = nb2 /
|
| 104 |
-
const size_t s3 = nb3 /
|
| 105 |
|
| 106 |
-
const size_t s10 = nb10 /
|
| 107 |
-
const size_t s11 = nb11 /
|
| 108 |
-
const size_t s12 = nb12 /
|
| 109 |
-
//const size_t s13 = nb13 /
|
| 110 |
|
| 111 |
GGML_ASSERT(ne00 % 2 == 0);
|
| 112 |
|
| 113 |
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
|
| 114 |
-
|
| 115 |
ne00, /*ne01, ne02, ne03,*/
|
| 116 |
/*ne10, ne11,*/ ne12, /*ne13,*/
|
| 117 |
/* s0,*/ s1, s2, s3,
|
| 118 |
/* nb00,*/ nb01, nb02, nb03,
|
| 119 |
s10, s11, s12/*, s13*/);
|
| 120 |
-
|
| 121 |
-
GGML_UNUSED(dst);
|
| 122 |
}
|
| 123 |
|
| 124 |
-
template<typename src0_t>
|
| 125 |
static void get_rows_cuda_float(
|
| 126 |
-
const
|
| 127 |
-
const
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
GGML_ASSERT(ne13 == 1);
|
| 132 |
-
|
| 133 |
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
| 134 |
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
| 135 |
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
| 136 |
|
| 137 |
// strides in elements
|
| 138 |
-
//const size_t s0 = nb0 /
|
| 139 |
-
const size_t s1 = nb1 /
|
| 140 |
-
const size_t s2 = nb2 /
|
| 141 |
-
const size_t s3 = nb3 /
|
| 142 |
|
| 143 |
-
const size_t s10 = nb10 /
|
| 144 |
-
const size_t s11 = nb11 /
|
| 145 |
-
const size_t s12 = nb12 /
|
| 146 |
-
//const size_t s13 = nb13 /
|
| 147 |
|
| 148 |
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
|
| 149 |
-
|
| 150 |
ne00, /*ne01, ne02, ne03,*/
|
| 151 |
/*ne10, ne11,*/ ne12, /*ne13,*/
|
| 152 |
/* s0,*/ s1, s2, s3,
|
| 153 |
/* nb00,*/ nb01, nb02, nb03,
|
| 154 |
s10, s11, s12/*, s13*/);
|
| 155 |
-
|
| 156 |
-
GGML_UNUSED(dst);
|
| 157 |
}
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
cudaStream_t stream = ctx.stream();
|
| 168 |
-
|
| 169 |
-
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
| 170 |
-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 171 |
-
|
| 172 |
-
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
| 173 |
-
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
|
| 174 |
-
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
|
| 175 |
-
|
| 176 |
-
switch (src0->type) {
|
| 177 |
case GGML_TYPE_F16:
|
| 178 |
-
get_rows_cuda_float(
|
|
|
|
| 179 |
break;
|
| 180 |
case GGML_TYPE_F32:
|
| 181 |
-
get_rows_cuda_float(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
break;
|
| 183 |
case GGML_TYPE_Q4_0:
|
| 184 |
-
|
|
|
|
| 185 |
break;
|
| 186 |
case GGML_TYPE_Q4_1:
|
| 187 |
-
|
|
|
|
| 188 |
break;
|
| 189 |
case GGML_TYPE_Q5_0:
|
| 190 |
-
|
|
|
|
| 191 |
break;
|
| 192 |
case GGML_TYPE_Q5_1:
|
| 193 |
-
|
|
|
|
| 194 |
break;
|
| 195 |
case GGML_TYPE_Q8_0:
|
| 196 |
-
|
|
|
|
| 197 |
break;
|
| 198 |
default:
|
| 199 |
// TODO: k-quants
|
| 200 |
-
GGML_ABORT("%s: unsupported type: %s\n", __func__, ggml_type_name(
|
| 201 |
break;
|
| 202 |
}
|
| 203 |
}
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 206 |
const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
|
| 207 |
const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
|
|
|
|
| 33 |
dfloat2 v;
|
| 34 |
dequantize_kernel(src0_row, ib, iqs, v);
|
| 35 |
|
| 36 |
+
dst_row[iybs + iqs + 0] = float(v.x);
|
| 37 |
+
dst_row[iybs + iqs + y_offset] = float(v.y);
|
| 38 |
}
|
| 39 |
|
| 40 |
template<typename src0_t, typename dst_t>
|
|
|
|
| 60 |
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
| 61 |
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
| 62 |
|
| 63 |
+
dst_row[i00] = float(src0_row[i00]);
|
| 64 |
}
|
| 65 |
|
| 66 |
template<typename grad_t, typename dst_t>
|
|
|
|
| 86 |
dst[dst_row*ncols + col] = sum;
|
| 87 |
}
|
| 88 |
|
| 89 |
+
template<int qk, int qr, dequantize_kernel_t dq, typename dst_t>
|
| 90 |
+
static void get_rows_cuda_q(
|
| 91 |
+
const void * src0_d, const int32_t * src1_d, dst_t * dst_d,
|
| 92 |
+
const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
|
| 93 |
+
const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
|
| 94 |
+
const size_t nb1, const size_t nb2, const size_t nb3,
|
| 95 |
+
cudaStream_t stream) {
|
| 96 |
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
| 97 |
const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
| 98 |
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
| 99 |
|
| 100 |
// strides in elements
|
| 101 |
+
// const size_t s0 = nb0 / sizeof(dst_t);
|
| 102 |
+
const size_t s1 = nb1 / sizeof(dst_t);
|
| 103 |
+
const size_t s2 = nb2 / sizeof(dst_t);
|
| 104 |
+
const size_t s3 = nb3 / sizeof(dst_t);
|
| 105 |
|
| 106 |
+
const size_t s10 = nb10 / sizeof(int32_t);
|
| 107 |
+
const size_t s11 = nb11 / sizeof(int32_t);
|
| 108 |
+
const size_t s12 = nb12 / sizeof(int32_t);
|
| 109 |
+
// const size_t s13 = nb13 / sizeof(int32_t);
|
| 110 |
|
| 111 |
GGML_ASSERT(ne00 % 2 == 0);
|
| 112 |
|
| 113 |
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
|
| 114 |
+
src0_d, src1_d, dst_d,
|
| 115 |
ne00, /*ne01, ne02, ne03,*/
|
| 116 |
/*ne10, ne11,*/ ne12, /*ne13,*/
|
| 117 |
/* s0,*/ s1, s2, s3,
|
| 118 |
/* nb00,*/ nb01, nb02, nb03,
|
| 119 |
s10, s11, s12/*, s13*/);
|
|
|
|
|
|
|
| 120 |
}
|
| 121 |
|
| 122 |
+
template<typename src0_t, typename dst_t>
|
| 123 |
static void get_rows_cuda_float(
|
| 124 |
+
const src0_t * src0_d, const int32_t * src1_d, dst_t * dst_d,
|
| 125 |
+
const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
|
| 126 |
+
const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
|
| 127 |
+
const size_t nb1, const size_t nb2, const size_t nb3,
|
| 128 |
+
cudaStream_t stream) {
|
|
|
|
|
|
|
| 129 |
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
| 130 |
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
| 131 |
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
| 132 |
|
| 133 |
// strides in elements
|
| 134 |
+
// const size_t s0 = nb0 / sizeof(dst_t);
|
| 135 |
+
const size_t s1 = nb1 / sizeof(dst_t);
|
| 136 |
+
const size_t s2 = nb2 / sizeof(dst_t);
|
| 137 |
+
const size_t s3 = nb3 / sizeof(dst_t);
|
| 138 |
|
| 139 |
+
const size_t s10 = nb10 / sizeof(int32_t);
|
| 140 |
+
const size_t s11 = nb11 / sizeof(int32_t);
|
| 141 |
+
const size_t s12 = nb12 / sizeof(int32_t);
|
| 142 |
+
// const size_t s13 = nb13 / sizeof(int32_t);
|
| 143 |
|
| 144 |
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
|
| 145 |
+
src0_d, src1_d, dst_d,
|
| 146 |
ne00, /*ne01, ne02, ne03,*/
|
| 147 |
/*ne10, ne11,*/ ne12, /*ne13,*/
|
| 148 |
/* s0,*/ s1, s2, s3,
|
| 149 |
/* nb00,*/ nb01, nb02, nb03,
|
| 150 |
s10, s11, s12/*, s13*/);
|
|
|
|
|
|
|
| 151 |
}
|
| 152 |
|
| 153 |
+
template <typename dst_t>
|
| 154 |
+
static void ggml_cuda_get_rows_switch_src0_type(
|
| 155 |
+
const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
|
| 156 |
+
const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
|
| 157 |
+
const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
|
| 158 |
+
const size_t nb1, const size_t nb2, const size_t nb3,
|
| 159 |
+
cudaStream_t stream) {
|
| 160 |
+
switch (src0_type) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
case GGML_TYPE_F16:
|
| 162 |
+
get_rows_cuda_float((const half *) src0_d, src1_d, dst_d,
|
| 163 |
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 164 |
break;
|
| 165 |
case GGML_TYPE_F32:
|
| 166 |
+
get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
|
| 167 |
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 168 |
+
break;
|
| 169 |
+
case GGML_TYPE_BF16:
|
| 170 |
+
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
|
| 171 |
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 172 |
break;
|
| 173 |
case GGML_TYPE_Q4_0:
|
| 174 |
+
get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,
|
| 175 |
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 176 |
break;
|
| 177 |
case GGML_TYPE_Q4_1:
|
| 178 |
+
get_rows_cuda_q<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_d, dst_d,
|
| 179 |
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 180 |
break;
|
| 181 |
case GGML_TYPE_Q5_0:
|
| 182 |
+
get_rows_cuda_q<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_d, dst_d,
|
| 183 |
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 184 |
break;
|
| 185 |
case GGML_TYPE_Q5_1:
|
| 186 |
+
get_rows_cuda_q<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_d, dst_d,
|
| 187 |
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 188 |
break;
|
| 189 |
case GGML_TYPE_Q8_0:
|
| 190 |
+
get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
|
| 191 |
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 192 |
break;
|
| 193 |
default:
|
| 194 |
// TODO: k-quants
|
| 195 |
+
GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type));
|
| 196 |
break;
|
| 197 |
}
|
| 198 |
}
|
| 199 |
|
| 200 |
+
void get_rows_cuda(
|
| 201 |
+
const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,
|
| 202 |
+
int64_t ne00, size_t nb01, size_t nb02, size_t nb03,
|
| 203 |
+
int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,
|
| 204 |
+
size_t nb1, size_t nb2, size_t nb3,
|
| 205 |
+
cudaStream_t stream) {
|
| 206 |
+
switch (dst_type) {
|
| 207 |
+
case GGML_TYPE_F32:
|
| 208 |
+
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
|
| 209 |
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 210 |
+
break;
|
| 211 |
+
case GGML_TYPE_F16:
|
| 212 |
+
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
|
| 213 |
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 214 |
+
break;
|
| 215 |
+
case GGML_TYPE_BF16:
|
| 216 |
+
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (nv_bfloat16 *) dst_d,
|
| 217 |
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 218 |
+
break;
|
| 219 |
+
default:
|
| 220 |
+
GGML_ABORT("%s: unsupported dst type: %s\n", __func__, ggml_type_name(dst_type));
|
| 221 |
+
break;
|
| 222 |
+
}
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 226 |
+
const ggml_tensor * src0 = dst->src[0];
|
| 227 |
+
const ggml_tensor * src1 = dst->src[1];
|
| 228 |
+
|
| 229 |
+
cudaStream_t stream = ctx.stream();
|
| 230 |
+
|
| 231 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 232 |
+
|
| 233 |
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
| 234 |
+
GGML_ASSERT(ne13 == 1);
|
| 235 |
+
|
| 236 |
+
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
| 237 |
+
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
|
| 238 |
+
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
|
| 239 |
+
|
| 240 |
+
get_rows_cuda(src0->data, src0->type, (const int32_t *) src1->data, dst->data, dst->type,
|
| 241 |
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 245 |
const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
|
| 246 |
const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
|
ggml/src/ggml-cuda/getrows.cuh
CHANGED
|
@@ -3,6 +3,13 @@
|
|
| 3 |
#define CUDA_GET_ROWS_BLOCK_SIZE 256
|
| 4 |
#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 7 |
|
| 8 |
void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
|
|
| 3 |
#define CUDA_GET_ROWS_BLOCK_SIZE 256
|
| 4 |
#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
|
| 5 |
|
| 6 |
+
void get_rows_cuda(
|
| 7 |
+
const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,
|
| 8 |
+
int64_t ne00, size_t nb01, size_t nb02, size_t nb03,
|
| 9 |
+
int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,
|
| 10 |
+
size_t nb1, size_t nb2, size_t nb3,
|
| 11 |
+
cudaStream_t stream);
|
| 12 |
+
|
| 13 |
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 14 |
|
| 15 |
void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -1551,7 +1551,7 @@ static void ggml_cuda_op_mul_mat(
|
|
| 1551 |
|
| 1552 |
if (src1_on_device && src1_is_contiguous) {
|
| 1553 |
quantize_src1(
|
| 1554 |
-
dev[id].src1_ddf, dev[id].src1_ddq, src0->type, ne10,
|
| 1555 |
nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),
|
| 1556 |
src1_padded_col_size, ne11, ne12, ne13, stream);
|
| 1557 |
CUDA_CHECK(cudaGetLastError());
|
|
@@ -1649,7 +1649,7 @@ static void ggml_cuda_op_mul_mat(
|
|
| 1649 |
|
| 1650 |
if (quantize_src1 && !src1_is_contiguous) {
|
| 1651 |
quantize_src1(
|
| 1652 |
-
src1_ddf_i, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
|
| 1653 |
src1_padded_col_size, src1_ncols, 1, 1, stream);
|
| 1654 |
CUDA_CHECK(cudaGetLastError());
|
| 1655 |
}
|
|
@@ -1949,6 +1949,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
| 1949 |
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
|
| 1950 |
} else if (!split && use_mul_mat_vec_q) {
|
| 1951 |
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
|
|
|
|
|
|
|
| 1952 |
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
|
| 1953 |
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
| 1954 |
// general KQ + KQV multi-batch without FlashAttention
|
|
@@ -1964,183 +1966,145 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
| 1964 |
}
|
| 1965 |
}
|
| 1966 |
|
| 1967 |
-
struct mmid_row_mapping {
|
| 1968 |
-
int32_t i1;
|
| 1969 |
-
int32_t i2;
|
| 1970 |
-
};
|
| 1971 |
-
|
| 1972 |
-
static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
|
| 1973 |
-
int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
|
| 1974 |
-
const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
|
| 1975 |
-
int64_t ne11, int64_t ne10,
|
| 1976 |
-
size_t nb11, size_t nb12) {
|
| 1977 |
-
int32_t iid1 = blockIdx.x;
|
| 1978 |
-
int32_t id = blockIdx.y;
|
| 1979 |
-
|
| 1980 |
-
const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
|
| 1981 |
-
|
| 1982 |
-
if (row_id_i != i02) {
|
| 1983 |
-
return;
|
| 1984 |
-
}
|
| 1985 |
-
|
| 1986 |
-
const int64_t i11 = id % ne11;
|
| 1987 |
-
const int64_t i12 = iid1;
|
| 1988 |
-
|
| 1989 |
-
__shared__ int src1_row;
|
| 1990 |
-
if (threadIdx.x == 0) {
|
| 1991 |
-
src1_row = atomicAdd(cur_src1_row, 1);
|
| 1992 |
-
row_mapping[src1_row] = {id, iid1};
|
| 1993 |
-
}
|
| 1994 |
-
__syncthreads();
|
| 1995 |
-
|
| 1996 |
-
const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
|
| 1997 |
-
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
|
| 1998 |
-
|
| 1999 |
-
for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
|
| 2000 |
-
src1_row_contiguous[i] = src1_row_original[i];
|
| 2001 |
-
}
|
| 2002 |
-
}
|
| 2003 |
-
|
| 2004 |
-
static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
|
| 2005 |
-
const mmid_row_mapping * __restrict__ row_mapping,
|
| 2006 |
-
int64_t ne0,
|
| 2007 |
-
size_t nb1, size_t nb2) {
|
| 2008 |
-
int32_t i = blockIdx.x;
|
| 2009 |
-
|
| 2010 |
-
const int32_t i1 = row_mapping[i].i1;
|
| 2011 |
-
const int32_t i2 = row_mapping[i].i2;
|
| 2012 |
-
|
| 2013 |
-
const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
|
| 2014 |
-
float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
|
| 2015 |
-
|
| 2016 |
-
for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
|
| 2017 |
-
dst_row_original[j] = dst_row_contiguous[j];
|
| 2018 |
-
}
|
| 2019 |
-
}
|
| 2020 |
-
|
| 2021 |
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 2022 |
const ggml_tensor * src0 = dst->src[0];
|
| 2023 |
const ggml_tensor * src1 = dst->src[1];
|
| 2024 |
const ggml_tensor * ids = dst->src[2];
|
| 2025 |
|
| 2026 |
-
|
| 2027 |
-
|
| 2028 |
-
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ne2 == 1) {
|
| 2029 |
-
if (ggml_is_quantized(src0->type)) {
|
| 2030 |
-
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
| 2031 |
-
} else {
|
| 2032 |
-
ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
|
| 2033 |
-
}
|
| 2034 |
-
return;
|
| 2035 |
-
}
|
| 2036 |
-
|
| 2037 |
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
|
| 2038 |
|
| 2039 |
-
|
| 2040 |
|
| 2041 |
-
const
|
| 2042 |
-
const int64_t n_ids = ids->ne[0];
|
| 2043 |
|
| 2044 |
-
|
| 2045 |
-
|
| 2046 |
-
|
| 2047 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2048 |
|
| 2049 |
-
|
| 2050 |
-
|
| 2051 |
-
|
|
|
|
|
|
|
| 2052 |
|
| 2053 |
-
|
| 2054 |
-
char * src1_original = (char *) src1->data;
|
| 2055 |
-
char * dst_original = (char *) dst->data;
|
| 2056 |
|
| 2057 |
-
|
| 2058 |
-
|
| 2059 |
-
src0_row.nb[3] = nb02;
|
| 2060 |
|
| 2061 |
-
|
| 2062 |
-
|
| 2063 |
-
|
| 2064 |
-
|
| 2065 |
-
|
| 2066 |
|
| 2067 |
-
|
| 2068 |
-
|
| 2069 |
-
dst_row.ne[3] = 1;
|
| 2070 |
-
dst_row.nb[2] = nb1;
|
| 2071 |
-
dst_row.nb[3] = nb1;
|
| 2072 |
|
| 2073 |
-
|
| 2074 |
-
|
|
|
|
| 2075 |
|
| 2076 |
-
|
| 2077 |
-
dst_row.data = dst_contiguous.get();
|
| 2078 |
|
| 2079 |
-
|
| 2080 |
-
int64_t num_src1_rows = 0;
|
| 2081 |
|
| 2082 |
-
|
| 2083 |
-
|
| 2084 |
-
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
| 2085 |
|
| 2086 |
-
|
|
|
|
|
|
|
| 2087 |
|
| 2088 |
-
|
| 2089 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2090 |
}
|
| 2091 |
-
|
| 2092 |
-
num_src1_rows++;
|
| 2093 |
}
|
| 2094 |
}
|
|
|
|
|
|
|
| 2095 |
|
| 2096 |
-
|
| 2097 |
-
continue;
|
| 2098 |
-
}
|
| 2099 |
-
|
| 2100 |
-
ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
|
| 2101 |
-
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
|
| 2102 |
-
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
|
| 2103 |
-
|
| 2104 |
-
{
|
| 2105 |
-
dim3 block_dims(std::min((unsigned int)ne10, 768u));
|
| 2106 |
-
dim3 grid_dims(ids->ne[1], n_ids);
|
| 2107 |
-
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
|
| 2108 |
-
src1_original, src1_contiguous.get(),
|
| 2109 |
-
dev_cur_src1_row.get(), dev_row_mapping.get(),
|
| 2110 |
-
ids_dev, i02, ids->nb[1], ids->nb[0],
|
| 2111 |
-
ne11, ne10,
|
| 2112 |
-
nb11, nb12);
|
| 2113 |
-
CUDA_CHECK(cudaGetLastError());
|
| 2114 |
-
}
|
| 2115 |
|
| 2116 |
-
|
|
|
|
| 2117 |
|
| 2118 |
-
|
| 2119 |
-
|
| 2120 |
|
| 2121 |
-
|
| 2122 |
-
|
| 2123 |
-
|
| 2124 |
-
|
|
|
|
| 2125 |
|
| 2126 |
-
|
| 2127 |
-
|
| 2128 |
-
|
| 2129 |
-
|
|
|
|
|
|
|
| 2130 |
|
| 2131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2132 |
|
| 2133 |
-
|
| 2134 |
-
|
| 2135 |
-
dim3 grid_dims(num_src1_rows);
|
| 2136 |
-
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
|
| 2137 |
-
dst_original, dst_contiguous.get(),
|
| 2138 |
-
dev_row_mapping.get(),
|
| 2139 |
-
ne0,
|
| 2140 |
-
nb1, nb2);
|
| 2141 |
-
CUDA_CHECK(cudaGetLastError());
|
| 2142 |
-
}
|
| 2143 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2144 |
}
|
| 2145 |
|
| 2146 |
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
|
|
|
|
| 1551 |
|
| 1552 |
if (src1_on_device && src1_is_contiguous) {
|
| 1553 |
quantize_src1(
|
| 1554 |
+
dev[id].src1_ddf, nullptr, dev[id].src1_ddq, src0->type, ne10,
|
| 1555 |
nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),
|
| 1556 |
src1_padded_col_size, ne11, ne12, ne13, stream);
|
| 1557 |
CUDA_CHECK(cudaGetLastError());
|
|
|
|
| 1649 |
|
| 1650 |
if (quantize_src1 && !src1_is_contiguous) {
|
| 1651 |
quantize_src1(
|
| 1652 |
+
src1_ddf_i, nullptr, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
|
| 1653 |
src1_padded_col_size, src1_ncols, 1, 1, stream);
|
| 1654 |
CUDA_CHECK(cudaGetLastError());
|
| 1655 |
}
|
|
|
|
| 1949 |
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
|
| 1950 |
} else if (!split && use_mul_mat_vec_q) {
|
| 1951 |
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
|
| 1952 |
+
} else if (!split && use_mul_mat_q) {
|
| 1953 |
+
ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
|
| 1954 |
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
|
| 1955 |
!ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
| 1956 |
// general KQ + KQV multi-batch without FlashAttention
|
|
|
|
| 1966 |
}
|
| 1967 |
}
|
| 1968 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1969 |
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 1970 |
const ggml_tensor * src0 = dst->src[0];
|
| 1971 |
const ggml_tensor * src1 = dst->src[1];
|
| 1972 |
const ggml_tensor * ids = dst->src[2];
|
| 1973 |
|
| 1974 |
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 1975 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1976 |
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
|
| 1977 |
|
| 1978 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 1979 |
|
| 1980 |
+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
|
|
|
| 1981 |
|
| 1982 |
+
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 1983 |
+
if (ne2 == 1) {
|
| 1984 |
+
if (ggml_is_quantized(src0->type)) {
|
| 1985 |
+
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
| 1986 |
+
} else {
|
| 1987 |
+
ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
|
| 1988 |
+
}
|
| 1989 |
+
return;
|
| 1990 |
+
}
|
| 1991 |
|
| 1992 |
+
if (ggml_cuda_should_use_mmq(src0->type, cc, ne12)) {
|
| 1993 |
+
ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
|
| 1994 |
+
return;
|
| 1995 |
+
}
|
| 1996 |
+
}
|
| 1997 |
|
| 1998 |
+
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
| 1999 |
|
| 2000 |
+
GGML_ASSERT(nb12 % nb11 == 0);
|
| 2001 |
+
GGML_ASSERT(nb2 % nb1 == 0);
|
|
|
|
| 2002 |
|
| 2003 |
+
const ggml_type type_src1_sorted = (src0->type == GGML_TYPE_F16 && !fast_fp16_hardware_available(cc))
|
| 2004 |
+
|| ggml_is_quantized(src0->type) ? GGML_TYPE_F32 : src0->type;
|
| 2005 |
+
const ggml_type type_dst_sorted = GGML_TYPE_F32;
|
| 2006 |
+
const size_t ts_src1_sorted = ggml_type_size(type_src1_sorted);
|
| 2007 |
+
const size_t ts_dst_sorted = ggml_type_size(type_dst_sorted);
|
| 2008 |
|
| 2009 |
+
const int64_t n_expert_used = ids->ne[0];
|
| 2010 |
+
const int64_t ne_get_rows = ne12 * n_expert_used;
|
|
|
|
|
|
|
|
|
|
| 2011 |
|
| 2012 |
+
std::vector<int32_t> ids_to_sorted_host;
|
| 2013 |
+
ids_to_sorted_host.reserve(2*ne_get_rows);
|
| 2014 |
+
std::vector<int32_t> ids_from_sorted_host(ne_get_rows);
|
| 2015 |
|
| 2016 |
+
ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool(), 2*ne_get_rows);
|
|
|
|
| 2017 |
|
| 2018 |
+
std::vector<int32_t> tokens_per_expert(ne02);
|
|
|
|
| 2019 |
|
| 2020 |
+
ggml_cuda_pool_alloc<char> src1_sorted(ctx.pool(), ne12*n_expert_used*ne10*ts_src1_sorted);
|
| 2021 |
+
ggml_cuda_pool_alloc<char> dst_sorted(ctx.pool(), ne2 *n_expert_used* ne0*ts_dst_sorted);
|
|
|
|
| 2022 |
|
| 2023 |
+
std::vector<char> ids_host(ggml_nbytes(ids));
|
| 2024 |
+
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
| 2025 |
+
CUDA_CHECK(cudaStreamSynchronize(stream));
|
| 2026 |
|
| 2027 |
+
for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
|
| 2028 |
+
for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
|
| 2029 |
+
for (int64_t iex = 0; iex < n_expert_used; ++iex) {
|
| 2030 |
+
const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
|
| 2031 |
+
assert(expert_to_use >= 0 && expert_to_use < ne02);
|
| 2032 |
+
if (expert_to_use == i02) {
|
| 2033 |
+
ids_from_sorted_host[i12*n_expert_used + iex] = ids_to_sorted_host.size();
|
| 2034 |
+
ids_to_sorted_host.push_back(i12*ne11 + iex % ne11);
|
| 2035 |
+
tokens_per_expert[i02]++;
|
| 2036 |
+
break;
|
| 2037 |
}
|
|
|
|
|
|
|
| 2038 |
}
|
| 2039 |
}
|
| 2040 |
+
}
|
| 2041 |
+
GGML_ASSERT(ids_to_sorted_host.size() == size_t(ne_get_rows));
|
| 2042 |
|
| 2043 |
+
ids_to_sorted_host.insert(ids_to_sorted_host.end(), ids_from_sorted_host.begin(), ids_from_sorted_host.end());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2044 |
|
| 2045 |
+
CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_to_sorted_host.data(), 2*ne_get_rows*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
| 2046 |
+
CUDA_CHECK(cudaStreamSynchronize(stream));
|
| 2047 |
|
| 2048 |
+
const int32_t * ids_to_sorted = ids_buf_dev.ptr + 0*ne_get_rows;
|
| 2049 |
+
const int32_t * ids_from_sorted = ids_buf_dev.ptr + 1*ne_get_rows;
|
| 2050 |
|
| 2051 |
+
get_rows_cuda(src1->data, src1->type, ids_to_sorted, src1_sorted.ptr, type_src1_sorted,
|
| 2052 |
+
ne10, nb11, nb12, nb13,
|
| 2053 |
+
ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t),
|
| 2054 |
+
ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, stream);
|
| 2055 |
+
CUDA_CHECK(cudaGetLastError());
|
| 2056 |
|
| 2057 |
+
char * src1_data_cur = (char *) src1_sorted.ptr;
|
| 2058 |
+
char * dst_data_cur = (char *) dst_sorted.ptr;
|
| 2059 |
+
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
| 2060 |
+
if (tokens_per_expert[i02] == 0) {
|
| 2061 |
+
continue;
|
| 2062 |
+
}
|
| 2063 |
|
| 2064 |
+
ggml_tensor src0_slice = *src0;
|
| 2065 |
+
src0_slice.ne[2] = 1;
|
| 2066 |
+
src0_slice.nb[3] = src0_slice.nb[2];
|
| 2067 |
+
src0_slice.data = (char *) src0->data + i02*nb02;
|
| 2068 |
+
|
| 2069 |
+
ggml_tensor src1_slice;
|
| 2070 |
+
memset(&src1_slice, 0, sizeof(src1_slice));
|
| 2071 |
+
src1_slice.buffer = src1->buffer;
|
| 2072 |
+
src1_slice.type = type_src1_sorted;
|
| 2073 |
+
src1_slice.ne[0] = ne10;
|
| 2074 |
+
src1_slice.ne[1] = tokens_per_expert[i02];
|
| 2075 |
+
src1_slice.ne[2] = 1;
|
| 2076 |
+
src1_slice.ne[3] = 1;
|
| 2077 |
+
src1_slice.nb[0] = ts_src1_sorted;
|
| 2078 |
+
src1_slice.nb[1] = src1_slice.ne[0] * src1_slice.nb[0];
|
| 2079 |
+
src1_slice.nb[2] = src1_slice.ne[1] * src1_slice.nb[1];
|
| 2080 |
+
src1_slice.nb[3] = src1_slice.ne[2] * src1_slice.nb[2];
|
| 2081 |
+
src1_slice.data = src1_data_cur;
|
| 2082 |
+
|
| 2083 |
+
ggml_tensor dst_slice;
|
| 2084 |
+
memset(&dst_slice, 0, sizeof(dst_slice));
|
| 2085 |
+
dst_slice.buffer = dst->buffer;
|
| 2086 |
+
dst_slice.type = type_dst_sorted;
|
| 2087 |
+
dst_slice.ne[0] = ne0;
|
| 2088 |
+
dst_slice.ne[1] = tokens_per_expert[i02];
|
| 2089 |
+
dst_slice.ne[2] = 1;
|
| 2090 |
+
dst_slice.ne[3] = 1;
|
| 2091 |
+
dst_slice.nb[0] = ts_dst_sorted;
|
| 2092 |
+
dst_slice.nb[1] = dst_slice.ne[0] * dst_slice.nb[0];
|
| 2093 |
+
dst_slice.nb[2] = dst_slice.ne[1] * dst_slice.nb[1];
|
| 2094 |
+
dst_slice.nb[3] = dst_slice.ne[2] * dst_slice.nb[2];
|
| 2095 |
+
dst_slice.data = dst_data_cur;
|
| 2096 |
+
|
| 2097 |
+
ggml_cuda_mul_mat(ctx, &src0_slice, &src1_slice, &dst_slice);
|
| 2098 |
+
CUDA_CHECK(cudaGetLastError());
|
| 2099 |
|
| 2100 |
+
src1_data_cur += src1_slice.nb[2];
|
| 2101 |
+
dst_data_cur += dst_slice.nb[2];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2102 |
}
|
| 2103 |
+
|
| 2104 |
+
get_rows_cuda(dst_sorted.ptr, type_dst_sorted, ids_from_sorted, dst->data, dst->type,
|
| 2105 |
+
ne0, ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted,
|
| 2106 |
+
ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t),
|
| 2107 |
+
nb1, nb2, nb3, stream);
|
| 2108 |
}
|
| 2109 |
|
| 2110 |
static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
|
ggml/src/ggml-cuda/mmq.cu
CHANGED
|
@@ -1,37 +1,10 @@
|
|
| 1 |
#include "mmq.cuh"
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
ggml_backend_cuda_context & ctx,
|
| 5 |
-
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
| 6 |
-
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
| 7 |
-
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
| 8 |
-
|
| 9 |
-
const int64_t ne00 = src0->ne[0];
|
| 10 |
-
|
| 11 |
-
const int64_t ne10 = src1->ne[0];
|
| 12 |
-
const int64_t ne11 = src1->ne[1];
|
| 13 |
-
GGML_ASSERT(ne10 % QK8_1 == 0);
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
const int64_t row_diff = row_high - row_low;
|
| 18 |
-
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
|
| 19 |
-
|
| 20 |
-
int id = ggml_cuda_get_device();
|
| 21 |
-
const int cc = ggml_cuda_info().devices[id].cc;
|
| 22 |
-
|
| 23 |
-
// the main device has a larger memory buffer to hold the results from all GPUs
|
| 24 |
-
// nrows_dst == nrows of the matrix that the kernel writes into
|
| 25 |
-
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
|
| 26 |
-
|
| 27 |
-
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
|
| 28 |
-
// Also its fixup needs to allocate a temporary buffer in the memory pool.
|
| 29 |
-
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
|
| 30 |
-
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
|
| 31 |
-
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
|
| 32 |
-
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
|
| 33 |
-
|
| 34 |
-
switch (src0->type) {
|
| 35 |
case GGML_TYPE_Q4_0:
|
| 36 |
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
|
| 37 |
break;
|
|
@@ -90,10 +63,195 @@ void ggml_cuda_op_mul_mat_q(
|
|
| 90 |
GGML_ABORT("fatal error");
|
| 91 |
break;
|
| 92 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
GGML_UNUSED(src1);
|
| 95 |
GGML_UNUSED(dst);
|
| 96 |
GGML_UNUSED(src1_ddf_i);
|
|
|
|
| 97 |
}
|
| 98 |
|
| 99 |
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
|
|
|
| 1 |
#include "mmq.cuh"
|
| 2 |
+
#include "quantize.cuh"
|
| 3 |
|
| 4 |
+
#include <vector>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
+
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
| 7 |
+
switch (args.type_x) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
case GGML_TYPE_Q4_0:
|
| 9 |
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
|
| 10 |
break;
|
|
|
|
| 63 |
GGML_ABORT("fatal error");
|
| 64 |
break;
|
| 65 |
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
void ggml_cuda_mul_mat_q(
|
| 69 |
+
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
| 70 |
+
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
| 71 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 72 |
+
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
|
| 73 |
+
|
| 74 |
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
| 75 |
+
|
| 76 |
+
cudaStream_t stream = ctx.stream();
|
| 77 |
+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 78 |
+
|
| 79 |
+
const size_t ts_src0 = ggml_type_size(src0->type);
|
| 80 |
+
const size_t ts_src1 = ggml_type_size(src1->type);
|
| 81 |
+
const size_t ts_dst = ggml_type_size(dst->type);
|
| 82 |
+
|
| 83 |
+
GGML_ASSERT( nb00 == ts_src0);
|
| 84 |
+
GGML_ASSERT( nb10 == ts_src1);
|
| 85 |
+
GGML_ASSERT( nb0 == ts_dst);
|
| 86 |
+
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
|
| 87 |
+
|
| 88 |
+
const char * src0_d = (const char *) src0->data;
|
| 89 |
+
const float * src1_d = (const float *) src1->data;
|
| 90 |
+
float * dst_d = (float *) dst->data;
|
| 91 |
+
|
| 92 |
+
const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
|
| 93 |
+
|
| 94 |
+
const int64_t s01 = src0->nb[1] / ts_src0;
|
| 95 |
+
const int64_t s1 = dst->nb[1] / ts_dst;
|
| 96 |
+
const int64_t s02 = src0->nb[2] / ts_src0;
|
| 97 |
+
const int64_t s2 = dst->nb[2] / ts_dst;
|
| 98 |
+
const int64_t s03 = src0->nb[3] / ts_src0;
|
| 99 |
+
const int64_t s3 = dst->nb[3] / ts_dst;
|
| 100 |
+
|
| 101 |
+
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
|
| 102 |
+
|
| 103 |
+
if (!ids) {
|
| 104 |
+
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
| 105 |
+
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
|
| 106 |
+
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
|
| 107 |
+
|
| 108 |
+
{
|
| 109 |
+
const int64_t s11 = src1->nb[1] / ts_src1;
|
| 110 |
+
const int64_t s12 = src1->nb[2] / ts_src1;
|
| 111 |
+
const int64_t s13 = src1->nb[3] / ts_src1;
|
| 112 |
+
quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
|
| 113 |
+
ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
| 117 |
+
const int64_t s13 = ne12*s12;
|
| 118 |
+
|
| 119 |
+
const mmq_args args = {
|
| 120 |
+
src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
|
| 121 |
+
ne00, ne01, ne1, s01, s1,
|
| 122 |
+
ne02, ne12, s02, s12, s2,
|
| 123 |
+
ne03, ne13, s03, s13, s3,
|
| 124 |
+
use_stream_k};
|
| 125 |
+
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
|
| 126 |
+
return;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
GGML_ASSERT(ne13 == 1);
|
| 130 |
+
GGML_ASSERT(nb12 % nb11 == 0);
|
| 131 |
+
GGML_ASSERT(nb2 % nb1 == 0);
|
| 132 |
+
|
| 133 |
+
const int64_t n_expert_used = ids->ne[0];
|
| 134 |
+
const int64_t ne_get_rows = ne12 * n_expert_used;
|
| 135 |
+
|
| 136 |
+
std::vector<char> ids_host(ggml_nbytes(ids));
|
| 137 |
+
std::vector<int32_t> ids_src1_host;
|
| 138 |
+
ids_src1_host.reserve(ne_get_rows);
|
| 139 |
+
std::vector<int32_t> ids_dst_host;
|
| 140 |
+
ids_dst_host.reserve(ne_get_rows);
|
| 141 |
+
std::vector<int32_t> tokens_per_expert_host(ne02);
|
| 142 |
+
std::vector<int32_t> expert_bounds_host(ne02 + 1);
|
| 143 |
+
ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool());
|
| 144 |
+
|
| 145 |
+
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
|
| 146 |
+
CUDA_CHECK(cudaStreamSynchronize(stream));
|
| 147 |
+
|
| 148 |
+
for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
|
| 149 |
+
for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
|
| 150 |
+
for (int64_t iex = 0; iex < n_expert_used; ++iex) {
|
| 151 |
+
const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
|
| 152 |
+
assert(expert_to_use >= 0 && expert_to_use < ne02);
|
| 153 |
+
if (expert_to_use == i02) {
|
| 154 |
+
ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11);
|
| 155 |
+
ids_dst_host.push_back(i12*ne1 + iex);
|
| 156 |
+
tokens_per_expert_host[i02]++;
|
| 157 |
+
break;
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
int32_t cumsum = 0;
|
| 164 |
+
for (int64_t i = 0; i < ne02; ++i) {
|
| 165 |
+
expert_bounds_host[i] = cumsum;
|
| 166 |
+
cumsum += tokens_per_expert_host[i];
|
| 167 |
+
}
|
| 168 |
+
expert_bounds_host[ne02] = cumsum;
|
| 169 |
+
|
| 170 |
+
std::vector<int32_t> ids_buf_host;
|
| 171 |
+
ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size());
|
| 172 |
+
ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end());
|
| 173 |
+
ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end());
|
| 174 |
+
ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end());
|
| 175 |
+
ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device.
|
| 176 |
+
CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
| 177 |
+
CUDA_CHECK(cudaStreamSynchronize(stream));
|
| 178 |
+
|
| 179 |
+
const int32_t * ids_src1_dev = ids_buf_dev.ptr;
|
| 180 |
+
const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size();
|
| 181 |
+
const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size();
|
| 182 |
+
|
| 183 |
+
const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
|
| 184 |
+
get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
|
| 185 |
+
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
|
| 186 |
+
|
| 187 |
+
const int64_t ne11_flat = ne12*n_expert_used;
|
| 188 |
+
const int64_t ne12_flat = 1;
|
| 189 |
+
const int64_t ne13_flat = 1;
|
| 190 |
+
|
| 191 |
+
{
|
| 192 |
+
const int64_t s11 = src1->nb[1] / ts_src1;
|
| 193 |
+
const int64_t s12 = src1->nb[2] / ts_src1;
|
| 194 |
+
const int64_t s13 = src1->nb[2] / ts_src1;
|
| 195 |
+
quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
|
| 196 |
+
ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
|
| 200 |
+
const int64_t s13 = ne12*s12;
|
| 201 |
+
|
| 202 |
+
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
|
| 203 |
+
const mmq_args args = {
|
| 204 |
+
src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
|
| 205 |
+
ne00, ne01, ne_get_rows, s01, s1,
|
| 206 |
+
ne02, ne02, s02, s12, s2,
|
| 207 |
+
ne03, ne13, s03, s13, s3,
|
| 208 |
+
use_stream_k};
|
| 209 |
+
|
| 210 |
+
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
void ggml_cuda_op_mul_mat_q(
|
| 214 |
+
ggml_backend_cuda_context & ctx,
|
| 215 |
+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
| 216 |
+
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
| 217 |
+
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
| 218 |
+
|
| 219 |
+
const int64_t ne00 = src0->ne[0];
|
| 220 |
+
|
| 221 |
+
const int64_t ne10 = src1->ne[0];
|
| 222 |
+
const int64_t ne11 = src1->ne[1];
|
| 223 |
+
GGML_ASSERT(ne10 % QK8_1 == 0);
|
| 224 |
+
|
| 225 |
+
const int64_t ne0 = dst->ne[0];
|
| 226 |
+
|
| 227 |
+
const int64_t row_diff = row_high - row_low;
|
| 228 |
+
const int64_t stride01 = ne00 / ggml_blck_size(src0->type);
|
| 229 |
+
|
| 230 |
+
const int id = ggml_cuda_get_device();
|
| 231 |
+
const int cc = ggml_cuda_info().devices[id].cc;
|
| 232 |
+
|
| 233 |
+
// the main device has a larger memory buffer to hold the results from all GPUs
|
| 234 |
+
// nrows_dst == nrows of the matrix that the kernel writes into
|
| 235 |
+
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
|
| 236 |
+
|
| 237 |
+
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
|
| 238 |
+
// Also its fixup needs to allocate a temporary buffer in the memory pool.
|
| 239 |
+
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
|
| 240 |
+
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
|
| 241 |
+
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
|
| 242 |
+
const mmq_args args = {
|
| 243 |
+
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
|
| 244 |
+
ne00, row_diff, src1_ncols, stride01, nrows_dst,
|
| 245 |
+
1, 1, 0, 0, 0,
|
| 246 |
+
1, 1, 0, 0, 0,
|
| 247 |
+
use_stream_k};
|
| 248 |
+
|
| 249 |
+
ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
|
| 250 |
|
| 251 |
GGML_UNUSED(src1);
|
| 252 |
GGML_UNUSED(dst);
|
| 253 |
GGML_UNUSED(src1_ddf_i);
|
| 254 |
+
GGML_UNUSED(src1_padded_row_size);
|
| 255 |
}
|
| 256 |
|
| 257 |
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
ggml/src/ggml-cuda/mmq.cuh
CHANGED
|
@@ -13,9 +13,10 @@ using namespace ggml_cuda_mma;
|
|
| 13 |
#define MMQ_ITER_K 256
|
| 14 |
#define MMQ_NWARPS 8
|
| 15 |
|
| 16 |
-
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int
|
| 17 |
-
typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 18 |
-
typedef void (*mmq_write_back_t)(const float * __restrict__ sum,
|
|
|
|
| 19 |
|
| 20 |
enum mmq_q8_1_ds_layout {
|
| 21 |
MMQ_Q8_1_DS_LAYOUT_D4,
|
|
@@ -233,7 +234,7 @@ static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */
|
|
| 233 |
// ------------------------------------------------------------
|
| 234 |
|
| 235 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
|
| 236 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 237 |
|
| 238 |
#ifdef NEW_MMA_AVAILABLE
|
| 239 |
int * x_qs = (int *) x_tile;
|
|
@@ -289,7 +290,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 289 |
|
| 290 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 291 |
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
| 292 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 293 |
|
| 294 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
| 295 |
const int * x_qs = (const int *) x;
|
|
@@ -328,7 +329,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
| 328 |
}
|
| 329 |
|
| 330 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
|
| 331 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 332 |
|
| 333 |
#ifdef NEW_MMA_AVAILABLE
|
| 334 |
int * x_qs = (int *) x_tile;
|
|
@@ -384,7 +385,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 384 |
|
| 385 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 386 |
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
| 387 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 388 |
|
| 389 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
| 390 |
const int * x_qs = (const int *) x;
|
|
@@ -423,7 +424,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
| 423 |
}
|
| 424 |
|
| 425 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
|
| 426 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 427 |
|
| 428 |
#ifdef NEW_MMA_AVAILABLE
|
| 429 |
int * x_qs = (int *) x_tile;
|
|
@@ -495,7 +496,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 495 |
}
|
| 496 |
|
| 497 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
|
| 498 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 499 |
|
| 500 |
#ifdef NEW_MMA_AVAILABLE
|
| 501 |
int * x_qs = (int *) x_tile;
|
|
@@ -565,7 +566,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 565 |
}
|
| 566 |
|
| 567 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
|
| 568 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 569 |
|
| 570 |
#ifdef NEW_MMA_AVAILABLE
|
| 571 |
int * x_qs = (int *) x_tile;
|
|
@@ -621,7 +622,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 621 |
|
| 622 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 623 |
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
| 624 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 625 |
|
| 626 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
| 627 |
const int * x_qs = (const int *) x;
|
|
@@ -651,7 +652,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
| 651 |
|
| 652 |
template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
|
| 653 |
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
| 654 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 655 |
|
| 656 |
typedef tile<16, 8, int> tile_A;
|
| 657 |
typedef tile< 8, 8, int> tile_B;
|
|
@@ -732,7 +733,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
| 732 |
|
| 733 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 734 |
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
| 735 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 736 |
|
| 737 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
| 738 |
const int * x_qs = (const int *) x;
|
|
@@ -762,7 +763,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
|
| 762 |
|
| 763 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 764 |
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
| 765 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 766 |
|
| 767 |
typedef tile<16, 8, int> tile_A;
|
| 768 |
typedef tile< 8, 8, int> tile_B;
|
|
@@ -839,7 +840,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
| 839 |
|
| 840 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 841 |
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
| 842 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 843 |
|
| 844 |
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
| 845 |
const int * x_qs = (const int *) x;
|
|
@@ -871,7 +872,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|
| 871 |
|
| 872 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 873 |
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
| 874 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 875 |
#ifdef NEW_MMA_AVAILABLE
|
| 876 |
|
| 877 |
typedef tile<16, 4, int> tile_A;
|
|
@@ -955,7 +956,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
| 955 |
}
|
| 956 |
|
| 957 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
|
| 958 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 959 |
|
| 960 |
#ifdef NEW_MMA_AVAILABLE
|
| 961 |
int * x_qs = (int *) x_tile;
|
|
@@ -1011,7 +1012,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1011 |
|
| 1012 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1013 |
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
| 1014 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 1015 |
|
| 1016 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
| 1017 |
const int * x_qs = (const int *) x;
|
|
@@ -1074,7 +1075,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
| 1074 |
|
| 1075 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1076 |
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
| 1077 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 1078 |
#ifdef NEW_MMA_AVAILABLE
|
| 1079 |
|
| 1080 |
typedef tile<16, 4, int> tile_A;
|
|
@@ -1201,7 +1202,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
| 1201 |
}
|
| 1202 |
|
| 1203 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
|
| 1204 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 1205 |
|
| 1206 |
#ifdef NEW_MMA_AVAILABLE
|
| 1207 |
int * x_qs = (int *) x_tile;
|
|
@@ -1298,7 +1299,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1298 |
|
| 1299 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1300 |
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
| 1301 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 1302 |
|
| 1303 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
| 1304 |
const int * x_qs = (const int *) x;
|
|
@@ -1340,7 +1341,7 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
|
|
| 1340 |
}
|
| 1341 |
|
| 1342 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
|
| 1343 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 1344 |
|
| 1345 |
#ifdef NEW_MMA_AVAILABLE
|
| 1346 |
int * x_qs = (int *) x_tile;
|
|
@@ -1437,7 +1438,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1437 |
|
| 1438 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1439 |
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
| 1440 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 1441 |
|
| 1442 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
|
| 1443 |
const int * x_qs = (const int *) x;
|
|
@@ -1469,7 +1470,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
|
| 1469 |
}
|
| 1470 |
|
| 1471 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
|
| 1472 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 1473 |
|
| 1474 |
#ifdef NEW_MMA_AVAILABLE
|
| 1475 |
int * x_qs = (int *) x_tile;
|
|
@@ -1578,7 +1579,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1578 |
|
| 1579 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1580 |
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
| 1581 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 1582 |
|
| 1583 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
|
| 1584 |
const int * x_qs = (const int *) x;
|
|
@@ -1610,7 +1611,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
|
| 1610 |
}
|
| 1611 |
|
| 1612 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
|
| 1613 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 1614 |
|
| 1615 |
#ifdef NEW_MMA_AVAILABLE
|
| 1616 |
int * x_qs = (int *) x_tile;
|
|
@@ -1693,7 +1694,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1693 |
|
| 1694 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1695 |
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
| 1696 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 1697 |
|
| 1698 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
|
| 1699 |
const int * x_qs = (const int *) x;
|
|
@@ -1726,7 +1727,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
| 1726 |
|
| 1727 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1728 |
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
| 1729 |
-
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int
|
| 1730 |
#ifdef NEW_MMA_AVAILABLE
|
| 1731 |
|
| 1732 |
typedef tile<16, 4, int> tile_A;
|
|
@@ -1835,7 +1836,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1835 |
}
|
| 1836 |
|
| 1837 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
|
| 1838 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 1839 |
|
| 1840 |
#ifdef NEW_MMA_AVAILABLE
|
| 1841 |
int * x_qs = (int *) x_tile;
|
|
@@ -1893,7 +1894,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1893 |
}
|
| 1894 |
|
| 1895 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
|
| 1896 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 1897 |
|
| 1898 |
#ifdef NEW_MMA_AVAILABLE
|
| 1899 |
int * x_qs = (int *) x_tile;
|
|
@@ -1951,7 +1952,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1951 |
}
|
| 1952 |
|
| 1953 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
|
| 1954 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 1955 |
|
| 1956 |
#ifdef NEW_MMA_AVAILABLE
|
| 1957 |
int * x_qs = (int *) x_tile;
|
|
@@ -2007,7 +2008,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2007 |
}
|
| 2008 |
|
| 2009 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
|
| 2010 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 2011 |
|
| 2012 |
#ifdef NEW_MMA_AVAILABLE
|
| 2013 |
int * x_qs = (int *) x_tile;
|
|
@@ -2070,7 +2071,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2070 |
}
|
| 2071 |
|
| 2072 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
|
| 2073 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 2074 |
|
| 2075 |
#ifdef NEW_MMA_AVAILABLE
|
| 2076 |
int * x_qs = (int *) x_tile;
|
|
@@ -2126,7 +2127,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2126 |
}
|
| 2127 |
|
| 2128 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
|
| 2129 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 2130 |
|
| 2131 |
#ifdef NEW_MMA_AVAILABLE
|
| 2132 |
int * x_qs = (int *) x_tile;
|
|
@@ -2189,7 +2190,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2189 |
}
|
| 2190 |
|
| 2191 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
|
| 2192 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 2193 |
|
| 2194 |
#ifdef NEW_MMA_AVAILABLE
|
| 2195 |
int * x_qs = (int *) x_tile;
|
|
@@ -2245,7 +2246,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2245 |
}
|
| 2246 |
|
| 2247 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
|
| 2248 |
-
const char * __restrict__ x, int * __restrict__ x_tile, const int
|
| 2249 |
|
| 2250 |
#ifdef NEW_MMA_AVAILABLE
|
| 2251 |
int * x_qs = (int *) x_tile;
|
|
@@ -2306,8 +2307,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2306 |
|
| 2307 |
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 2308 |
static __device__ __forceinline__ void mmq_write_back_dp4a(
|
| 2309 |
-
|
| 2310 |
-
|
| 2311 |
#pragma unroll
|
| 2312 |
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 2313 |
const int j = j0 + threadIdx.y;
|
|
@@ -2324,15 +2325,15 @@ static __device__ __forceinline__ void mmq_write_back_dp4a(
|
|
| 2324 |
continue;
|
| 2325 |
}
|
| 2326 |
|
| 2327 |
-
dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
|
| 2328 |
}
|
| 2329 |
}
|
| 2330 |
}
|
| 2331 |
|
| 2332 |
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 2333 |
static __device__ __forceinline__ void mmq_write_back_mma(
|
| 2334 |
-
|
| 2335 |
-
|
| 2336 |
typedef tile<16, 8, int> tile_C;
|
| 2337 |
|
| 2338 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
@@ -2362,7 +2363,7 @@ static __device__ __forceinline__ void mmq_write_back_mma(
|
|
| 2362 |
continue;
|
| 2363 |
}
|
| 2364 |
|
| 2365 |
-
dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
|
| 2366 |
}
|
| 2367 |
}
|
| 2368 |
}
|
|
@@ -2518,17 +2519,18 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
|
|
| 2518 |
};
|
| 2519 |
|
| 2520 |
template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
|
| 2521 |
-
static __device__ void mul_mat_q_process_tile(
|
| 2522 |
-
|
| 2523 |
-
|
| 2524 |
-
|
|
|
|
| 2525 |
|
| 2526 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
| 2527 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2528 |
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
| 2529 |
|
| 2530 |
-
extern __shared__
|
| 2531 |
-
int * tile_y =
|
| 2532 |
int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
|
| 2533 |
|
| 2534 |
#ifdef NEW_MMA_AVAILABLE
|
|
@@ -2543,16 +2545,11 @@ static __device__ void mul_mat_q_process_tile(
|
|
| 2543 |
|
| 2544 |
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
| 2545 |
|
| 2546 |
-
const int tile_x_max_i = ne01 - it*mmq_y - 1;
|
| 2547 |
-
const int tile_y_max_j = ne11 - jt*mmq_x - 1;
|
| 2548 |
-
|
| 2549 |
-
const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
|
| 2550 |
-
|
| 2551 |
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
|
| 2552 |
-
load_tiles(x, tile_x,
|
| 2553 |
|
| 2554 |
{
|
| 2555 |
-
const int * by0 = y +
|
| 2556 |
#pragma unroll
|
| 2557 |
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
|
| 2558 |
int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
|
@@ -2568,7 +2565,7 @@ static __device__ void mul_mat_q_process_tile(
|
|
| 2568 |
__syncthreads();
|
| 2569 |
|
| 2570 |
{
|
| 2571 |
-
const int * by0 = y +
|
| 2572 |
#pragma unroll
|
| 2573 |
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
|
| 2574 |
int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
|
@@ -2585,12 +2582,10 @@ static __device__ void mul_mat_q_process_tile(
|
|
| 2585 |
}
|
| 2586 |
|
| 2587 |
if (fixup) {
|
| 2588 |
-
write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
|
| 2589 |
} else {
|
| 2590 |
-
write_back(sum, dst
|
| 2591 |
}
|
| 2592 |
-
|
| 2593 |
-
GGML_UNUSED(ne00); GGML_UNUSED(ne10);
|
| 2594 |
}
|
| 2595 |
|
| 2596 |
|
|
@@ -2609,8 +2604,11 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
|
| 2609 |
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
| 2610 |
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 2611 |
static __global__ void mul_mat_q(
|
| 2612 |
-
|
| 2613 |
-
|
|
|
|
|
|
|
|
|
|
| 2614 |
|
| 2615 |
// Skip unused template specializations for faster compilation:
|
| 2616 |
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
|
|
@@ -2621,26 +2619,85 @@ static __global__ void mul_mat_q(
|
|
| 2621 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
| 2622 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2623 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2624 |
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
| 2625 |
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
| 2626 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2627 |
constexpr bool fixup = false;
|
| 2628 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
| 2629 |
-
(x,
|
| 2630 |
-
|
| 2631 |
return;
|
| 2632 |
}
|
| 2633 |
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
| 2634 |
|
| 2635 |
-
const int64_t blocks_per_ne00 =
|
| 2636 |
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
| 2637 |
|
| 2638 |
-
const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
|
| 2639 |
-
const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
|
| 2640 |
-
|
| 2641 |
// kbc == k block continuous, current index in continuous ijk space.
|
| 2642 |
-
int64_t kbc = (int64_t) blockIdx.x *
|
| 2643 |
-
int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*
|
| 2644 |
|
| 2645 |
kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
|
| 2646 |
kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
|
|
@@ -2649,13 +2706,64 @@ static __global__ void mul_mat_q(
|
|
| 2649 |
int kb0_start = kbc % blocks_per_ne00;
|
| 2650 |
int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
|
| 2651 |
while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
|
| 2652 |
-
|
| 2653 |
-
const int it =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2654 |
|
| 2655 |
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
| 2656 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
| 2657 |
-
(x,
|
| 2658 |
-
|
| 2659 |
|
| 2660 |
kbc += blocks_per_ne00;
|
| 2661 |
kbc -= kbc % blocks_per_ne00;
|
|
@@ -2668,55 +2776,106 @@ static __global__ void mul_mat_q(
|
|
| 2668 |
return;
|
| 2669 |
}
|
| 2670 |
|
| 2671 |
-
|
| 2672 |
-
const int it =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2673 |
|
| 2674 |
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
| 2675 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
| 2676 |
-
(x,
|
| 2677 |
-
|
| 2678 |
}
|
| 2679 |
|
| 2680 |
|
| 2681 |
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
| 2682 |
static __global__ void mul_mat_q_stream_k_fixup(
|
| 2683 |
-
|
| 2684 |
-
|
|
|
|
| 2685 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2686 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
| 2687 |
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
| 2688 |
-
const int64_t blocks_per_ne00 =
|
| 2689 |
|
| 2690 |
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
| 2691 |
|
| 2692 |
-
const int ntx
|
| 2693 |
-
const int nty
|
| 2694 |
-
|
| 2695 |
-
bool any_fixup = false;
|
| 2696 |
|
| 2697 |
-
const int
|
| 2698 |
-
const int bidx_stop = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x);
|
| 2699 |
|
| 2700 |
-
|
| 2701 |
-
int64_t
|
| 2702 |
-
|
| 2703 |
-
for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
|
| 2704 |
-
kbc_0 = kbc_stop_0;
|
| 2705 |
-
kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
|
| 2706 |
|
| 2707 |
-
|
| 2708 |
-
|
| 2709 |
|
| 2710 |
-
|
| 2711 |
-
|
| 2712 |
-
|
| 2713 |
-
|
|
|
|
|
|
|
| 2714 |
|
| 2715 |
-
|
| 2716 |
-
const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
|
| 2717 |
|
| 2718 |
-
|
| 2719 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2720 |
continue;
|
| 2721 |
}
|
| 2722 |
|
|
@@ -2733,16 +2892,71 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
| 2733 |
sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
| 2734 |
}
|
| 2735 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2736 |
}
|
| 2737 |
|
| 2738 |
if (!any_fixup) {
|
| 2739 |
return;
|
| 2740 |
}
|
| 2741 |
|
| 2742 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2743 |
|
| 2744 |
-
|
| 2745 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2746 |
|
| 2747 |
#pragma unroll
|
| 2748 |
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
@@ -2760,26 +2974,27 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
| 2760 |
continue;
|
| 2761 |
}
|
| 2762 |
|
| 2763 |
-
dst[j*
|
| 2764 |
}
|
| 2765 |
}
|
| 2766 |
}
|
| 2767 |
|
| 2768 |
struct mmq_args {
|
| 2769 |
-
const char * x; const
|
| 2770 |
-
int64_t
|
| 2771 |
-
int64_t
|
| 2772 |
-
int64_t
|
| 2773 |
bool use_stream_k;
|
| 2774 |
};
|
| 2775 |
|
| 2776 |
template<ggml_type type>
|
| 2777 |
-
static
|
| 2778 |
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
|
| 2779 |
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
| 2780 |
-
const
|
| 2781 |
-
const
|
| 2782 |
-
|
|
|
|
| 2783 |
}
|
| 2784 |
|
| 2785 |
template <ggml_type type, int mmq_x>
|
|
@@ -2791,86 +3006,114 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
| 2791 |
|
| 2792 |
const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
|
| 2793 |
|
| 2794 |
-
const int
|
| 2795 |
|
| 2796 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 2797 |
-
static bool
|
| 2798 |
-
if (!
|
| 2799 |
-
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
| 2800 |
-
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
| 2801 |
-
|
| 2802 |
}
|
| 2803 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 2804 |
|
| 2805 |
-
const int nty
|
| 2806 |
-
const int ntx
|
| 2807 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2808 |
|
| 2809 |
if (!args.use_stream_k) {
|
| 2810 |
-
if (args.
|
| 2811 |
constexpr bool need_check = false;
|
| 2812 |
-
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims,
|
| 2813 |
-
(args.x, args.y, args.
|
|
|
|
|
|
|
|
|
|
| 2814 |
} else {
|
| 2815 |
constexpr bool need_check = true;
|
| 2816 |
-
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims,
|
| 2817 |
-
(args.x, args.y, args.
|
|
|
|
|
|
|
|
|
|
| 2818 |
}
|
| 2819 |
return;
|
| 2820 |
}
|
| 2821 |
|
| 2822 |
-
const dim3
|
|
|
|
| 2823 |
|
| 2824 |
ggml_cuda_pool & pool = ctx.pool(id);
|
| 2825 |
-
ggml_cuda_pool_alloc<float> tmp_fixup(pool
|
|
|
|
|
|
|
|
|
|
| 2826 |
|
| 2827 |
-
if (args.
|
| 2828 |
constexpr bool need_check = false;
|
| 2829 |
|
| 2830 |
-
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<
|
| 2831 |
-
(args.x, args.y, args.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2832 |
|
| 2833 |
-
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<
|
| 2834 |
-
(args.
|
|
|
|
| 2835 |
} else {
|
| 2836 |
constexpr bool need_check = true;
|
| 2837 |
|
| 2838 |
-
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<
|
| 2839 |
-
(args.x, args.y, args.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2840 |
|
| 2841 |
-
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<
|
| 2842 |
-
(args.
|
|
|
|
| 2843 |
}
|
| 2844 |
}
|
| 2845 |
|
| 2846 |
template <ggml_type type>
|
| 2847 |
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
| 2848 |
-
const int
|
| 2849 |
-
const int
|
| 2850 |
-
const
|
| 2851 |
|
| 2852 |
const int mmq_x_max = get_mmq_x_max_host(cc);
|
| 2853 |
const int mmq_y = get_mmq_y_host(cc);
|
| 2854 |
-
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
|
| 2855 |
-
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
|
| 2856 |
|
| 2857 |
int mmq_x_best = 0;
|
| 2858 |
-
int
|
| 2859 |
|
| 2860 |
-
for (int mmq_x = 8; mmq_x <= mmq_x_max &&
|
| 2861 |
const int granularity = mmq_get_granularity_host(mmq_x, cc);
|
| 2862 |
|
| 2863 |
-
if (mmq_x % granularity != 0 ||
|
| 2864 |
continue;
|
| 2865 |
}
|
| 2866 |
|
| 2867 |
-
const int ntiles_x = (args.
|
| 2868 |
-
const int nwaves_xy_tiling = ntiles_x*block_num_y;
|
| 2869 |
-
const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
|
| 2870 |
|
| 2871 |
-
if (
|
| 2872 |
-
mmq_x_best
|
| 2873 |
-
|
| 2874 |
}
|
| 2875 |
}
|
| 2876 |
|
|
@@ -2954,6 +3197,9 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
|
|
| 2954 |
|
| 2955 |
// -------------------------------------------------------------------------------------------------------------------------
|
| 2956 |
|
|
|
|
|
|
|
|
|
|
| 2957 |
void ggml_cuda_op_mul_mat_q(
|
| 2958 |
ggml_backend_cuda_context & ctx,
|
| 2959 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
|
|
|
| 13 |
#define MMQ_ITER_K 256
|
| 14 |
#define MMQ_NWARPS 8
|
| 15 |
|
| 16 |
+
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
|
| 17 |
+
typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
|
| 18 |
+
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted,
|
| 19 |
+
float * __restrict__ dst, const int stride, const int i_max, const int j_max);
|
| 20 |
|
| 21 |
enum mmq_q8_1_ds_layout {
|
| 22 |
MMQ_Q8_1_DS_LAYOUT_D4,
|
|
|
|
| 234 |
// ------------------------------------------------------------
|
| 235 |
|
| 236 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
|
| 237 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 238 |
|
| 239 |
#ifdef NEW_MMA_AVAILABLE
|
| 240 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 290 |
|
| 291 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 292 |
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
| 293 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 294 |
|
| 295 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
| 296 |
const int * x_qs = (const int *) x;
|
|
|
|
| 329 |
}
|
| 330 |
|
| 331 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
|
| 332 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 333 |
|
| 334 |
#ifdef NEW_MMA_AVAILABLE
|
| 335 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 385 |
|
| 386 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 387 |
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
| 388 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 389 |
|
| 390 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
| 391 |
const int * x_qs = (const int *) x;
|
|
|
|
| 424 |
}
|
| 425 |
|
| 426 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
|
| 427 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 428 |
|
| 429 |
#ifdef NEW_MMA_AVAILABLE
|
| 430 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 496 |
}
|
| 497 |
|
| 498 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
|
| 499 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 500 |
|
| 501 |
#ifdef NEW_MMA_AVAILABLE
|
| 502 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 566 |
}
|
| 567 |
|
| 568 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
|
| 569 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 570 |
|
| 571 |
#ifdef NEW_MMA_AVAILABLE
|
| 572 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 622 |
|
| 623 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 624 |
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
| 625 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 626 |
|
| 627 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
| 628 |
const int * x_qs = (const int *) x;
|
|
|
|
| 652 |
|
| 653 |
template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
|
| 654 |
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
| 655 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 656 |
|
| 657 |
typedef tile<16, 8, int> tile_A;
|
| 658 |
typedef tile< 8, 8, int> tile_B;
|
|
|
|
| 733 |
|
| 734 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 735 |
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
| 736 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 737 |
|
| 738 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
| 739 |
const int * x_qs = (const int *) x;
|
|
|
|
| 763 |
|
| 764 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 765 |
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
| 766 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 767 |
|
| 768 |
typedef tile<16, 8, int> tile_A;
|
| 769 |
typedef tile< 8, 8, int> tile_B;
|
|
|
|
| 840 |
|
| 841 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 842 |
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
| 843 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 844 |
|
| 845 |
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
| 846 |
const int * x_qs = (const int *) x;
|
|
|
|
| 872 |
|
| 873 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 874 |
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
| 875 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 876 |
#ifdef NEW_MMA_AVAILABLE
|
| 877 |
|
| 878 |
typedef tile<16, 4, int> tile_A;
|
|
|
|
| 956 |
}
|
| 957 |
|
| 958 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
|
| 959 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 960 |
|
| 961 |
#ifdef NEW_MMA_AVAILABLE
|
| 962 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 1012 |
|
| 1013 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1014 |
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
| 1015 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 1016 |
|
| 1017 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
| 1018 |
const int * x_qs = (const int *) x;
|
|
|
|
| 1075 |
|
| 1076 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1077 |
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
| 1078 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 1079 |
#ifdef NEW_MMA_AVAILABLE
|
| 1080 |
|
| 1081 |
typedef tile<16, 4, int> tile_A;
|
|
|
|
| 1202 |
}
|
| 1203 |
|
| 1204 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
|
| 1205 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 1206 |
|
| 1207 |
#ifdef NEW_MMA_AVAILABLE
|
| 1208 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 1299 |
|
| 1300 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1301 |
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
| 1302 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 1303 |
|
| 1304 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
| 1305 |
const int * x_qs = (const int *) x;
|
|
|
|
| 1341 |
}
|
| 1342 |
|
| 1343 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
|
| 1344 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 1345 |
|
| 1346 |
#ifdef NEW_MMA_AVAILABLE
|
| 1347 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 1438 |
|
| 1439 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1440 |
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
| 1441 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 1442 |
|
| 1443 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
|
| 1444 |
const int * x_qs = (const int *) x;
|
|
|
|
| 1470 |
}
|
| 1471 |
|
| 1472 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
|
| 1473 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 1474 |
|
| 1475 |
#ifdef NEW_MMA_AVAILABLE
|
| 1476 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 1579 |
|
| 1580 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1581 |
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
| 1582 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 1583 |
|
| 1584 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
|
| 1585 |
const int * x_qs = (const int *) x;
|
|
|
|
| 1611 |
}
|
| 1612 |
|
| 1613 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
|
| 1614 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 1615 |
|
| 1616 |
#ifdef NEW_MMA_AVAILABLE
|
| 1617 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 1694 |
|
| 1695 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1696 |
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
| 1697 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 1698 |
|
| 1699 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
|
| 1700 |
const int * x_qs = (const int *) x;
|
|
|
|
| 1727 |
|
| 1728 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1729 |
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
| 1730 |
+
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
| 1731 |
#ifdef NEW_MMA_AVAILABLE
|
| 1732 |
|
| 1733 |
typedef tile<16, 4, int> tile_A;
|
|
|
|
| 1836 |
}
|
| 1837 |
|
| 1838 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
|
| 1839 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 1840 |
|
| 1841 |
#ifdef NEW_MMA_AVAILABLE
|
| 1842 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 1894 |
}
|
| 1895 |
|
| 1896 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
|
| 1897 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 1898 |
|
| 1899 |
#ifdef NEW_MMA_AVAILABLE
|
| 1900 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 1952 |
}
|
| 1953 |
|
| 1954 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
|
| 1955 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 1956 |
|
| 1957 |
#ifdef NEW_MMA_AVAILABLE
|
| 1958 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 2008 |
}
|
| 2009 |
|
| 2010 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
|
| 2011 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 2012 |
|
| 2013 |
#ifdef NEW_MMA_AVAILABLE
|
| 2014 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 2071 |
}
|
| 2072 |
|
| 2073 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
|
| 2074 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 2075 |
|
| 2076 |
#ifdef NEW_MMA_AVAILABLE
|
| 2077 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 2127 |
}
|
| 2128 |
|
| 2129 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
|
| 2130 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 2131 |
|
| 2132 |
#ifdef NEW_MMA_AVAILABLE
|
| 2133 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 2190 |
}
|
| 2191 |
|
| 2192 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
|
| 2193 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 2194 |
|
| 2195 |
#ifdef NEW_MMA_AVAILABLE
|
| 2196 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 2246 |
}
|
| 2247 |
|
| 2248 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
|
| 2249 |
+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 2250 |
|
| 2251 |
#ifdef NEW_MMA_AVAILABLE
|
| 2252 |
int * x_qs = (int *) x_tile;
|
|
|
|
| 2307 |
|
| 2308 |
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 2309 |
static __device__ __forceinline__ void mmq_write_back_dp4a(
|
| 2310 |
+
const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
|
| 2311 |
+
const int stride, const int i_max, const int j_max) {
|
| 2312 |
#pragma unroll
|
| 2313 |
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 2314 |
const int j = j0 + threadIdx.y;
|
|
|
|
| 2325 |
continue;
|
| 2326 |
}
|
| 2327 |
|
| 2328 |
+
dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
|
| 2329 |
}
|
| 2330 |
}
|
| 2331 |
}
|
| 2332 |
|
| 2333 |
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
| 2334 |
static __device__ __forceinline__ void mmq_write_back_mma(
|
| 2335 |
+
const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
|
| 2336 |
+
const int stride, const int i_max, const int j_max) {
|
| 2337 |
typedef tile<16, 8, int> tile_C;
|
| 2338 |
|
| 2339 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
|
|
| 2363 |
continue;
|
| 2364 |
}
|
| 2365 |
|
| 2366 |
+
dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
|
| 2367 |
}
|
| 2368 |
}
|
| 2369 |
}
|
|
|
|
| 2519 |
};
|
| 2520 |
|
| 2521 |
template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
|
| 2522 |
+
static __device__ __forceinline__ void mul_mat_q_process_tile(
|
| 2523 |
+
const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
|
| 2524 |
+
const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
| 2525 |
+
const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
|
| 2526 |
+
const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
|
| 2527 |
|
| 2528 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
| 2529 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2530 |
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
| 2531 |
|
| 2532 |
+
extern __shared__ int data_mul_mat_q[];
|
| 2533 |
+
int * tile_y = data_mul_mat_q + mmq_x;
|
| 2534 |
int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
|
| 2535 |
|
| 2536 |
#ifdef NEW_MMA_AVAILABLE
|
|
|
|
| 2545 |
|
| 2546 |
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
| 2547 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2548 |
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
|
| 2549 |
+
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
|
| 2550 |
|
| 2551 |
{
|
| 2552 |
+
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
|
| 2553 |
#pragma unroll
|
| 2554 |
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
|
| 2555 |
int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
|
|
|
| 2565 |
__syncthreads();
|
| 2566 |
|
| 2567 |
{
|
| 2568 |
+
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
|
| 2569 |
#pragma unroll
|
| 2570 |
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
|
| 2571 |
int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
|
|
|
| 2582 |
}
|
| 2583 |
|
| 2584 |
if (fixup) {
|
| 2585 |
+
write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
|
| 2586 |
} else {
|
| 2587 |
+
write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j);
|
| 2588 |
}
|
|
|
|
|
|
|
| 2589 |
}
|
| 2590 |
|
| 2591 |
|
|
|
|
| 2604 |
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
| 2605 |
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 2606 |
static __global__ void mul_mat_q(
|
| 2607 |
+
const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
|
| 2608 |
+
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
| 2609 |
+
const int ncols_x, const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
|
| 2610 |
+
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
| 2611 |
+
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
| 2612 |
|
| 2613 |
// Skip unused template specializations for faster compilation:
|
| 2614 |
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
|
|
|
|
| 2619 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
| 2620 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2621 |
|
| 2622 |
+
const int ntx = (ncols_y + mmq_x - 1) / mmq_x; // Number of tiles x
|
| 2623 |
+
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
|
| 2624 |
+
|
| 2625 |
+
// Initialize the ids for writing back data with just the index.
|
| 2626 |
+
// For regular matrix multiplications this is never changed.
|
| 2627 |
+
// For MoE the correct indices are loaded from ids_dst.
|
| 2628 |
+
extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
|
| 2629 |
+
#pragma unroll
|
| 2630 |
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
|
| 2631 |
+
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
| 2632 |
+
|
| 2633 |
+
if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
|
| 2634 |
+
break;
|
| 2635 |
+
}
|
| 2636 |
+
|
| 2637 |
+
ids_dst_shared[j] = j;
|
| 2638 |
+
}
|
| 2639 |
+
|
| 2640 |
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
| 2641 |
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
| 2642 |
{
|
| 2643 |
+
const int wt = blockIdx.z / nchannels_y;
|
| 2644 |
+
const int zt = blockIdx.z - wt*nchannels_y;
|
| 2645 |
+
const int jt = blockIdx.y;
|
| 2646 |
+
const int it = blockIdx.x;
|
| 2647 |
+
|
| 2648 |
+
// Defaults for regular matrix multiplication:
|
| 2649 |
+
int col_low = 0;
|
| 2650 |
+
int col_high = ncols_y;
|
| 2651 |
+
int col_diff = ncols_y;
|
| 2652 |
+
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
|
| 2653 |
+
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
|
| 2654 |
+
|
| 2655 |
+
if (ids_dst) {
|
| 2656 |
+
col_low = expert_bounds[zt + 0];
|
| 2657 |
+
col_high = expert_bounds[zt + 1];
|
| 2658 |
+
col_diff = col_high - col_low;
|
| 2659 |
+
|
| 2660 |
+
offset_y = 0;
|
| 2661 |
+
offset_dst = 0;
|
| 2662 |
+
|
| 2663 |
+
if (jt*mmq_x >= col_diff) {
|
| 2664 |
+
return;
|
| 2665 |
+
}
|
| 2666 |
+
|
| 2667 |
+
#pragma unroll
|
| 2668 |
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
|
| 2669 |
+
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
| 2670 |
+
|
| 2671 |
+
if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
|
| 2672 |
+
break;
|
| 2673 |
+
}
|
| 2674 |
+
|
| 2675 |
+
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
|
| 2676 |
+
}
|
| 2677 |
+
}
|
| 2678 |
+
|
| 2679 |
+
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
| 2680 |
+
offset_dst += it*mmq_y;
|
| 2681 |
+
|
| 2682 |
+
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
| 2683 |
+
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
| 2684 |
+
|
| 2685 |
+
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
| 2686 |
+
|
| 2687 |
constexpr bool fixup = false;
|
| 2688 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
| 2689 |
+
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
|
| 2690 |
+
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
|
| 2691 |
return;
|
| 2692 |
}
|
| 2693 |
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
| 2694 |
|
| 2695 |
+
const int64_t blocks_per_ne00 = ncols_x / qk;
|
| 2696 |
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
| 2697 |
|
|
|
|
|
|
|
|
|
|
| 2698 |
// kbc == k block continuous, current index in continuous ijk space.
|
| 2699 |
+
int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
| 2700 |
+
int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
| 2701 |
|
| 2702 |
kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
|
| 2703 |
kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
|
|
|
|
| 2706 |
int kb0_start = kbc % blocks_per_ne00;
|
| 2707 |
int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
|
| 2708 |
while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
|
| 2709 |
+
int tmp = kbc;
|
| 2710 |
+
const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
| 2711 |
+
tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
| 2712 |
+
const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
|
| 2713 |
+
tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
|
| 2714 |
+
const int zt = tmp / (ntx*blocks_per_ne00);
|
| 2715 |
+
tmp -= zt * (ntx*blocks_per_ne00);
|
| 2716 |
+
const int jt = tmp / blocks_per_ne00;
|
| 2717 |
+
|
| 2718 |
+
// Defaults for regular matrix multiplication:
|
| 2719 |
+
int col_low = 0;
|
| 2720 |
+
int col_high = ncols_y;
|
| 2721 |
+
int col_diff = ncols_y;
|
| 2722 |
+
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
|
| 2723 |
+
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
|
| 2724 |
+
|
| 2725 |
+
if (ids_dst) {
|
| 2726 |
+
col_low = expert_bounds[zt + 0];
|
| 2727 |
+
col_high = expert_bounds[zt + 1];
|
| 2728 |
+
col_diff = col_high - col_low;
|
| 2729 |
+
|
| 2730 |
+
offset_y = 0;
|
| 2731 |
+
offset_dst = 0;
|
| 2732 |
+
|
| 2733 |
+
if (jt*mmq_x >= col_diff) {
|
| 2734 |
+
kbc += blocks_per_ne00;
|
| 2735 |
+
kbc -= kbc % blocks_per_ne00;
|
| 2736 |
+
|
| 2737 |
+
kb0_start = 0;
|
| 2738 |
+
kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
|
| 2739 |
+
|
| 2740 |
+
continue;
|
| 2741 |
+
}
|
| 2742 |
+
|
| 2743 |
+
#pragma unroll
|
| 2744 |
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
|
| 2745 |
+
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
| 2746 |
+
|
| 2747 |
+
if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
|
| 2748 |
+
break;
|
| 2749 |
+
}
|
| 2750 |
+
|
| 2751 |
+
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
|
| 2752 |
+
}
|
| 2753 |
+
}
|
| 2754 |
+
|
| 2755 |
+
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
| 2756 |
+
offset_dst += it*mmq_y;
|
| 2757 |
+
|
| 2758 |
+
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
| 2759 |
+
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
| 2760 |
+
|
| 2761 |
+
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
| 2762 |
|
| 2763 |
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
| 2764 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
| 2765 |
+
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
|
| 2766 |
+
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
| 2767 |
|
| 2768 |
kbc += blocks_per_ne00;
|
| 2769 |
kbc -= kbc % blocks_per_ne00;
|
|
|
|
| 2776 |
return;
|
| 2777 |
}
|
| 2778 |
|
| 2779 |
+
int tmp = kbc;
|
| 2780 |
+
const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
| 2781 |
+
tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
| 2782 |
+
const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
|
| 2783 |
+
tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
|
| 2784 |
+
const int zt = tmp / (ntx*blocks_per_ne00);
|
| 2785 |
+
tmp -= zt * (ntx*blocks_per_ne00);
|
| 2786 |
+
const int jt = tmp / blocks_per_ne00;
|
| 2787 |
+
|
| 2788 |
+
// Defaults for regular matrix multiplication:
|
| 2789 |
+
int col_low = 0;
|
| 2790 |
+
int col_high = ncols_y;
|
| 2791 |
+
int col_diff = ncols_y;
|
| 2792 |
+
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
|
| 2793 |
+
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
|
| 2794 |
+
|
| 2795 |
+
if (ids_dst) {
|
| 2796 |
+
col_low = expert_bounds[zt + 0];
|
| 2797 |
+
col_high = expert_bounds[zt + 1];
|
| 2798 |
+
col_diff = col_high - col_low;
|
| 2799 |
+
|
| 2800 |
+
offset_y = 0;
|
| 2801 |
+
offset_dst = 0;
|
| 2802 |
+
|
| 2803 |
+
if (jt*mmq_x >= col_diff) {
|
| 2804 |
+
return;
|
| 2805 |
+
}
|
| 2806 |
+
|
| 2807 |
+
// The memory layout for the fixup buffer is always contiguous, therefore reset ids:
|
| 2808 |
+
#pragma unroll
|
| 2809 |
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
|
| 2810 |
+
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
| 2811 |
+
|
| 2812 |
+
if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
|
| 2813 |
+
break;
|
| 2814 |
+
}
|
| 2815 |
+
|
| 2816 |
+
ids_dst_shared[j] = j;
|
| 2817 |
+
}
|
| 2818 |
+
}
|
| 2819 |
+
|
| 2820 |
+
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
| 2821 |
+
offset_dst += it*mmq_y;
|
| 2822 |
+
|
| 2823 |
+
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
| 2824 |
+
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
|
| 2825 |
+
|
| 2826 |
+
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
| 2827 |
|
| 2828 |
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
| 2829 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
| 2830 |
+
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
|
| 2831 |
+
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
| 2832 |
}
|
| 2833 |
|
| 2834 |
|
| 2835 |
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
| 2836 |
static __global__ void mul_mat_q_stream_k_fixup(
|
| 2837 |
+
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
|
| 2838 |
+
const int ncols_x, const int nrows_x, const int ncols_y, const int stride_col_dst,
|
| 2839 |
+
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
|
| 2840 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2841 |
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
| 2842 |
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
| 2843 |
+
const int64_t blocks_per_ne00 = ncols_x / qk;
|
| 2844 |
|
| 2845 |
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
| 2846 |
|
| 2847 |
+
const int ntx = (ncols_y + mmq_x - 1) / mmq_x;
|
| 2848 |
+
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
|
|
|
|
|
|
| 2849 |
|
| 2850 |
+
const int bidx0 = blockIdx.x;
|
|
|
|
| 2851 |
|
| 2852 |
+
// kbc == k block continuous, current index in continuous ijk space.
|
| 2853 |
+
int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
| 2854 |
+
int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
|
|
|
|
|
|
|
|
|
| 2855 |
|
| 2856 |
+
kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter;
|
| 2857 |
+
kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter;
|
| 2858 |
|
| 2859 |
+
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
| 2860 |
+
const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0;
|
| 2861 |
+
const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0;
|
| 2862 |
+
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
|
| 2863 |
+
return;
|
| 2864 |
+
}
|
| 2865 |
|
| 2866 |
+
bool any_fixup = false;
|
|
|
|
| 2867 |
|
| 2868 |
+
// Iterate over previous blocks and sum up partial sums written to fixup buffer.
|
| 2869 |
+
// All CUDA blocks that get here must have a previous block that needs a fixup.
|
| 2870 |
+
int64_t bidx = bidx0 - 1;
|
| 2871 |
+
int64_t kbc_stop = kbc0;
|
| 2872 |
+
while(true) {
|
| 2873 |
+
int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
| 2874 |
+
kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
|
| 2875 |
+
|
| 2876 |
+
if (kbc == kbc_stop) { // Did not have any data.
|
| 2877 |
+
bidx--;
|
| 2878 |
+
kbc_stop = kbc;
|
| 2879 |
continue;
|
| 2880 |
}
|
| 2881 |
|
|
|
|
| 2892 |
sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
| 2893 |
}
|
| 2894 |
}
|
| 2895 |
+
|
| 2896 |
+
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
| 2897 |
+
if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) {
|
| 2898 |
+
break;
|
| 2899 |
+
}
|
| 2900 |
+
bidx--;
|
| 2901 |
+
kbc_stop = kbc;
|
| 2902 |
}
|
| 2903 |
|
| 2904 |
if (!any_fixup) {
|
| 2905 |
return;
|
| 2906 |
}
|
| 2907 |
|
| 2908 |
+
int tmp = kbc0;
|
| 2909 |
+
const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
| 2910 |
+
tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
|
| 2911 |
+
const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
|
| 2912 |
+
tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
|
| 2913 |
+
const int zt = tmp / (ntx*blocks_per_ne00);
|
| 2914 |
+
tmp -= zt * (ntx*blocks_per_ne00);
|
| 2915 |
+
const int jt = tmp / blocks_per_ne00;
|
| 2916 |
|
| 2917 |
+
if (!ids_dst) {
|
| 2918 |
+
const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
|
| 2919 |
+
dst += offset_dst;
|
| 2920 |
+
|
| 2921 |
+
const int i_max = nrows_x - it*mmq_y - 1;
|
| 2922 |
+
const int j_max = ncols_y - jt*mmq_x - 1;
|
| 2923 |
+
|
| 2924 |
+
#pragma unroll
|
| 2925 |
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
| 2926 |
+
const int j = j0 + threadIdx.y;
|
| 2927 |
+
|
| 2928 |
+
if (j > j_max) {
|
| 2929 |
+
return;
|
| 2930 |
+
}
|
| 2931 |
+
|
| 2932 |
+
#pragma unroll
|
| 2933 |
+
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
| 2934 |
+
const int i = i0 + threadIdx.x;
|
| 2935 |
+
|
| 2936 |
+
if (need_check && i > i_max) {
|
| 2937 |
+
continue;
|
| 2938 |
+
}
|
| 2939 |
+
|
| 2940 |
+
dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
|
| 2941 |
+
}
|
| 2942 |
+
}
|
| 2943 |
+
return;
|
| 2944 |
+
}
|
| 2945 |
+
|
| 2946 |
+
__shared__ int ids_dst_shared[mmq_x];
|
| 2947 |
+
const int col_low = expert_bounds[zt + 0];
|
| 2948 |
+
const int col_high = expert_bounds[zt + 1];
|
| 2949 |
+
const int col_diff = col_high - col_low;
|
| 2950 |
+
|
| 2951 |
+
for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) {
|
| 2952 |
+
ids_dst_shared[j] = ids_dst[col_low + j];
|
| 2953 |
+
}
|
| 2954 |
+
|
| 2955 |
+
const int offset_dst = it*mmq_y;
|
| 2956 |
+
dst += offset_dst;
|
| 2957 |
+
|
| 2958 |
+
const int i_max = nrows_x - it*mmq_y - 1;
|
| 2959 |
+
const int j_max = col_diff - jt*mmq_x - 1;
|
| 2960 |
|
| 2961 |
#pragma unroll
|
| 2962 |
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
|
|
| 2974 |
continue;
|
| 2975 |
}
|
| 2976 |
|
| 2977 |
+
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
|
| 2978 |
}
|
| 2979 |
}
|
| 2980 |
}
|
| 2981 |
|
| 2982 |
struct mmq_args {
|
| 2983 |
+
const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
|
| 2984 |
+
int64_t ncols_x; int64_t nrows_x; int64_t ncols_y; int64_t stride_row_x; int64_t nrows_dst;
|
| 2985 |
+
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
|
| 2986 |
+
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
|
| 2987 |
bool use_stream_k;
|
| 2988 |
};
|
| 2989 |
|
| 2990 |
template<ggml_type type>
|
| 2991 |
+
static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) {
|
| 2992 |
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
|
| 2993 |
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
| 2994 |
+
const size_t nbs_ids = mmq_x*sizeof(int);
|
| 2995 |
+
const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
| 2996 |
+
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
|
| 2997 |
+
return nbs_ids + nbs_x + GGML_PAD(nbs_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
|
| 2998 |
}
|
| 2999 |
|
| 3000 |
template <ggml_type type, int mmq_x>
|
|
|
|
| 3006 |
|
| 3007 |
const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
|
| 3008 |
|
| 3009 |
+
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
|
| 3010 |
|
| 3011 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 3012 |
+
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
| 3013 |
+
if (!shared_memory_limit_raised[id]) {
|
| 3014 |
+
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
|
| 3015 |
+
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
|
| 3016 |
+
shared_memory_limit_raised[id] = true;
|
| 3017 |
}
|
| 3018 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 3019 |
|
| 3020 |
+
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
|
| 3021 |
+
const int ntx = (args.ncols_y + mmq_x - 1) / mmq_x;
|
| 3022 |
+
const int ntzw = args.nchannels_y * args.nsamples_y;
|
| 3023 |
+
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
|
| 3024 |
+
|
| 3025 |
+
GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0);
|
| 3026 |
+
GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0);
|
| 3027 |
+
const int channel_ratio = args.nchannels_y / args.nchannels_x;
|
| 3028 |
+
const int sample_ratio = args.nsamples_y / args.nsamples_x;
|
| 3029 |
|
| 3030 |
if (!args.use_stream_k) {
|
| 3031 |
+
if (args.nrows_x % mmq_y == 0) {
|
| 3032 |
constexpr bool need_check = false;
|
| 3033 |
+
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
| 3034 |
+
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
| 3035 |
+
args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
|
| 3036 |
+
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
| 3037 |
+
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
| 3038 |
} else {
|
| 3039 |
constexpr bool need_check = true;
|
| 3040 |
+
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
| 3041 |
+
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
| 3042 |
+
args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
|
| 3043 |
+
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
| 3044 |
+
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
| 3045 |
}
|
| 3046 |
return;
|
| 3047 |
}
|
| 3048 |
|
| 3049 |
+
const dim3 block_nums_stream_k(nsm, 1, 1);
|
| 3050 |
+
const bool fixup_needed = ntx*nty*ntzw % nsm != 0;
|
| 3051 |
|
| 3052 |
ggml_cuda_pool & pool = ctx.pool(id);
|
| 3053 |
+
ggml_cuda_pool_alloc<float> tmp_fixup(pool);
|
| 3054 |
+
if (fixup_needed) {
|
| 3055 |
+
tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y);
|
| 3056 |
+
}
|
| 3057 |
|
| 3058 |
+
if (args.nrows_x % mmq_y == 0) {
|
| 3059 |
constexpr bool need_check = false;
|
| 3060 |
|
| 3061 |
+
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
| 3062 |
+
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
| 3063 |
+
args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
|
| 3064 |
+
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
| 3065 |
+
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
| 3066 |
+
|
| 3067 |
+
if (!fixup_needed) {
|
| 3068 |
+
return;
|
| 3069 |
+
}
|
| 3070 |
|
| 3071 |
+
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
| 3072 |
+
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
|
| 3073 |
+
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
|
| 3074 |
} else {
|
| 3075 |
constexpr bool need_check = true;
|
| 3076 |
|
| 3077 |
+
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
| 3078 |
+
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
| 3079 |
+
args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
|
| 3080 |
+
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
| 3081 |
+
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
| 3082 |
+
|
| 3083 |
+
if (!fixup_needed) {
|
| 3084 |
+
return;
|
| 3085 |
+
}
|
| 3086 |
|
| 3087 |
+
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
| 3088 |
+
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
|
| 3089 |
+
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
|
| 3090 |
}
|
| 3091 |
}
|
| 3092 |
|
| 3093 |
template <ggml_type type>
|
| 3094 |
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
| 3095 |
+
const int id = ggml_cuda_get_device();
|
| 3096 |
+
const int cc = ggml_cuda_info().devices[id].cc;
|
| 3097 |
+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
| 3098 |
|
| 3099 |
const int mmq_x_max = get_mmq_x_max_host(cc);
|
| 3100 |
const int mmq_y = get_mmq_y_host(cc);
|
|
|
|
|
|
|
| 3101 |
|
| 3102 |
int mmq_x_best = 0;
|
| 3103 |
+
int ntiles_x_best = INT_MAX;
|
| 3104 |
|
| 3105 |
+
for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
|
| 3106 |
const int granularity = mmq_get_granularity_host(mmq_x, cc);
|
| 3107 |
|
| 3108 |
+
if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc) > smpbo) {
|
| 3109 |
continue;
|
| 3110 |
}
|
| 3111 |
|
| 3112 |
+
const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
|
|
|
|
|
|
|
| 3113 |
|
| 3114 |
+
if (ntiles_x < ntiles_x_best) {
|
| 3115 |
+
mmq_x_best = mmq_x;
|
| 3116 |
+
ntiles_x_best = ntiles_x;
|
| 3117 |
}
|
| 3118 |
}
|
| 3119 |
|
|
|
|
| 3197 |
|
| 3198 |
// -------------------------------------------------------------------------------------------------------------------------
|
| 3199 |
|
| 3200 |
+
void ggml_cuda_mul_mat_q(
|
| 3201 |
+
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
| 3202 |
+
|
| 3203 |
void ggml_cuda_op_mul_mat_q(
|
| 3204 |
ggml_backend_cuda_context & ctx,
|
| 3205 |
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
ggml/src/ggml-cuda/mmvq.cu
CHANGED
|
@@ -158,7 +158,7 @@ static __global__ void mul_mat_vec_q(
|
|
| 158 |
const int blocks_per_row_x = ncols_x / qk;
|
| 159 |
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
|
| 160 |
|
| 161 |
-
// The MUL_MAT_ID code path with ids != nullptr is only
|
| 162 |
const int channel_dst = blockIdx.y;
|
| 163 |
const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
| 164 |
const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
|
|
@@ -507,7 +507,7 @@ void ggml_cuda_mul_mat_vec_q(
|
|
| 507 |
GGML_ASSERT( nb0 == ts_dst);
|
| 508 |
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
|
| 509 |
|
| 510 |
-
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for
|
| 511 |
|
| 512 |
const float * src1_d = (const float *) src1->data;
|
| 513 |
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
|
@@ -519,7 +519,7 @@ void ggml_cuda_mul_mat_vec_q(
|
|
| 519 |
const int64_t s11 = src1->nb[1] / ts_src1;
|
| 520 |
const int64_t s12 = src1->nb[2] / ts_src1;
|
| 521 |
const int64_t s13 = src1->nb[3] / ts_src1;
|
| 522 |
-
quantize_row_q8_1_cuda(src1_d, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
|
| 523 |
}
|
| 524 |
|
| 525 |
const int64_t s01 = src0->nb[1] / ts_src0;
|
|
|
|
| 158 |
const int blocks_per_row_x = ncols_x / qk;
|
| 159 |
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
|
| 160 |
|
| 161 |
+
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
|
| 162 |
const int channel_dst = blockIdx.y;
|
| 163 |
const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
| 164 |
const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
|
|
|
|
| 507 |
GGML_ASSERT( nb0 == ts_dst);
|
| 508 |
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
|
| 509 |
|
| 510 |
+
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
|
| 511 |
|
| 512 |
const float * src1_d = (const float *) src1->data;
|
| 513 |
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
|
|
|
| 519 |
const int64_t s11 = src1->nb[1] / ts_src1;
|
| 520 |
const int64_t s12 = src1->nb[2] / ts_src1;
|
| 521 |
const int64_t s13 = src1->nb[3] / ts_src1;
|
| 522 |
+
quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
|
| 523 |
}
|
| 524 |
|
| 525 |
const int64_t s01 = src0->nb[1] / ts_src0;
|
ggml/src/ggml-cuda/quantize.cu
CHANGED
|
@@ -49,29 +49,38 @@ static __global__ void quantize_q8_1(
|
|
| 49 |
|
| 50 |
template <mmq_q8_1_ds_layout ds_layout>
|
| 51 |
static __global__ void quantize_mmq_q8_1(
|
| 52 |
-
|
|
|
|
|
|
|
| 53 |
|
| 54 |
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
|
| 55 |
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
|
| 56 |
|
| 57 |
-
const int64_t
|
| 58 |
|
| 59 |
-
if (
|
| 60 |
return;
|
| 61 |
}
|
| 62 |
|
| 63 |
-
const
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
const int64_t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
| 68 |
|
| 69 |
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
|
| 70 |
-
const int64_t ib = ib0 + (
|
| 71 |
-
const int64_t iqs =
|
| 72 |
|
| 73 |
// Load 4 floats per thread and calculate max. abs. value between them:
|
| 74 |
-
const float4 xi =
|
| 75 |
float amax = fabsf(xi.x);
|
| 76 |
amax = fmaxf(amax, fabsf(xi.y));
|
| 77 |
amax = fmaxf(amax, fabsf(xi.z));
|
|
@@ -87,7 +96,7 @@ static __global__ void quantize_mmq_q8_1(
|
|
| 87 |
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
|
| 88 |
sum = xi.x + xi.y + xi.z + xi.w;
|
| 89 |
|
| 90 |
-
//
|
| 91 |
#pragma unroll
|
| 92 |
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
|
| 93 |
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
|
|
@@ -137,9 +146,10 @@ static __global__ void quantize_mmq_q8_1(
|
|
| 137 |
}
|
| 138 |
|
| 139 |
void quantize_row_q8_1_cuda(
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
| 143 |
GGML_ASSERT(ne0 % QK8_1 == 0);
|
| 144 |
|
| 145 |
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
|
@@ -150,9 +160,9 @@ void quantize_row_q8_1_cuda(
|
|
| 150 |
}
|
| 151 |
|
| 152 |
void quantize_mmq_q8_1_cuda(
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
|
| 157 |
|
| 158 |
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
|
@@ -161,21 +171,18 @@ void quantize_mmq_q8_1_cuda(
|
|
| 161 |
switch (mmq_get_q8_1_ds_layout(type_src0)) {
|
| 162 |
case MMQ_Q8_1_DS_LAYOUT_D4:
|
| 163 |
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
|
| 164 |
-
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1,
|
| 165 |
break;
|
| 166 |
case MMQ_Q8_1_DS_LAYOUT_DS4:
|
| 167 |
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
|
| 168 |
-
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1,
|
| 169 |
break;
|
| 170 |
case MMQ_Q8_1_DS_LAYOUT_D2S6:
|
| 171 |
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
|
| 172 |
-
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1,
|
| 173 |
break;
|
| 174 |
default:
|
| 175 |
GGML_ABORT("fatal error");
|
| 176 |
break;
|
| 177 |
}
|
| 178 |
-
GGML_UNUSED(s01);
|
| 179 |
-
GGML_UNUSED(s02);
|
| 180 |
-
GGML_UNUSED(s03);
|
| 181 |
}
|
|
|
|
| 49 |
|
| 50 |
template <mmq_q8_1_ds_layout ds_layout>
|
| 51 |
static __global__ void quantize_mmq_q8_1(
|
| 52 |
+
const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
|
| 53 |
+
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
| 54 |
+
const int64_t ne0, const int ne1, const int ne2) {
|
| 55 |
|
| 56 |
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
|
| 57 |
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
|
| 58 |
|
| 59 |
+
const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
|
| 60 |
|
| 61 |
+
if (i0 >= ne0) {
|
| 62 |
return;
|
| 63 |
}
|
| 64 |
|
| 65 |
+
const int64_t i1 = blockIdx.y;
|
| 66 |
+
const int64_t i2 = blockIdx.z % ne2;
|
| 67 |
+
const int64_t i3 = blockIdx.z / ne2;
|
| 68 |
|
| 69 |
+
const int64_t i00 = i0;
|
| 70 |
+
const int64_t i01 = ids ? ids[i1] : i1;
|
| 71 |
+
const int64_t i02 = i2;
|
| 72 |
+
const int64_t i03 = i3;
|
| 73 |
+
|
| 74 |
+
const float4 * x4 = (const float4 *) x;
|
| 75 |
|
| 76 |
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
| 77 |
|
| 78 |
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
|
| 79 |
+
const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel
|
| 80 |
+
const int64_t iqs = i0 % (4*QK8_1); // quant index in block
|
| 81 |
|
| 82 |
// Load 4 floats per thread and calculate max. abs. value between them:
|
| 83 |
+
const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
|
| 84 |
float amax = fabsf(xi.x);
|
| 85 |
amax = fmaxf(amax, fabsf(xi.y));
|
| 86 |
amax = fmaxf(amax, fabsf(xi.z));
|
|
|
|
| 96 |
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
|
| 97 |
sum = xi.x + xi.y + xi.z + xi.w;
|
| 98 |
|
| 99 |
+
// Calculate sums across vals_per_sum/4 threads.
|
| 100 |
#pragma unroll
|
| 101 |
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
|
| 102 |
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
|
|
|
|
| 146 |
}
|
| 147 |
|
| 148 |
void quantize_row_q8_1_cuda(
|
| 149 |
+
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
|
| 150 |
+
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
| 151 |
+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
|
| 152 |
+
GGML_ASSERT(!ids);
|
| 153 |
GGML_ASSERT(ne0 % QK8_1 == 0);
|
| 154 |
|
| 155 |
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
|
|
|
| 160 |
}
|
| 161 |
|
| 162 |
void quantize_mmq_q8_1_cuda(
|
| 163 |
+
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
|
| 164 |
+
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
| 165 |
+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
|
| 166 |
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
|
| 167 |
|
| 168 |
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
|
|
|
| 171 |
switch (mmq_get_q8_1_ds_layout(type_src0)) {
|
| 172 |
case MMQ_Q8_1_DS_LAYOUT_D4:
|
| 173 |
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
|
| 174 |
+
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
| 175 |
break;
|
| 176 |
case MMQ_Q8_1_DS_LAYOUT_DS4:
|
| 177 |
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
|
| 178 |
+
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
| 179 |
break;
|
| 180 |
case MMQ_Q8_1_DS_LAYOUT_D2S6:
|
| 181 |
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
|
| 182 |
+
<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
| 183 |
break;
|
| 184 |
default:
|
| 185 |
GGML_ABORT("fatal error");
|
| 186 |
break;
|
| 187 |
}
|
|
|
|
|
|
|
|
|
|
| 188 |
}
|
ggml/src/ggml-cuda/quantize.cuh
CHANGED
|
@@ -12,13 +12,16 @@ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk
|
|
| 12 |
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
|
| 13 |
|
| 14 |
typedef void (*quantize_cuda_t)(
|
| 15 |
-
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
void quantize_row_q8_1_cuda(
|
| 19 |
-
|
| 20 |
-
|
|
|
|
| 21 |
|
| 22 |
void quantize_mmq_q8_1_cuda(
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
| 12 |
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
|
| 13 |
|
| 14 |
typedef void (*quantize_cuda_t)(
|
| 15 |
+
const float * x, const int32_t * ids, void * vy,
|
| 16 |
+
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
|
| 17 |
+
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
|
| 18 |
|
| 19 |
void quantize_row_q8_1_cuda(
|
| 20 |
+
const float * x, const int32_t * ids, void * vy,
|
| 21 |
+
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
|
| 22 |
+
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
|
| 23 |
|
| 24 |
void quantize_mmq_q8_1_cuda(
|
| 25 |
+
const float * x, const int32_t * ids, void * vy,
|
| 26 |
+
ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
|
| 27 |
+
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
|