Spaces:
Sleeping
Sleeping
CUDA: add dynamic shared mem to softmax, refactor general usage (llama/14497)
Browse files- ggml/src/ggml-cuda/common.cuh +14 -0
- ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- ggml/src/ggml-cuda/mmq.cuh +2 -8
- ggml/src/ggml-cuda/softmax.cu +38 -40
ggml/src/ggml-cuda/common.cuh
CHANGED
|
@@ -175,6 +175,20 @@ static const char * cu_get_error_str(CUresult err) {
|
|
| 175 |
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
| 176 |
#endif
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
| 179 |
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
| 180 |
#else
|
|
|
|
| 175 |
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
| 176 |
#endif
|
| 177 |
|
| 178 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 179 |
+
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
| 180 |
+
do { \
|
| 181 |
+
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \
|
| 182 |
+
const int id = ggml_cuda_get_device(); \
|
| 183 |
+
if (!shared_memory_limit_raised[id]) { \
|
| 184 |
+
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
|
| 185 |
+
shared_memory_limit_raised[id] = true; \
|
| 186 |
+
} \
|
| 187 |
+
} while (0)
|
| 188 |
+
#else
|
| 189 |
+
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
|
| 190 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 191 |
+
|
| 192 |
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
| 193 |
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
| 194 |
#else
|
ggml/src/ggml-cuda/cross-entropy-loss.cu
CHANGED
|
@@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 123 |
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
| 124 |
|
| 125 |
if (nbytes_shared <= smpbo) {
|
| 126 |
-
|
| 127 |
-
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
| 128 |
-
if (!shared_memory_limit_raised[id]) {
|
| 129 |
-
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
| 130 |
-
shared_memory_limit_raised[id] = true;
|
| 131 |
-
}
|
| 132 |
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 133 |
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
| 134 |
} else {
|
| 135 |
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
|
@@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
|
|
| 175 |
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
| 176 |
|
| 177 |
if (nbytes_shared <= smpbo) {
|
| 178 |
-
|
| 179 |
-
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
| 180 |
-
if (!shared_memory_limit_raised[id]) {
|
| 181 |
-
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
| 182 |
-
shared_memory_limit_raised[id] = true;
|
| 183 |
-
}
|
| 184 |
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 185 |
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
| 186 |
} else {
|
| 187 |
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
|
|
|
| 123 |
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
| 124 |
|
| 125 |
if (nbytes_shared <= smpbo) {
|
| 126 |
+
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
| 128 |
} else {
|
| 129 |
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
|
|
|
| 169 |
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
| 170 |
|
| 171 |
if (nbytes_shared <= smpbo) {
|
| 172 |
+
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
| 174 |
} else {
|
| 175 |
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
ggml/src/ggml-cuda/mmq.cuh
CHANGED
|
@@ -3016,14 +3016,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
| 3016 |
|
| 3017 |
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
|
| 3018 |
|
| 3019 |
-
|
| 3020 |
-
|
| 3021 |
-
if (!shared_memory_limit_raised[id]) {
|
| 3022 |
-
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
|
| 3023 |
-
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
|
| 3024 |
-
shared_memory_limit_raised[id] = true;
|
| 3025 |
-
}
|
| 3026 |
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
| 3027 |
|
| 3028 |
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
|
| 3029 |
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
|
|
|
|
| 3016 |
|
| 3017 |
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
|
| 3018 |
|
| 3019 |
+
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, false>), nbytes_shared);
|
| 3020 |
+
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, true>), nbytes_shared);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3021 |
|
| 3022 |
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
|
| 3023 |
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
|
ggml/src/ggml-cuda/softmax.cu
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
#include "ggml.h"
|
| 3 |
#include "softmax.cuh"
|
| 4 |
#include <cstdint>
|
|
|
|
| 5 |
|
| 6 |
template <typename T>
|
| 7 |
static __device__ __forceinline__ float t2f32(T val) {
|
|
@@ -181,6 +182,37 @@ static __global__ void soft_max_back_f32(
|
|
| 181 |
}
|
| 182 |
}
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
template<typename T>
|
| 185 |
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
| 186 |
int nth = WARP_SIZE;
|
|
@@ -193,46 +225,12 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
|
| 193 |
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
| 194 |
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
break;
|
| 203 |
-
case 64:
|
| 204 |
-
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 205 |
-
(x, mask, dst, params);
|
| 206 |
-
break;
|
| 207 |
-
case 128:
|
| 208 |
-
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 209 |
-
(x, mask, dst, params);
|
| 210 |
-
break;
|
| 211 |
-
case 256:
|
| 212 |
-
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 213 |
-
(x, mask, dst, params);
|
| 214 |
-
break;
|
| 215 |
-
case 512:
|
| 216 |
-
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 217 |
-
(x, mask, dst, params);
|
| 218 |
-
break;
|
| 219 |
-
case 1024:
|
| 220 |
-
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 221 |
-
(x, mask, dst, params);
|
| 222 |
-
break;
|
| 223 |
-
case 2048:
|
| 224 |
-
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 225 |
-
(x, mask, dst, params);
|
| 226 |
-
break;
|
| 227 |
-
case 4096:
|
| 228 |
-
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 229 |
-
(x, mask, dst, params);
|
| 230 |
-
break;
|
| 231 |
-
default:
|
| 232 |
-
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 233 |
-
(x, mask, dst, params);
|
| 234 |
-
break;
|
| 235 |
-
}
|
| 236 |
} else {
|
| 237 |
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
| 238 |
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
|
|
|
|
| 2 |
#include "ggml.h"
|
| 3 |
#include "softmax.cuh"
|
| 4 |
#include <cstdint>
|
| 5 |
+
#include <utility>
|
| 6 |
|
| 7 |
template <typename T>
|
| 8 |
static __device__ __forceinline__ float t2f32(T val) {
|
|
|
|
| 182 |
}
|
| 183 |
}
|
| 184 |
|
| 185 |
+
template<int... Ns, typename T>
|
| 186 |
+
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
|
| 187 |
+
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
|
| 188 |
+
{
|
| 189 |
+
const int id = ggml_cuda_get_device();
|
| 190 |
+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
| 191 |
+
|
| 192 |
+
auto launch_kernel = [=](auto I) -> bool {
|
| 193 |
+
constexpr int ncols = decltype(I)::value;
|
| 194 |
+
constexpr int block = (ncols > 1024 ? 1024 : ncols);
|
| 195 |
+
|
| 196 |
+
if (p.ncols == ncols) {
|
| 197 |
+
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
|
| 198 |
+
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 199 |
+
(x, mask, dst, p);
|
| 200 |
+
return true;
|
| 201 |
+
}
|
| 202 |
+
return false;
|
| 203 |
+
};
|
| 204 |
+
|
| 205 |
+
// unary fold over launch_kernel
|
| 206 |
+
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
|
| 207 |
+
return;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
//default case
|
| 211 |
+
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
|
| 212 |
+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
|
| 216 |
template<typename T>
|
| 217 |
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
| 218 |
int nth = WARP_SIZE;
|
|
|
|
| 225 |
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
| 226 |
|
| 227 |
|
| 228 |
+
const int id = ggml_cuda_get_device();
|
| 229 |
+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
if (nbytes_shared <= smpbo) {
|
| 233 |
+
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
} else {
|
| 235 |
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
| 236 |
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
|