Spaces:
Running
Running
Commit
·
961ef57
1
Parent(s):
4d90c3d
CUDA: add BF16 support (llama/11093)
Browse files
ggml/src/ggml-cuda/convert.cu
CHANGED
|
@@ -680,6 +680,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|
| 680 |
return dequantize_row_iq3_s_cuda;
|
| 681 |
case GGML_TYPE_F16:
|
| 682 |
return convert_unary_cuda<half>;
|
|
|
|
|
|
|
| 683 |
default:
|
| 684 |
return nullptr;
|
| 685 |
}
|
|
|
|
| 680 |
return dequantize_row_iq3_s_cuda;
|
| 681 |
case GGML_TYPE_F16:
|
| 682 |
return convert_unary_cuda<half>;
|
| 683 |
+
case GGML_TYPE_BF16:
|
| 684 |
+
return convert_unary_cuda<nv_bfloat16>;
|
| 685 |
default:
|
| 686 |
return nullptr;
|
| 687 |
}
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -1728,7 +1728,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
| 1728 |
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 1729 |
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
|
| 1730 |
|
| 1731 |
-
bool use_mul_mat_vec = src0->type == GGML_TYPE_F16
|
| 1732 |
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
| 1733 |
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
|
| 1734 |
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
|
@@ -2869,6 +2869,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 2869 |
case GGML_TYPE_IQ3_XXS:
|
| 2870 |
case GGML_TYPE_IQ4_NL:
|
| 2871 |
case GGML_TYPE_IQ4_XS:
|
|
|
|
| 2872 |
#ifdef GGML_USE_MUSA
|
| 2873 |
if (a->type == GGML_TYPE_Q3_K) {
|
| 2874 |
return false;
|
|
|
|
| 1728 |
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 1729 |
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
|
| 1730 |
|
| 1731 |
+
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
|
| 1732 |
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
| 1733 |
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
|
| 1734 |
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
|
|
|
| 2869 |
case GGML_TYPE_IQ3_XXS:
|
| 2870 |
case GGML_TYPE_IQ4_NL:
|
| 2871 |
case GGML_TYPE_IQ4_XS:
|
| 2872 |
+
case GGML_TYPE_BF16:
|
| 2873 |
#ifdef GGML_USE_MUSA
|
| 2874 |
if (a->type == GGML_TYPE_Q3_K) {
|
| 2875 |
return false;
|
ggml/src/ggml-cuda/mmv.cu
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
#include "common.cuh"
|
| 2 |
#include "mmv.cuh"
|
| 3 |
|
| 4 |
-
template <typename type_acc, int block_size>
|
| 5 |
static __global__ void mul_mat_vec(
|
| 6 |
-
const
|
| 7 |
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
|
| 8 |
const int64_t row = blockIdx.x;
|
| 9 |
const int64_t channel = blockIdx.z;
|
|
@@ -13,7 +13,6 @@ static __global__ void mul_mat_vec(
|
|
| 13 |
y += channel *stride_channel_y;
|
| 14 |
dst += channel *stride_channel_dst;
|
| 15 |
|
| 16 |
-
const half2 * x2 = (const half2 *) x;
|
| 17 |
const float2 * y2 = (const float2 *) y;
|
| 18 |
|
| 19 |
extern __shared__ char data_mmv[];
|
|
@@ -28,28 +27,44 @@ static __global__ void mul_mat_vec(
|
|
| 28 |
|
| 29 |
float sumf;
|
| 30 |
|
| 31 |
-
if (std::is_same<
|
| 32 |
-
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
| 41 |
#ifdef FP16_AVAILABLE
|
| 42 |
-
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
|
| 49 |
-
|
| 50 |
#else
|
| 51 |
-
|
| 52 |
#endif // FP16_AVAILABLE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
}
|
| 54 |
|
| 55 |
sumf = warp_reduce_sum(sumf);
|
|
@@ -71,9 +86,9 @@ static __global__ void mul_mat_vec(
|
|
| 71 |
dst[row] = sumf;
|
| 72 |
}
|
| 73 |
|
| 74 |
-
template <typename type_acc>
|
| 75 |
static void launch_mul_mat_vec_cuda(
|
| 76 |
-
const
|
| 77 |
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
|
| 78 |
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
| 79 |
cudaStream_t stream) {
|
|
@@ -97,35 +112,35 @@ static void launch_mul_mat_vec_cuda(
|
|
| 97 |
const dim3 block_dims(block_size_best, 1, 1);
|
| 98 |
switch (block_size_best) {
|
| 99 |
case 32: {
|
| 100 |
-
mul_mat_vec<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
|
| 101 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 102 |
} break;
|
| 103 |
case 64: {
|
| 104 |
-
mul_mat_vec<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
|
| 105 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 106 |
} break;
|
| 107 |
case 96: {
|
| 108 |
-
mul_mat_vec<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
|
| 109 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 110 |
} break;
|
| 111 |
case 128: {
|
| 112 |
-
mul_mat_vec<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
|
| 113 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 114 |
} break;
|
| 115 |
case 160: {
|
| 116 |
-
mul_mat_vec<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
|
| 117 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 118 |
} break;
|
| 119 |
case 192: {
|
| 120 |
-
mul_mat_vec<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
|
| 121 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 122 |
} break;
|
| 123 |
case 224: {
|
| 124 |
-
mul_mat_vec<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
|
| 125 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 126 |
} break;
|
| 127 |
case 256: {
|
| 128 |
-
mul_mat_vec<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
|
| 129 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 130 |
} break;
|
| 131 |
default: {
|
|
@@ -134,25 +149,25 @@ static void launch_mul_mat_vec_cuda(
|
|
| 134 |
}
|
| 135 |
}
|
| 136 |
|
|
|
|
| 137 |
static void mul_mat_vec_cuda(
|
| 138 |
-
const
|
| 139 |
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
|
| 140 |
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
| 141 |
enum ggml_prec prec, cudaStream_t stream) {
|
| 142 |
switch (prec) {
|
| 143 |
case GGML_PREC_DEFAULT: {
|
| 144 |
-
launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
|
| 145 |
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
|
| 146 |
} break;
|
| 147 |
case GGML_PREC_F32: {
|
| 148 |
-
launch_mul_mat_vec_cuda<float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
|
| 149 |
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
|
| 150 |
} break;
|
| 151 |
}
|
| 152 |
}
|
| 153 |
|
| 154 |
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 155 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 156 |
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 157 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 158 |
|
|
@@ -164,7 +179,6 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
|
| 164 |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 165 |
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
| 166 |
|
| 167 |
-
const half * src0_d = (const half *) src0->data;
|
| 168 |
const float * src1_d = (const float *) src1->data;
|
| 169 |
float * dst_d = (float *) dst->data;
|
| 170 |
|
|
@@ -181,7 +195,20 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
|
| 181 |
const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
|
| 182 |
const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
|
| 183 |
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
}
|
| 186 |
|
| 187 |
void ggml_cuda_op_mul_mat_vec(
|
|
@@ -190,7 +217,6 @@ void ggml_cuda_op_mul_mat_vec(
|
|
| 190 |
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
| 191 |
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
| 192 |
|
| 193 |
-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 194 |
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 195 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 196 |
|
|
@@ -211,8 +237,20 @@ void ggml_cuda_op_mul_mat_vec(
|
|
| 211 |
const int64_t channel_stride_y = 0;
|
| 212 |
const int64_t channel_stride_dst = 0;
|
| 213 |
|
| 214 |
-
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
GGML_UNUSED(ctx);
|
| 218 |
GGML_UNUSED(src1);
|
|
|
|
| 1 |
#include "common.cuh"
|
| 2 |
#include "mmv.cuh"
|
| 3 |
|
| 4 |
+
template <typename T, typename type_acc, int block_size>
|
| 5 |
static __global__ void mul_mat_vec(
|
| 6 |
+
const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
|
| 7 |
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
|
| 8 |
const int64_t row = blockIdx.x;
|
| 9 |
const int64_t channel = blockIdx.z;
|
|
|
|
| 13 |
y += channel *stride_channel_y;
|
| 14 |
dst += channel *stride_channel_dst;
|
| 15 |
|
|
|
|
| 16 |
const float2 * y2 = (const float2 *) y;
|
| 17 |
|
| 18 |
extern __shared__ char data_mmv[];
|
|
|
|
| 27 |
|
| 28 |
float sumf;
|
| 29 |
|
| 30 |
+
if constexpr (std::is_same<T, half>::value) {
|
| 31 |
+
const half2 * x2 = (const half2 *) x;
|
| 32 |
|
| 33 |
+
if (std::is_same<type_acc, float>::value) {
|
| 34 |
+
sumf = 0.0f;
|
| 35 |
+
|
| 36 |
+
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
| 37 |
+
const float2 tmpx = __half22float2(x2[col2]);
|
| 38 |
+
const float2 tmpy = y2[col2];
|
| 39 |
+
sumf += tmpx.x * tmpy.x;
|
| 40 |
+
sumf += tmpx.y * tmpy.y;
|
| 41 |
+
}
|
| 42 |
+
} else {
|
| 43 |
#ifdef FP16_AVAILABLE
|
| 44 |
+
half2 sumh2 = make_half2(0.0f, 0.0f);
|
| 45 |
|
| 46 |
+
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
| 47 |
+
const float2 tmp = y2[col2];
|
| 48 |
+
sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
|
| 49 |
+
}
|
| 50 |
|
| 51 |
+
sumf = __low2float(sumh2) + __high2float(sumh2);
|
| 52 |
#else
|
| 53 |
+
NO_DEVICE_CODE;
|
| 54 |
#endif // FP16_AVAILABLE
|
| 55 |
+
}
|
| 56 |
+
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
| 57 |
+
const int * x2 = (const int *) x;
|
| 58 |
+
sumf = 0.0f;
|
| 59 |
+
|
| 60 |
+
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
| 61 |
+
const int tmpx = x2[col2];
|
| 62 |
+
const float2 tmpy = y2[col2];
|
| 63 |
+
sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
| 64 |
+
sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
| 65 |
+
}
|
| 66 |
+
} else {
|
| 67 |
+
static_assert(std::is_same<T, void>::value, "unsupported type");
|
| 68 |
}
|
| 69 |
|
| 70 |
sumf = warp_reduce_sum(sumf);
|
|
|
|
| 86 |
dst[row] = sumf;
|
| 87 |
}
|
| 88 |
|
| 89 |
+
template <typename T, typename type_acc>
|
| 90 |
static void launch_mul_mat_vec_cuda(
|
| 91 |
+
const T * x, const float * y, float * dst,
|
| 92 |
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
|
| 93 |
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
| 94 |
cudaStream_t stream) {
|
|
|
|
| 112 |
const dim3 block_dims(block_size_best, 1, 1);
|
| 113 |
switch (block_size_best) {
|
| 114 |
case 32: {
|
| 115 |
+
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
|
| 116 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 117 |
} break;
|
| 118 |
case 64: {
|
| 119 |
+
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
|
| 120 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 121 |
} break;
|
| 122 |
case 96: {
|
| 123 |
+
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
|
| 124 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 125 |
} break;
|
| 126 |
case 128: {
|
| 127 |
+
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
|
| 128 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 129 |
} break;
|
| 130 |
case 160: {
|
| 131 |
+
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
|
| 132 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 133 |
} break;
|
| 134 |
case 192: {
|
| 135 |
+
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
|
| 136 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 137 |
} break;
|
| 138 |
case 224: {
|
| 139 |
+
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
|
| 140 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 141 |
} break;
|
| 142 |
case 256: {
|
| 143 |
+
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
|
| 144 |
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
|
| 145 |
} break;
|
| 146 |
default: {
|
|
|
|
| 149 |
}
|
| 150 |
}
|
| 151 |
|
| 152 |
+
template<typename T>
|
| 153 |
static void mul_mat_vec_cuda(
|
| 154 |
+
const T * x, const float * y, float * dst,
|
| 155 |
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
|
| 156 |
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
| 157 |
enum ggml_prec prec, cudaStream_t stream) {
|
| 158 |
switch (prec) {
|
| 159 |
case GGML_PREC_DEFAULT: {
|
| 160 |
+
launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
|
| 161 |
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
|
| 162 |
} break;
|
| 163 |
case GGML_PREC_F32: {
|
| 164 |
+
launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
|
| 165 |
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
|
| 166 |
} break;
|
| 167 |
}
|
| 168 |
}
|
| 169 |
|
| 170 |
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
|
|
| 171 |
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 172 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 173 |
|
|
|
|
| 179 |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 180 |
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
| 181 |
|
|
|
|
| 182 |
const float * src1_d = (const float *) src1->data;
|
| 183 |
float * dst_d = (float *) dst->data;
|
| 184 |
|
|
|
|
| 195 |
const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
|
| 196 |
const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
|
| 197 |
|
| 198 |
+
switch (src0->type) {
|
| 199 |
+
case GGML_TYPE_F16: {
|
| 200 |
+
const half * src0_d = (const half *) src0->data;
|
| 201 |
+
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
|
| 202 |
+
channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
|
| 203 |
+
} break;
|
| 204 |
+
case GGML_TYPE_BF16: {
|
| 205 |
+
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
| 206 |
+
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
|
| 207 |
+
channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
|
| 208 |
+
} break;
|
| 209 |
+
default:
|
| 210 |
+
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
| 211 |
+
}
|
| 212 |
}
|
| 213 |
|
| 214 |
void ggml_cuda_op_mul_mat_vec(
|
|
|
|
| 217 |
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
| 218 |
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
| 219 |
|
|
|
|
| 220 |
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 221 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 222 |
|
|
|
|
| 237 |
const int64_t channel_stride_y = 0;
|
| 238 |
const int64_t channel_stride_dst = 0;
|
| 239 |
|
| 240 |
+
switch (src0->type) {
|
| 241 |
+
case GGML_TYPE_F16: {
|
| 242 |
+
const half * src0_d = (const half *) src0_dd_i;
|
| 243 |
+
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
|
| 244 |
+
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
|
| 245 |
+
} break;
|
| 246 |
+
case GGML_TYPE_BF16: {
|
| 247 |
+
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
| 248 |
+
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
|
| 249 |
+
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
|
| 250 |
+
} break;
|
| 251 |
+
default:
|
| 252 |
+
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
| 253 |
+
}
|
| 254 |
|
| 255 |
GGML_UNUSED(ctx);
|
| 256 |
GGML_UNUSED(src1);
|
ggml/src/ggml-cuda/vendors/cuda.h
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
#include <cuda_runtime.h>
|
| 4 |
#include <cuda.h>
|
| 5 |
#include <cublas_v2.h>
|
|
|
|
| 6 |
#include <cuda_fp16.h>
|
| 7 |
|
| 8 |
#if CUDART_VERSION < 11020
|
|
|
|
| 3 |
#include <cuda_runtime.h>
|
| 4 |
#include <cuda.h>
|
| 5 |
#include <cublas_v2.h>
|
| 6 |
+
#include <cuda_bf16.h>
|
| 7 |
#include <cuda_fp16.h>
|
| 8 |
|
| 9 |
#if CUDART_VERSION < 11020
|
ggml/src/ggml-cuda/vendors/hip.h
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
#include <hip/hip_runtime.h>
|
| 4 |
#include <hipblas/hipblas.h>
|
| 5 |
#include <hip/hip_fp16.h>
|
|
|
|
| 6 |
#ifdef __HIP_PLATFORM_AMD__
|
| 7 |
// for rocblas_initialize()
|
| 8 |
#include "rocblas/rocblas.h"
|
|
@@ -121,6 +122,8 @@
|
|
| 121 |
#define __has_builtin(x) 0
|
| 122 |
#endif
|
| 123 |
|
|
|
|
|
|
|
| 124 |
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
| 125 |
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
| 126 |
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
|
|
|
| 3 |
#include <hip/hip_runtime.h>
|
| 4 |
#include <hipblas/hipblas.h>
|
| 5 |
#include <hip/hip_fp16.h>
|
| 6 |
+
#include <hip/hip_bfloat16.h>
|
| 7 |
#ifdef __HIP_PLATFORM_AMD__
|
| 8 |
// for rocblas_initialize()
|
| 9 |
#include "rocblas/rocblas.h"
|
|
|
|
| 122 |
#define __has_builtin(x) 0
|
| 123 |
#endif
|
| 124 |
|
| 125 |
+
typedef hip_bfloat16 nv_bfloat16;
|
| 126 |
+
|
| 127 |
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
| 128 |
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
| 129 |
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
ggml/src/ggml-cuda/vendors/musa.h
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
#include <musa_runtime.h>
|
| 4 |
#include <musa.h>
|
| 5 |
#include <mublas.h>
|
|
|
|
| 6 |
#include <musa_fp16.h>
|
| 7 |
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
| 8 |
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
|
@@ -132,3 +133,5 @@
|
|
| 132 |
#define cudaKernelNodeParams musaKernelNodeParams
|
| 133 |
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
| 134 |
#define cudaStreamEndCapture musaStreamEndCapture
|
|
|
|
|
|
|
|
|
| 3 |
#include <musa_runtime.h>
|
| 4 |
#include <musa.h>
|
| 5 |
#include <mublas.h>
|
| 6 |
+
#include <musa_bf16.h>
|
| 7 |
#include <musa_fp16.h>
|
| 8 |
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
| 9 |
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
|
|
|
| 133 |
#define cudaKernelNodeParams musaKernelNodeParams
|
| 134 |
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
| 135 |
#define cudaStreamEndCapture musaStreamEndCapture
|
| 136 |
+
|
| 137 |
+
typedef mt_bfloat16 nv_bfloat16;
|