JohannesGaessler commited on
Commit
961ef57
·
1 Parent(s): 4d90c3d

CUDA: add BF16 support (llama/11093)

Browse files
ggml/src/ggml-cuda/convert.cu CHANGED
@@ -680,6 +680,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
680
  return dequantize_row_iq3_s_cuda;
681
  case GGML_TYPE_F16:
682
  return convert_unary_cuda<half>;
 
 
683
  default:
684
  return nullptr;
685
  }
 
680
  return dequantize_row_iq3_s_cuda;
681
  case GGML_TYPE_F16:
682
  return convert_unary_cuda<half>;
683
+ case GGML_TYPE_BF16:
684
+ return convert_unary_cuda<nv_bfloat16>;
685
  default:
686
  return nullptr;
687
  }
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -1728,7 +1728,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1728
  static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1729
  const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
1730
 
1731
- bool use_mul_mat_vec = src0->type == GGML_TYPE_F16
1732
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1733
  && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
1734
  bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
@@ -2869,6 +2869,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
2869
  case GGML_TYPE_IQ3_XXS:
2870
  case GGML_TYPE_IQ4_NL:
2871
  case GGML_TYPE_IQ4_XS:
 
2872
  #ifdef GGML_USE_MUSA
2873
  if (a->type == GGML_TYPE_Q3_K) {
2874
  return false;
 
1728
  static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1729
  const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
1730
 
1731
+ bool use_mul_mat_vec = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
1732
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1733
  && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
1734
  bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
 
2869
  case GGML_TYPE_IQ3_XXS:
2870
  case GGML_TYPE_IQ4_NL:
2871
  case GGML_TYPE_IQ4_XS:
2872
+ case GGML_TYPE_BF16:
2873
  #ifdef GGML_USE_MUSA
2874
  if (a->type == GGML_TYPE_Q3_K) {
2875
  return false;
ggml/src/ggml-cuda/mmv.cu CHANGED
@@ -1,9 +1,9 @@
1
  #include "common.cuh"
2
  #include "mmv.cuh"
3
 
4
- template <typename type_acc, int block_size>
5
  static __global__ void mul_mat_vec(
6
- const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
7
  const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
8
  const int64_t row = blockIdx.x;
9
  const int64_t channel = blockIdx.z;
@@ -13,7 +13,6 @@ static __global__ void mul_mat_vec(
13
  y += channel *stride_channel_y;
14
  dst += channel *stride_channel_dst;
15
 
16
- const half2 * x2 = (const half2 *) x;
17
  const float2 * y2 = (const float2 *) y;
18
 
19
  extern __shared__ char data_mmv[];
@@ -28,28 +27,44 @@ static __global__ void mul_mat_vec(
28
 
29
  float sumf;
30
 
31
- if (std::is_same<type_acc, float>::value) {
32
- sumf = 0.0f;
33
 
34
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
35
- const float2 tmpx = __half22float2(x2[col2]);
36
- const float2 tmpy = y2[col2];
37
- sumf += tmpx.x * tmpy.x;
38
- sumf += tmpx.y * tmpy.y;
39
- }
40
- } else {
 
 
 
41
  #ifdef FP16_AVAILABLE
42
- half2 sumh2 = make_half2(0.0f, 0.0f);
43
 
44
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
45
- const float2 tmp = y2[col2];
46
- sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
47
- }
48
 
49
- sumf = __low2float(sumh2) + __high2float(sumh2);
50
  #else
51
- NO_DEVICE_CODE;
52
  #endif // FP16_AVAILABLE
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  }
54
 
55
  sumf = warp_reduce_sum(sumf);
@@ -71,9 +86,9 @@ static __global__ void mul_mat_vec(
71
  dst[row] = sumf;
72
  }
73
 
74
- template <typename type_acc>
75
  static void launch_mul_mat_vec_cuda(
76
- const half * x, const float * y, float * dst,
77
  const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
78
  const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
79
  cudaStream_t stream) {
@@ -97,35 +112,35 @@ static void launch_mul_mat_vec_cuda(
97
  const dim3 block_dims(block_size_best, 1, 1);
98
  switch (block_size_best) {
99
  case 32: {
100
- mul_mat_vec<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
101
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
102
  } break;
103
  case 64: {
104
- mul_mat_vec<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
105
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
106
  } break;
107
  case 96: {
108
- mul_mat_vec<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
109
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
110
  } break;
111
  case 128: {
112
- mul_mat_vec<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
113
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
114
  } break;
115
  case 160: {
116
- mul_mat_vec<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
117
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
118
  } break;
119
  case 192: {
120
- mul_mat_vec<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
121
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
122
  } break;
123
  case 224: {
124
- mul_mat_vec<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
125
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
126
  } break;
127
  case 256: {
128
- mul_mat_vec<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
129
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
130
  } break;
131
  default: {
@@ -134,25 +149,25 @@ static void launch_mul_mat_vec_cuda(
134
  }
135
  }
136
 
 
137
  static void mul_mat_vec_cuda(
138
- const half * x, const float * y, float * dst,
139
  const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
140
  const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
141
  enum ggml_prec prec, cudaStream_t stream) {
142
  switch (prec) {
143
  case GGML_PREC_DEFAULT: {
144
- launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
145
  stride_channel_x, stride_channel_y, stride_channel_dst, stream);
146
  } break;
147
  case GGML_PREC_F32: {
148
- launch_mul_mat_vec_cuda<float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
149
  stride_channel_x, stride_channel_y, stride_channel_dst, stream);
150
  } break;
151
  }
152
  }
153
 
154
  void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
155
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
156
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
157
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
158
 
@@ -164,7 +179,6 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
164
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
165
  const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
166
 
167
- const half * src0_d = (const half *) src0->data;
168
  const float * src1_d = (const float *) src1->data;
169
  float * dst_d = (float *) dst->data;
170
 
@@ -181,7 +195,20 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
181
  const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
182
  const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
183
 
184
- mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  }
186
 
187
  void ggml_cuda_op_mul_mat_vec(
@@ -190,7 +217,6 @@ void ggml_cuda_op_mul_mat_vec(
190
  const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
191
  const int64_t src1_padded_row_size, cudaStream_t stream) {
192
 
193
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
194
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
195
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
196
 
@@ -211,8 +237,20 @@ void ggml_cuda_op_mul_mat_vec(
211
  const int64_t channel_stride_y = 0;
212
  const int64_t channel_stride_dst = 0;
213
 
214
- mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
215
- nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  GGML_UNUSED(ctx);
218
  GGML_UNUSED(src1);
 
1
  #include "common.cuh"
2
  #include "mmv.cuh"
3
 
4
+ template <typename T, typename type_acc, int block_size>
5
  static __global__ void mul_mat_vec(
6
+ const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
7
  const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
8
  const int64_t row = blockIdx.x;
9
  const int64_t channel = blockIdx.z;
 
13
  y += channel *stride_channel_y;
14
  dst += channel *stride_channel_dst;
15
 
 
16
  const float2 * y2 = (const float2 *) y;
17
 
18
  extern __shared__ char data_mmv[];
 
27
 
28
  float sumf;
29
 
30
+ if constexpr (std::is_same<T, half>::value) {
31
+ const half2 * x2 = (const half2 *) x;
32
 
33
+ if (std::is_same<type_acc, float>::value) {
34
+ sumf = 0.0f;
35
+
36
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
37
+ const float2 tmpx = __half22float2(x2[col2]);
38
+ const float2 tmpy = y2[col2];
39
+ sumf += tmpx.x * tmpy.x;
40
+ sumf += tmpx.y * tmpy.y;
41
+ }
42
+ } else {
43
  #ifdef FP16_AVAILABLE
44
+ half2 sumh2 = make_half2(0.0f, 0.0f);
45
 
46
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
47
+ const float2 tmp = y2[col2];
48
+ sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
49
+ }
50
 
51
+ sumf = __low2float(sumh2) + __high2float(sumh2);
52
  #else
53
+ NO_DEVICE_CODE;
54
  #endif // FP16_AVAILABLE
55
+ }
56
+ } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
57
+ const int * x2 = (const int *) x;
58
+ sumf = 0.0f;
59
+
60
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
61
+ const int tmpx = x2[col2];
62
+ const float2 tmpy = y2[col2];
63
+ sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
64
+ sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
65
+ }
66
+ } else {
67
+ static_assert(std::is_same<T, void>::value, "unsupported type");
68
  }
69
 
70
  sumf = warp_reduce_sum(sumf);
 
86
  dst[row] = sumf;
87
  }
88
 
89
+ template <typename T, typename type_acc>
90
  static void launch_mul_mat_vec_cuda(
91
+ const T * x, const float * y, float * dst,
92
  const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
93
  const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
94
  cudaStream_t stream) {
 
112
  const dim3 block_dims(block_size_best, 1, 1);
113
  switch (block_size_best) {
114
  case 32: {
115
+ mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
116
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
117
  } break;
118
  case 64: {
119
+ mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
120
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
121
  } break;
122
  case 96: {
123
+ mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
124
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
125
  } break;
126
  case 128: {
127
+ mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
128
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
129
  } break;
130
  case 160: {
131
+ mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
132
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
133
  } break;
134
  case 192: {
135
+ mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
136
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
137
  } break;
138
  case 224: {
139
+ mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
140
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
141
  } break;
142
  case 256: {
143
+ mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
144
  (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
145
  } break;
146
  default: {
 
149
  }
150
  }
151
 
152
+ template<typename T>
153
  static void mul_mat_vec_cuda(
154
+ const T * x, const float * y, float * dst,
155
  const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
156
  const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
157
  enum ggml_prec prec, cudaStream_t stream) {
158
  switch (prec) {
159
  case GGML_PREC_DEFAULT: {
160
+ launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
161
  stride_channel_x, stride_channel_y, stride_channel_dst, stream);
162
  } break;
163
  case GGML_PREC_F32: {
164
+ launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
165
  stride_channel_x, stride_channel_y, stride_channel_dst, stream);
166
  } break;
167
  }
168
  }
169
 
170
  void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 
171
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
172
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
173
 
 
179
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
180
  const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
181
 
 
182
  const float * src1_d = (const float *) src1->data;
183
  float * dst_d = (float *) dst->data;
184
 
 
195
  const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
196
  const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
197
 
198
+ switch (src0->type) {
199
+ case GGML_TYPE_F16: {
200
+ const half * src0_d = (const half *) src0->data;
201
+ mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
202
+ channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
203
+ } break;
204
+ case GGML_TYPE_BF16: {
205
+ const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
206
+ mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
207
+ channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
208
+ } break;
209
+ default:
210
+ GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
211
+ }
212
  }
213
 
214
  void ggml_cuda_op_mul_mat_vec(
 
217
  const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
218
  const int64_t src1_padded_row_size, cudaStream_t stream) {
219
 
 
220
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
221
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
222
 
 
237
  const int64_t channel_stride_y = 0;
238
  const int64_t channel_stride_dst = 0;
239
 
240
+ switch (src0->type) {
241
+ case GGML_TYPE_F16: {
242
+ const half * src0_d = (const half *) src0_dd_i;
243
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
244
+ nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
245
+ } break;
246
+ case GGML_TYPE_BF16: {
247
+ const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
248
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
249
+ nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
250
+ } break;
251
+ default:
252
+ GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
253
+ }
254
 
255
  GGML_UNUSED(ctx);
256
  GGML_UNUSED(src1);
ggml/src/ggml-cuda/vendors/cuda.h CHANGED
@@ -3,6 +3,7 @@
3
  #include <cuda_runtime.h>
4
  #include <cuda.h>
5
  #include <cublas_v2.h>
 
6
  #include <cuda_fp16.h>
7
 
8
  #if CUDART_VERSION < 11020
 
3
  #include <cuda_runtime.h>
4
  #include <cuda.h>
5
  #include <cublas_v2.h>
6
+ #include <cuda_bf16.h>
7
  #include <cuda_fp16.h>
8
 
9
  #if CUDART_VERSION < 11020
ggml/src/ggml-cuda/vendors/hip.h CHANGED
@@ -3,6 +3,7 @@
3
  #include <hip/hip_runtime.h>
4
  #include <hipblas/hipblas.h>
5
  #include <hip/hip_fp16.h>
 
6
  #ifdef __HIP_PLATFORM_AMD__
7
  // for rocblas_initialize()
8
  #include "rocblas/rocblas.h"
@@ -121,6 +122,8 @@
121
  #define __has_builtin(x) 0
122
  #endif
123
 
 
 
124
  typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
125
  typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
126
  static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
 
3
  #include <hip/hip_runtime.h>
4
  #include <hipblas/hipblas.h>
5
  #include <hip/hip_fp16.h>
6
+ #include <hip/hip_bfloat16.h>
7
  #ifdef __HIP_PLATFORM_AMD__
8
  // for rocblas_initialize()
9
  #include "rocblas/rocblas.h"
 
122
  #define __has_builtin(x) 0
123
  #endif
124
 
125
+ typedef hip_bfloat16 nv_bfloat16;
126
+
127
  typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
128
  typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
129
  static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
ggml/src/ggml-cuda/vendors/musa.h CHANGED
@@ -3,6 +3,7 @@
3
  #include <musa_runtime.h>
4
  #include <musa.h>
5
  #include <mublas.h>
 
6
  #include <musa_fp16.h>
7
  #define CUBLAS_COMPUTE_16F CUDA_R_16F
8
  #define CUBLAS_COMPUTE_32F CUDA_R_32F
@@ -132,3 +133,5 @@
132
  #define cudaKernelNodeParams musaKernelNodeParams
133
  #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
134
  #define cudaStreamEndCapture musaStreamEndCapture
 
 
 
3
  #include <musa_runtime.h>
4
  #include <musa.h>
5
  #include <mublas.h>
6
+ #include <musa_bf16.h>
7
  #include <musa_fp16.h>
8
  #define CUBLAS_COMPUTE_16F CUDA_R_16F
9
  #define CUBLAS_COMPUTE_32F CUDA_R_32F
 
133
  #define cudaKernelNodeParams musaKernelNodeParams
134
  #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
135
  #define cudaStreamEndCapture musaStreamEndCapture
136
+
137
+ typedef mt_bfloat16 nv_bfloat16;