Spaces:
Running
Running
Commit
·
1d24833
1
Parent(s):
67ec576
CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16 (llama/15131)
Browse files- ggml/src/ggml-cuda/common.cuh +10 -2
- ggml/src/ggml-cuda/fattn-mma-f16.cuh +6 -6
- ggml/src/ggml-cuda/fattn.cu +2 -2
- ggml/src/ggml-cuda/ggml-cuda.cu +20 -11
- ggml/src/ggml-cuda/mma.cuh +88 -22
- ggml/src/ggml-cuda/mmf.cu +431 -0
- ggml/src/ggml-cuda/mmf.cuh +5 -0
- ggml/src/ggml-cuda/mmq.cu +1 -1
- ggml/src/ggml-cuda/mmq.cuh +132 -132
- ggml/src/ggml-cuda/mmvf.cu +510 -0
- ggml/src/ggml-cuda/mmvf.cuh +11 -0
- ggml/src/ggml-cuda/vendors/hip.h +1 -0
- ggml/src/ggml-cuda/vendors/musa.h +2 -1
ggml/src/ggml-cuda/common.cuh
CHANGED
|
@@ -233,9 +233,13 @@ typedef float2 dfloat2;
|
|
| 233 |
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
|
| 234 |
|
| 235 |
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
| 236 |
-
#define
|
| 237 |
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 240 |
#define CP_ASYNC_AVAILABLE
|
| 241 |
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
@@ -303,10 +307,14 @@ static bool amd_mfma_available(const int cc) {
|
|
| 303 |
}
|
| 304 |
|
| 305 |
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
| 306 |
-
static bool
|
| 307 |
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
| 308 |
}
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
static bool cp_async_available(const int cc) {
|
| 311 |
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
|
| 312 |
}
|
|
|
|
| 233 |
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
|
| 234 |
|
| 235 |
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
| 236 |
+
#define TURING_MMA_AVAILABLE
|
| 237 |
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
| 238 |
|
| 239 |
+
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 240 |
+
#define AMPERE_MMA_AVAILABLE
|
| 241 |
+
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 242 |
+
|
| 243 |
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 244 |
#define CP_ASYNC_AVAILABLE
|
| 245 |
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
|
|
|
| 307 |
}
|
| 308 |
|
| 309 |
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
| 310 |
+
static bool turing_mma_available(const int cc) {
|
| 311 |
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
| 312 |
}
|
| 313 |
|
| 314 |
+
static bool ampere_mma_available(const int cc) {
|
| 315 |
+
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
static bool cp_async_available(const int cc) {
|
| 319 |
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
|
| 320 |
}
|
ggml/src/ggml-cuda/fattn-mma-f16.cuh
CHANGED
|
@@ -418,7 +418,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 418 |
float * const __restrict__ KQ_max,
|
| 419 |
float * const __restrict__ KQ_rowsum,
|
| 420 |
const int kb0) {
|
| 421 |
-
#ifdef
|
| 422 |
typedef fattn_mma_f16_config<DKQ, DV> c;
|
| 423 |
|
| 424 |
#ifdef CP_ASYNC_AVAILABLE
|
|
@@ -776,7 +776,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 776 |
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
|
| 777 |
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
|
| 778 |
NO_DEVICE_CODE;
|
| 779 |
-
#endif //
|
| 780 |
}
|
| 781 |
|
| 782 |
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
|
@@ -800,7 +800,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 800 |
const int jt,
|
| 801 |
const int kb0_start,
|
| 802 |
const int kb0_stop) {
|
| 803 |
-
#ifdef
|
| 804 |
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 805 |
|
| 806 |
typedef fattn_mma_f16_config<DKQ, DV> c;
|
|
@@ -1196,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 1196 |
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
|
| 1197 |
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
|
| 1198 |
NO_DEVICE_CODE;
|
| 1199 |
-
#endif //
|
| 1200 |
}
|
| 1201 |
|
| 1202 |
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
|
|
@@ -1223,7 +1223,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1223 |
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
| 1224 |
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
| 1225 |
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
| 1226 |
-
#if defined(FLASH_ATTN_AVAILABLE) && defined(
|
| 1227 |
|
| 1228 |
// Skip unused kernel variants for faster compilation:
|
| 1229 |
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
|
@@ -1354,7 +1354,7 @@ static __global__ void flash_attn_ext_f16(
|
|
| 1354 |
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
| 1355 |
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
| 1356 |
NO_DEVICE_CODE;
|
| 1357 |
-
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(
|
| 1358 |
}
|
| 1359 |
|
| 1360 |
template <int DKQ, int DV, int ncols1, int ncols2>
|
|
|
|
| 418 |
float * const __restrict__ KQ_max,
|
| 419 |
float * const __restrict__ KQ_rowsum,
|
| 420 |
const int kb0) {
|
| 421 |
+
#ifdef TURING_MMA_AVAILABLE
|
| 422 |
typedef fattn_mma_f16_config<DKQ, DV> c;
|
| 423 |
|
| 424 |
#ifdef CP_ASYNC_AVAILABLE
|
|
|
|
| 776 |
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
|
| 777 |
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
|
| 778 |
NO_DEVICE_CODE;
|
| 779 |
+
#endif // TURING_MMA_AVAILABLE
|
| 780 |
}
|
| 781 |
|
| 782 |
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
|
|
|
| 800 |
const int jt,
|
| 801 |
const int kb0_start,
|
| 802 |
const int kb0_stop) {
|
| 803 |
+
#ifdef TURING_MMA_AVAILABLE
|
| 804 |
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 805 |
|
| 806 |
typedef fattn_mma_f16_config<DKQ, DV> c;
|
|
|
|
| 1196 |
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
|
| 1197 |
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
|
| 1198 |
NO_DEVICE_CODE;
|
| 1199 |
+
#endif // TURING_MMA_AVAILABLE
|
| 1200 |
}
|
| 1201 |
|
| 1202 |
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
|
|
|
|
| 1223 |
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
| 1224 |
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
| 1225 |
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
| 1226 |
+
#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
|
| 1227 |
|
| 1228 |
// Skip unused kernel variants for faster compilation:
|
| 1229 |
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
|
|
|
| 1354 |
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
| 1355 |
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
| 1356 |
NO_DEVICE_CODE;
|
| 1357 |
+
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
|
| 1358 |
}
|
| 1359 |
|
| 1360 |
template <int DKQ, int DV, int ncols1, int ncols2>
|
ggml/src/ggml-cuda/fattn.cu
CHANGED
|
@@ -327,7 +327,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 327 |
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
| 328 |
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
| 329 |
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
|
| 330 |
-
const bool mma_faster_for_bs1 =
|
| 331 |
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
|
| 332 |
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
|
| 333 |
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
|
@@ -340,7 +340,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 340 |
}
|
| 341 |
|
| 342 |
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
| 343 |
-
if (fp16_mma_available(cc) && !
|
| 344 |
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
| 345 |
return;
|
| 346 |
}
|
|
|
|
| 327 |
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
| 328 |
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
| 329 |
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
|
| 330 |
+
const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
|
| 331 |
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
|
| 332 |
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
|
| 333 |
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
|
|
|
| 340 |
}
|
| 341 |
|
| 342 |
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
| 343 |
+
if (fp16_mma_available(cc) && !turing_mma_available(cc)) {
|
| 344 |
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
| 345 |
return;
|
| 346 |
}
|
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -22,8 +22,9 @@
|
|
| 22 |
#include "ggml-cuda/fattn.cuh"
|
| 23 |
#include "ggml-cuda/getrows.cuh"
|
| 24 |
#include "ggml-cuda/im2col.cuh"
|
|
|
|
| 25 |
#include "ggml-cuda/mmq.cuh"
|
| 26 |
-
#include "ggml-cuda/
|
| 27 |
#include "ggml-cuda/mmvq.cuh"
|
| 28 |
#include "ggml-cuda/norm.cuh"
|
| 29 |
#include "ggml-cuda/opt-step-adamw.cuh"
|
|
@@ -2008,7 +2009,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
| 2008 |
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
|
| 2009 |
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
|
| 2010 |
|
| 2011 |
-
bool
|
|
|
|
|
|
|
| 2012 |
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
| 2013 |
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
| 2014 |
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
|
@@ -2028,14 +2031,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
| 2028 |
}
|
| 2029 |
|
| 2030 |
const int cc = ggml_cuda_info().devices[id].cc;
|
|
|
|
| 2031 |
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
| 2032 |
-
|
|
|
|
| 2033 |
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
| 2034 |
}
|
| 2035 |
} else {
|
| 2036 |
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
|
|
|
| 2037 |
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
| 2038 |
-
|
|
|
|
| 2039 |
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
| 2040 |
}
|
| 2041 |
|
|
@@ -2048,15 +2055,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
| 2048 |
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
| 2049 |
|
| 2050 |
//TODO update for generic tensor parallelism
|
| 2051 |
-
const int cc
|
| 2052 |
bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
|
| 2053 |
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
|
| 2054 |
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
|
| 2055 |
|
| 2056 |
-
if (!split &&
|
| 2057 |
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
|
| 2058 |
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
|
| 2059 |
-
|
|
|
|
|
|
|
| 2060 |
} else if (!split && use_mul_mat_vec_q) {
|
| 2061 |
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
|
| 2062 |
} else if (!split && use_mul_mat_q) {
|
|
@@ -2065,8 +2074,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
| 2065 |
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
| 2066 |
// general KQ + KQV multi-batch without FlashAttention
|
| 2067 |
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
|
| 2068 |
-
} else if (
|
| 2069 |
-
ggml_cuda_op_mul_mat(ctx, src0, src1, dst,
|
| 2070 |
} else if (use_mul_mat_vec_q) {
|
| 2071 |
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
|
| 2072 |
} else if (use_mul_mat_q) {
|
|
@@ -2094,7 +2103,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 2094 |
if (ggml_is_quantized(src0->type)) {
|
| 2095 |
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
| 2096 |
} else {
|
| 2097 |
-
|
| 2098 |
}
|
| 2099 |
return;
|
| 2100 |
}
|
|
@@ -3516,7 +3525,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3516 |
#endif // FLASH_ATTN_AVAILABLE
|
| 3517 |
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
| 3518 |
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
| 3519 |
-
if (!
|
| 3520 |
return false;
|
| 3521 |
}
|
| 3522 |
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
|
|
|
| 22 |
#include "ggml-cuda/fattn.cuh"
|
| 23 |
#include "ggml-cuda/getrows.cuh"
|
| 24 |
#include "ggml-cuda/im2col.cuh"
|
| 25 |
+
#include "ggml-cuda/mmf.cuh"
|
| 26 |
#include "ggml-cuda/mmq.cuh"
|
| 27 |
+
#include "ggml-cuda/mmvf.cuh"
|
| 28 |
#include "ggml-cuda/mmvq.cuh"
|
| 29 |
#include "ggml-cuda/norm.cuh"
|
| 30 |
#include "ggml-cuda/opt-step-adamw.cuh"
|
|
|
|
| 2009 |
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
|
| 2010 |
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
|
| 2011 |
|
| 2012 |
+
bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
|
| 2013 |
+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
| 2014 |
+
bool use_mul_mat_f = !ggml_is_quantized(src0->type)
|
| 2015 |
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
| 2016 |
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
| 2017 |
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
|
|
|
| 2031 |
}
|
| 2032 |
|
| 2033 |
const int cc = ggml_cuda_info().devices[id].cc;
|
| 2034 |
+
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
| 2035 |
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
| 2036 |
+
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
|
| 2037 |
+
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
|
| 2038 |
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
| 2039 |
}
|
| 2040 |
} else {
|
| 2041 |
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
| 2042 |
+
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
|
| 2043 |
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
| 2044 |
+
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
|
| 2045 |
+
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
|
| 2046 |
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
| 2047 |
}
|
| 2048 |
|
|
|
|
| 2055 |
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
| 2056 |
|
| 2057 |
//TODO update for generic tensor parallelism
|
| 2058 |
+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 2059 |
bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
|
| 2060 |
bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
|
| 2061 |
bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
|
| 2062 |
|
| 2063 |
+
if (!split && use_mul_mat_vec_f) {
|
| 2064 |
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
|
| 2065 |
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
|
| 2066 |
+
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);
|
| 2067 |
+
} else if (!split && use_mul_mat_f) {
|
| 2068 |
+
ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);
|
| 2069 |
} else if (!split && use_mul_mat_vec_q) {
|
| 2070 |
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
|
| 2071 |
} else if (!split && use_mul_mat_q) {
|
|
|
|
| 2074 |
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
| 2075 |
// general KQ + KQV multi-batch without FlashAttention
|
| 2076 |
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
|
| 2077 |
+
} else if (use_mul_mat_vec_f) {
|
| 2078 |
+
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);
|
| 2079 |
} else if (use_mul_mat_vec_q) {
|
| 2080 |
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
|
| 2081 |
} else if (use_mul_mat_q) {
|
|
|
|
| 2103 |
if (ggml_is_quantized(src0->type)) {
|
| 2104 |
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
| 2105 |
} else {
|
| 2106 |
+
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
|
| 2107 |
}
|
| 2108 |
return;
|
| 2109 |
}
|
|
|
|
| 3525 |
#endif // FLASH_ATTN_AVAILABLE
|
| 3526 |
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
| 3527 |
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
| 3528 |
+
if (!turing_mma_available(cc)) {
|
| 3529 |
return false;
|
| 3530 |
}
|
| 3531 |
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
ggml/src/ggml-cuda/mma.cuh
CHANGED
|
@@ -23,13 +23,13 @@
|
|
| 23 |
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
| 24 |
int ret = 0;
|
| 25 |
|
| 26 |
-
#ifdef
|
| 27 |
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
| 28 |
: "=r"(ret) : "r"(x));
|
| 29 |
#else
|
| 30 |
GGML_UNUSED(x);
|
| 31 |
NO_DEVICE_CODE;
|
| 32 |
-
#endif // defined(
|
| 33 |
return ret;
|
| 34 |
}
|
| 35 |
|
|
@@ -167,6 +167,38 @@ namespace ggml_cuda_mma {
|
|
| 167 |
}
|
| 168 |
};
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
template <int I, int J>
|
| 171 |
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
| 172 |
tile<I, J/2, half2> ret;
|
|
@@ -209,7 +241,7 @@ namespace ggml_cuda_mma {
|
|
| 209 |
template <typename T>
|
| 210 |
static __device__ __forceinline__ void load_ldmatrix(
|
| 211 |
tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
| 212 |
-
#ifdef
|
| 213 |
int * xi = (int *) t.x;
|
| 214 |
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
|
| 215 |
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
|
@@ -217,13 +249,13 @@ namespace ggml_cuda_mma {
|
|
| 217 |
: "l"(xs));
|
| 218 |
#else
|
| 219 |
load_generic(t, xs0, stride);
|
| 220 |
-
#endif //
|
| 221 |
}
|
| 222 |
|
| 223 |
template <typename T>
|
| 224 |
static __device__ __forceinline__ void load_ldmatrix(
|
| 225 |
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
|
| 226 |
-
#ifdef
|
| 227 |
int * xi = (int *) t.x;
|
| 228 |
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
|
| 229 |
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
|
@@ -232,13 +264,13 @@ namespace ggml_cuda_mma {
|
|
| 232 |
#else
|
| 233 |
load_generic(xs0, stride);
|
| 234 |
GGML_UNUSED(t);
|
| 235 |
-
#endif //
|
| 236 |
}
|
| 237 |
|
| 238 |
template <typename T>
|
| 239 |
static __device__ __forceinline__ void load_ldmatrix(
|
| 240 |
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
| 241 |
-
#if defined(
|
| 242 |
int * xi = (int * ) t.x;
|
| 243 |
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
| 244 |
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
|
@@ -246,13 +278,13 @@ namespace ggml_cuda_mma {
|
|
| 246 |
: "l"(xs));
|
| 247 |
#else
|
| 248 |
load_generic(t, xs0, stride);
|
| 249 |
-
#endif //
|
| 250 |
}
|
| 251 |
|
| 252 |
template <typename T>
|
| 253 |
static __device__ __forceinline__ void load_ldmatrix_trans(
|
| 254 |
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
| 255 |
-
#ifdef
|
| 256 |
int * xi = (int * ) t.x;
|
| 257 |
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
| 258 |
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
|
|
@@ -263,12 +295,12 @@ namespace ggml_cuda_mma {
|
|
| 263 |
GGML_UNUSED(xs0);
|
| 264 |
GGML_UNUSED(stride);
|
| 265 |
NO_DEVICE_CODE;
|
| 266 |
-
#endif //
|
| 267 |
}
|
| 268 |
|
| 269 |
static __device__ __forceinline__ void mma(
|
| 270 |
tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
|
| 271 |
-
#ifdef
|
| 272 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 273 |
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 274 |
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
|
|
@@ -287,12 +319,12 @@ namespace ggml_cuda_mma {
|
|
| 287 |
GGML_UNUSED(A);
|
| 288 |
GGML_UNUSED(B);
|
| 289 |
NO_DEVICE_CODE;
|
| 290 |
-
#endif //
|
| 291 |
}
|
| 292 |
|
| 293 |
static __device__ __forceinline__ void mma(
|
| 294 |
tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
|
| 295 |
-
#ifdef
|
| 296 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 297 |
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
| 298 |
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
|
|
@@ -317,12 +349,12 @@ namespace ggml_cuda_mma {
|
|
| 317 |
GGML_UNUSED(A);
|
| 318 |
GGML_UNUSED(B);
|
| 319 |
NO_DEVICE_CODE;
|
| 320 |
-
#endif //
|
| 321 |
}
|
| 322 |
|
| 323 |
static __device__ __forceinline__ void mma(
|
| 324 |
tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
| 325 |
-
#ifdef
|
| 326 |
const int * Axi = (const int *) A.x;
|
| 327 |
const int * Bxi = (const int *) B.x;
|
| 328 |
int * Dxi = (int *) D.x;
|
|
@@ -344,12 +376,12 @@ namespace ggml_cuda_mma {
|
|
| 344 |
GGML_UNUSED(A);
|
| 345 |
GGML_UNUSED(B);
|
| 346 |
NO_DEVICE_CODE;
|
| 347 |
-
#endif //
|
| 348 |
}
|
| 349 |
|
| 350 |
static __device__ __forceinline__ void mma(
|
| 351 |
tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
| 352 |
-
#ifdef
|
| 353 |
const int * Axi = (const int *) A.x;
|
| 354 |
const int * Bxi = (const int *) B.x;
|
| 355 |
int * Dxi = (int *) D.x;
|
|
@@ -380,12 +412,29 @@ namespace ggml_cuda_mma {
|
|
| 380 |
GGML_UNUSED(A);
|
| 381 |
GGML_UNUSED(B);
|
| 382 |
NO_DEVICE_CODE;
|
| 383 |
-
#endif //
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
}
|
| 385 |
|
| 386 |
static __device__ __forceinline__ void mma(
|
| 387 |
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
| 388 |
-
#ifdef
|
| 389 |
const int * Axi = (const int *) A.x;
|
| 390 |
const int * Bxi = (const int *) B.x;
|
| 391 |
int * Dxi = (int *) D.x;
|
|
@@ -407,12 +456,29 @@ namespace ggml_cuda_mma {
|
|
| 407 |
GGML_UNUSED(A);
|
| 408 |
GGML_UNUSED(B);
|
| 409 |
NO_DEVICE_CODE;
|
| 410 |
-
#endif //
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
}
|
| 412 |
|
| 413 |
static __device__ __forceinline__ void mma(
|
| 414 |
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
| 415 |
-
#ifdef
|
| 416 |
const int * Axi = (const int *) A.x;
|
| 417 |
const int * Bxi = (const int *) B.x;
|
| 418 |
int * Dxi = (int *) D.x;
|
|
@@ -443,7 +509,7 @@ namespace ggml_cuda_mma {
|
|
| 443 |
GGML_UNUSED(A);
|
| 444 |
GGML_UNUSED(B);
|
| 445 |
NO_DEVICE_CODE;
|
| 446 |
-
#endif //
|
| 447 |
}
|
| 448 |
|
| 449 |
static __device__ __forceinline__ void mma(
|
|
|
|
| 23 |
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
| 24 |
int ret = 0;
|
| 25 |
|
| 26 |
+
#ifdef TURING_MMA_AVAILABLE
|
| 27 |
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
| 28 |
: "=r"(ret) : "r"(x));
|
| 29 |
#else
|
| 30 |
GGML_UNUSED(x);
|
| 31 |
NO_DEVICE_CODE;
|
| 32 |
+
#endif // defined(TURING_MMA_AVAILABLE)
|
| 33 |
return ret;
|
| 34 |
}
|
| 35 |
|
|
|
|
| 167 |
}
|
| 168 |
};
|
| 169 |
|
| 170 |
+
template <int I_, int J_>
|
| 171 |
+
struct tile<I_, J_, nv_bfloat162> {
|
| 172 |
+
static constexpr int I = I_;
|
| 173 |
+
static constexpr int J = J_;
|
| 174 |
+
static constexpr int ne = I * J / WARP_SIZE;
|
| 175 |
+
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
| 176 |
+
|
| 177 |
+
static __device__ __forceinline__ int get_i(const int l) {
|
| 178 |
+
if constexpr (I == 8 && J == 8) {
|
| 179 |
+
return threadIdx.x / 4;
|
| 180 |
+
} else if constexpr (I == 16 && J == 4) {
|
| 181 |
+
return l * 8 + threadIdx.x / 4;
|
| 182 |
+
} else if constexpr (I == 16 && J == 8) {
|
| 183 |
+
return (l % 2) * 8 + threadIdx.x / 4;
|
| 184 |
+
} else {
|
| 185 |
+
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
static __device__ __forceinline__ int get_j(const int l) {
|
| 190 |
+
if constexpr (I == 8 && J == 8) {
|
| 191 |
+
return l * 4 + threadIdx.x % 4;
|
| 192 |
+
} else if constexpr (I == 16 && J == 4) {
|
| 193 |
+
return threadIdx.x % 4;
|
| 194 |
+
} else if constexpr (I == 16 && J == 8) {
|
| 195 |
+
return (l / 2) * 4 + threadIdx.x % 4;
|
| 196 |
+
} else {
|
| 197 |
+
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
};
|
| 201 |
+
|
| 202 |
template <int I, int J>
|
| 203 |
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
| 204 |
tile<I, J/2, half2> ret;
|
|
|
|
| 241 |
template <typename T>
|
| 242 |
static __device__ __forceinline__ void load_ldmatrix(
|
| 243 |
tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
| 244 |
+
#ifdef TURING_MMA_AVAILABLE
|
| 245 |
int * xi = (int *) t.x;
|
| 246 |
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
|
| 247 |
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
|
|
|
| 249 |
: "l"(xs));
|
| 250 |
#else
|
| 251 |
load_generic(t, xs0, stride);
|
| 252 |
+
#endif // TURING_MMA_AVAILABLE
|
| 253 |
}
|
| 254 |
|
| 255 |
template <typename T>
|
| 256 |
static __device__ __forceinline__ void load_ldmatrix(
|
| 257 |
tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
|
| 258 |
+
#ifdef TURING_MMA_AVAILABLE
|
| 259 |
int * xi = (int *) t.x;
|
| 260 |
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
|
| 261 |
asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
|
|
|
| 264 |
#else
|
| 265 |
load_generic(xs0, stride);
|
| 266 |
GGML_UNUSED(t);
|
| 267 |
+
#endif // TURING_MMA_AVAILABLE
|
| 268 |
}
|
| 269 |
|
| 270 |
template <typename T>
|
| 271 |
static __device__ __forceinline__ void load_ldmatrix(
|
| 272 |
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
| 273 |
+
#if defined(TURING_MMA_AVAILABLE)
|
| 274 |
int * xi = (int * ) t.x;
|
| 275 |
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
| 276 |
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
|
|
|
| 278 |
: "l"(xs));
|
| 279 |
#else
|
| 280 |
load_generic(t, xs0, stride);
|
| 281 |
+
#endif // TURING_MMA_AVAILABLE
|
| 282 |
}
|
| 283 |
|
| 284 |
template <typename T>
|
| 285 |
static __device__ __forceinline__ void load_ldmatrix_trans(
|
| 286 |
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
| 287 |
+
#ifdef TURING_MMA_AVAILABLE
|
| 288 |
int * xi = (int * ) t.x;
|
| 289 |
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
| 290 |
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
|
|
|
|
| 295 |
GGML_UNUSED(xs0);
|
| 296 |
GGML_UNUSED(stride);
|
| 297 |
NO_DEVICE_CODE;
|
| 298 |
+
#endif // TURING_MMA_AVAILABLE
|
| 299 |
}
|
| 300 |
|
| 301 |
static __device__ __forceinline__ void mma(
|
| 302 |
tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
|
| 303 |
+
#ifdef TURING_MMA_AVAILABLE
|
| 304 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 305 |
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 306 |
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
|
|
|
|
| 319 |
GGML_UNUSED(A);
|
| 320 |
GGML_UNUSED(B);
|
| 321 |
NO_DEVICE_CODE;
|
| 322 |
+
#endif // TURING_MMA_AVAILABLE
|
| 323 |
}
|
| 324 |
|
| 325 |
static __device__ __forceinline__ void mma(
|
| 326 |
tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
|
| 327 |
+
#ifdef TURING_MMA_AVAILABLE
|
| 328 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 329 |
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
| 330 |
: "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
|
|
|
|
| 349 |
GGML_UNUSED(A);
|
| 350 |
GGML_UNUSED(B);
|
| 351 |
NO_DEVICE_CODE;
|
| 352 |
+
#endif // TURING_MMA_AVAILABLE
|
| 353 |
}
|
| 354 |
|
| 355 |
static __device__ __forceinline__ void mma(
|
| 356 |
tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
| 357 |
+
#ifdef TURING_MMA_AVAILABLE
|
| 358 |
const int * Axi = (const int *) A.x;
|
| 359 |
const int * Bxi = (const int *) B.x;
|
| 360 |
int * Dxi = (int *) D.x;
|
|
|
|
| 376 |
GGML_UNUSED(A);
|
| 377 |
GGML_UNUSED(B);
|
| 378 |
NO_DEVICE_CODE;
|
| 379 |
+
#endif // TURING_MMA_AVAILABLE
|
| 380 |
}
|
| 381 |
|
| 382 |
static __device__ __forceinline__ void mma(
|
| 383 |
tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
| 384 |
+
#ifdef TURING_MMA_AVAILABLE
|
| 385 |
const int * Axi = (const int *) A.x;
|
| 386 |
const int * Bxi = (const int *) B.x;
|
| 387 |
int * Dxi = (int *) D.x;
|
|
|
|
| 412 |
GGML_UNUSED(A);
|
| 413 |
GGML_UNUSED(B);
|
| 414 |
NO_DEVICE_CODE;
|
| 415 |
+
#endif // TURING_MMA_AVAILABLE
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
static __device__ __forceinline__ void mma(
|
| 419 |
+
tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
|
| 420 |
+
#ifdef AMPERE_MMA_AVAILABLE
|
| 421 |
+
const int * Axi = (const int *) A.x;
|
| 422 |
+
const int * Bxi = (const int *) B.x;
|
| 423 |
+
int * Dxi = (int *) D.x;
|
| 424 |
+
asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
| 425 |
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
| 426 |
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
| 427 |
+
#else
|
| 428 |
+
GGML_UNUSED(D);
|
| 429 |
+
GGML_UNUSED(A);
|
| 430 |
+
GGML_UNUSED(B);
|
| 431 |
+
NO_DEVICE_CODE;
|
| 432 |
+
#endif // AMPERE_MMA_AVAILABLE
|
| 433 |
}
|
| 434 |
|
| 435 |
static __device__ __forceinline__ void mma(
|
| 436 |
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
| 437 |
+
#ifdef TURING_MMA_AVAILABLE
|
| 438 |
const int * Axi = (const int *) A.x;
|
| 439 |
const int * Bxi = (const int *) B.x;
|
| 440 |
int * Dxi = (int *) D.x;
|
|
|
|
| 456 |
GGML_UNUSED(A);
|
| 457 |
GGML_UNUSED(B);
|
| 458 |
NO_DEVICE_CODE;
|
| 459 |
+
#endif // TURING_MMA_AVAILABLE
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
static __device__ __forceinline__ void mma(
|
| 463 |
+
tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
|
| 464 |
+
#ifdef AMPERE_MMA_AVAILABLE
|
| 465 |
+
const int * Axi = (const int *) A.x;
|
| 466 |
+
const int * Bxi = (const int *) B.x;
|
| 467 |
+
int * Dxi = (int *) D.x;
|
| 468 |
+
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
| 469 |
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
| 470 |
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
| 471 |
+
#else
|
| 472 |
+
GGML_UNUSED(D);
|
| 473 |
+
GGML_UNUSED(A);
|
| 474 |
+
GGML_UNUSED(B);
|
| 475 |
+
NO_DEVICE_CODE;
|
| 476 |
+
#endif // AMPERE_MMA_AVAILABLE
|
| 477 |
}
|
| 478 |
|
| 479 |
static __device__ __forceinline__ void mma(
|
| 480 |
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
| 481 |
+
#ifdef TURING_MMA_AVAILABLE
|
| 482 |
const int * Axi = (const int *) A.x;
|
| 483 |
const int * Bxi = (const int *) B.x;
|
| 484 |
int * Dxi = (int *) D.x;
|
|
|
|
| 509 |
GGML_UNUSED(A);
|
| 510 |
GGML_UNUSED(B);
|
| 511 |
NO_DEVICE_CODE;
|
| 512 |
+
#endif // TURING_MMA_AVAILABLE
|
| 513 |
}
|
| 514 |
|
| 515 |
static __device__ __forceinline__ void mma(
|
ggml/src/ggml-cuda/mmf.cu
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ggml.h"
|
| 2 |
+
#include "common.cuh"
|
| 3 |
+
#include "mma.cuh"
|
| 4 |
+
#include "mmf.cuh"
|
| 5 |
+
|
| 6 |
+
using namespace ggml_cuda_mma;
|
| 7 |
+
|
| 8 |
+
#define MMF_ROWS_PER_BLOCK 32
|
| 9 |
+
|
| 10 |
+
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
|
| 11 |
+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
|
| 12 |
+
static __global__ void mul_mat_f(
|
| 13 |
+
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
| 14 |
+
const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
| 15 |
+
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
| 16 |
+
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
| 17 |
+
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
| 18 |
+
typedef tile<16, 8, T> tile_A;
|
| 19 |
+
typedef tile< 8, 8, T> tile_B;
|
| 20 |
+
typedef tile<16, 8, float> tile_C;
|
| 21 |
+
|
| 22 |
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 23 |
+
constexpr int tile_k_padded = warp_size + 4;
|
| 24 |
+
constexpr int ntA = rows_per_block / tile_A::I;
|
| 25 |
+
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
| 26 |
+
|
| 27 |
+
const int row0 = blockIdx.x * rows_per_block;
|
| 28 |
+
const int channel_dst = blockIdx.y;
|
| 29 |
+
const int channel_x = channel_dst / channel_ratio;
|
| 30 |
+
const int channel_y = channel_dst;
|
| 31 |
+
const int sample_dst = blockIdx.z;
|
| 32 |
+
const int sample_x = sample_dst / sample_ratio;
|
| 33 |
+
const int sample_y = sample_dst;
|
| 34 |
+
|
| 35 |
+
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
|
| 36 |
+
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
|
| 37 |
+
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
|
| 38 |
+
|
| 39 |
+
const float2 * y2 = (const float2 *) y;
|
| 40 |
+
|
| 41 |
+
extern __shared__ char data_mmv[];
|
| 42 |
+
|
| 43 |
+
tile_C C[ntA][ntB];
|
| 44 |
+
|
| 45 |
+
T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
|
| 46 |
+
|
| 47 |
+
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
|
| 48 |
+
tile_A A[ntA][warp_size / tile_A::J];
|
| 49 |
+
#pragma unroll
|
| 50 |
+
for (int itA = 0; itA < ntA; ++itA) {
|
| 51 |
+
#pragma unroll
|
| 52 |
+
for (int i = 0; i < tile_A::I; ++i) {
|
| 53 |
+
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
|
| 54 |
+
}
|
| 55 |
+
#pragma unroll
|
| 56 |
+
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
|
| 57 |
+
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
#pragma unroll
|
| 62 |
+
for (int itB = 0; itB < ntB; ++itB) {
|
| 63 |
+
if constexpr (std::is_same_v<T, float>) {
|
| 64 |
+
#pragma unroll
|
| 65 |
+
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
| 66 |
+
const int j = j0 + itB*tile_B::I;
|
| 67 |
+
|
| 68 |
+
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
|
| 69 |
+
}
|
| 70 |
+
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
| 71 |
+
#pragma unroll
|
| 72 |
+
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
| 73 |
+
const int j = j0 + itB*tile_B::I;
|
| 74 |
+
|
| 75 |
+
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
| 76 |
+
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
| 77 |
+
}
|
| 78 |
+
} else {
|
| 79 |
+
static_assert(std::is_same_v<T, void>, "unsupported type");
|
| 80 |
+
}
|
| 81 |
+
#pragma unroll
|
| 82 |
+
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
|
| 83 |
+
tile_B B;
|
| 84 |
+
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
|
| 85 |
+
#pragma unroll
|
| 86 |
+
for (int itA = 0; itA < ntA; ++itA) {
|
| 87 |
+
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
float * buf_iw = (float *) data_mmv;
|
| 94 |
+
constexpr int kiw = nwarps*rows_per_block + 4;
|
| 95 |
+
|
| 96 |
+
if (nwarps > 1) {
|
| 97 |
+
__syncthreads();
|
| 98 |
+
}
|
| 99 |
+
#pragma unroll
|
| 100 |
+
for (int itB = 0; itB < ntB; ++itB) {
|
| 101 |
+
#pragma unroll
|
| 102 |
+
for (int itA = 0; itA < ntA; ++itA) {
|
| 103 |
+
#pragma unroll
|
| 104 |
+
for (int l = 0; l < tile_C::ne; ++l) {
|
| 105 |
+
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
|
| 106 |
+
const int j = itB*tile_C::J + tile_C::get_j(l);
|
| 107 |
+
buf_iw[j*kiw + i] = C[itA][itB].x[l];
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
if (nwarps > 1) {
|
| 113 |
+
__syncthreads();
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
#pragma unroll
|
| 117 |
+
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
| 118 |
+
const int j = j0 + threadIdx.y;
|
| 119 |
+
|
| 120 |
+
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
|
| 121 |
+
return;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
float sum = 0.0f;
|
| 125 |
+
static_assert(rows_per_block == warp_size, "need loop/check");
|
| 126 |
+
#pragma unroll
|
| 127 |
+
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
| 128 |
+
const int i = i0 + threadIdx.x;
|
| 129 |
+
|
| 130 |
+
sum += buf_iw[j*kiw + i];
|
| 131 |
+
}
|
| 132 |
+
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
|
| 133 |
+
}
|
| 134 |
+
#else
|
| 135 |
+
NO_DEVICE_CODE;
|
| 136 |
+
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(ids); GGML_UNUSED(dst);
|
| 137 |
+
GGML_UNUSED(ncols); GGML_UNUSED(nchannels_y); GGML_UNUSED(stride_row); GGML_UNUSED(stride_col_y); GGML_UNUSED(stride_col_dst);
|
| 138 |
+
GGML_UNUSED(channel_ratio); GGML_UNUSED(stride_channel_x); GGML_UNUSED(stride_channel_y); GGML_UNUSED(stride_channel_dst);
|
| 139 |
+
GGML_UNUSED(sample_ratio); GGML_UNUSED(stride_sample_x); GGML_UNUSED(stride_sample_y); GGML_UNUSED(stride_sample_dst);
|
| 140 |
+
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
template <typename T, int cols_per_block>
|
| 144 |
+
static void mul_mat_f_cuda(
|
| 145 |
+
const T * x, const float * y, const int32_t * ids, float * dst,
|
| 146 |
+
const int64_t ncols_x, const int64_t nrows_x,
|
| 147 |
+
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
| 148 |
+
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
| 149 |
+
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
| 150 |
+
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
| 151 |
+
cudaStream_t stream) {
|
| 152 |
+
typedef tile<16, 8, T> tile_A;
|
| 153 |
+
typedef tile< 8, 8, T> tile_B;
|
| 154 |
+
typedef tile<16, 8, float> tile_C;
|
| 155 |
+
|
| 156 |
+
GGML_ASSERT(!ids && "mul_mat_id not implemented");
|
| 157 |
+
|
| 158 |
+
GGML_ASSERT(ncols_x % 2 == 0);
|
| 159 |
+
GGML_ASSERT(stride_row % 2 == 0);
|
| 160 |
+
GGML_ASSERT(stride_col_y % 2 == 0);
|
| 161 |
+
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
| 162 |
+
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
| 163 |
+
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
| 164 |
+
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
| 165 |
+
|
| 166 |
+
const int device = ggml_cuda_get_device();
|
| 167 |
+
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
| 168 |
+
|
| 169 |
+
int64_t nwarps_best = 1;
|
| 170 |
+
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
|
| 171 |
+
int64_t max_block_size = 256;
|
| 172 |
+
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
|
| 173 |
+
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
|
| 174 |
+
if (niter < niter_best) {
|
| 175 |
+
niter_best = niter;
|
| 176 |
+
nwarps_best = nwarps;
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
| 181 |
+
const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
|
| 182 |
+
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
|
| 183 |
+
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
| 184 |
+
const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
|
| 185 |
+
const dim3 block_dims(warp_size, nwarps_best, 1);
|
| 186 |
+
switch (nwarps_best) {
|
| 187 |
+
case 1: {
|
| 188 |
+
mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 189 |
+
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
| 190 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 191 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 192 |
+
} break;
|
| 193 |
+
case 2: {
|
| 194 |
+
mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 195 |
+
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
| 196 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 197 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 198 |
+
} break;
|
| 199 |
+
case 3: {
|
| 200 |
+
mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 201 |
+
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
| 202 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 203 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 204 |
+
} break;
|
| 205 |
+
case 4: {
|
| 206 |
+
mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 207 |
+
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
| 208 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 209 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 210 |
+
} break;
|
| 211 |
+
case 5: {
|
| 212 |
+
mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 213 |
+
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
| 214 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 215 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 216 |
+
} break;
|
| 217 |
+
case 6: {
|
| 218 |
+
mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 219 |
+
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
| 220 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 221 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 222 |
+
} break;
|
| 223 |
+
case 7: {
|
| 224 |
+
mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 225 |
+
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
| 226 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 227 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 228 |
+
} break;
|
| 229 |
+
case 8: {
|
| 230 |
+
mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 231 |
+
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
|
| 232 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 233 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 234 |
+
} break;
|
| 235 |
+
default: {
|
| 236 |
+
GGML_ABORT("fatal error");
|
| 237 |
+
} break;
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
template <typename T>
|
| 242 |
+
static void mul_mat_f_switch_cols_per_block(
|
| 243 |
+
const T * x, const float * y, const int32_t * ids, float * dst,
|
| 244 |
+
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
| 245 |
+
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
| 246 |
+
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
| 247 |
+
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
| 248 |
+
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
| 249 |
+
cudaStream_t stream) {
|
| 250 |
+
switch (ncols_dst) {
|
| 251 |
+
case 1: {
|
| 252 |
+
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 253 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 254 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 255 |
+
} break;
|
| 256 |
+
case 2: {
|
| 257 |
+
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 258 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 259 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 260 |
+
} break;
|
| 261 |
+
case 3: {
|
| 262 |
+
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 263 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 264 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 265 |
+
} break;
|
| 266 |
+
case 4: {
|
| 267 |
+
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 268 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 269 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 270 |
+
} break;
|
| 271 |
+
case 5: {
|
| 272 |
+
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 273 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 274 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 275 |
+
} break;
|
| 276 |
+
case 6: {
|
| 277 |
+
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 278 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 279 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 280 |
+
} break;
|
| 281 |
+
case 7: {
|
| 282 |
+
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 283 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 284 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 285 |
+
} break;
|
| 286 |
+
case 8: {
|
| 287 |
+
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 288 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 289 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 290 |
+
} break;
|
| 291 |
+
case 9: {
|
| 292 |
+
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 293 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 294 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 295 |
+
} break;
|
| 296 |
+
case 10: {
|
| 297 |
+
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 298 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 299 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 300 |
+
} break;
|
| 301 |
+
case 11: {
|
| 302 |
+
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 303 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 304 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 305 |
+
} break;
|
| 306 |
+
case 12: {
|
| 307 |
+
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 308 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 309 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 310 |
+
} break;
|
| 311 |
+
case 13: {
|
| 312 |
+
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 313 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 314 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 315 |
+
} break;
|
| 316 |
+
case 14: {
|
| 317 |
+
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 318 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 319 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 320 |
+
} break;
|
| 321 |
+
case 15: {
|
| 322 |
+
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 323 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 324 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 325 |
+
} break;
|
| 326 |
+
case 16: {
|
| 327 |
+
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
|
| 328 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 329 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 330 |
+
} break;
|
| 331 |
+
default: {
|
| 332 |
+
GGML_ABORT("fatal error");
|
| 333 |
+
} break;
|
| 334 |
+
}
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
| 338 |
+
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
| 339 |
+
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
| 340 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 341 |
+
|
| 342 |
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
| 343 |
+
|
| 344 |
+
const size_t ts_src0 = ggml_type_size(src0->type);
|
| 345 |
+
const size_t ts_src1 = ggml_type_size(src1->type);
|
| 346 |
+
const size_t ts_dst = ggml_type_size(dst->type);
|
| 347 |
+
|
| 348 |
+
GGML_ASSERT(ne13 == ne3);
|
| 349 |
+
|
| 350 |
+
GGML_ASSERT( nb00 == ts_src0);
|
| 351 |
+
GGML_ASSERT( nb10 == ts_src1);
|
| 352 |
+
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
|
| 353 |
+
GGML_ASSERT( nb0 == ts_dst);
|
| 354 |
+
|
| 355 |
+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 356 |
+
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
| 357 |
+
|
| 358 |
+
const float * src1_d = (const float *) src1->data;
|
| 359 |
+
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
| 360 |
+
float * dst_d = (float *) dst->data;
|
| 361 |
+
|
| 362 |
+
const int64_t s01 = src0->nb[1] / ts_src0;
|
| 363 |
+
const int64_t s11 = src1->nb[1] / ts_src1;
|
| 364 |
+
const int64_t s1 = dst->nb[1] / ts_dst;
|
| 365 |
+
const int64_t s02 = src0->nb[2] / ts_src0;
|
| 366 |
+
const int64_t s12 = src1->nb[2] / ts_src1;
|
| 367 |
+
const int64_t s2 = dst->nb[2] / ts_dst;
|
| 368 |
+
const int64_t s03 = src0->nb[3] / ts_src0;
|
| 369 |
+
const int64_t s13 = src1->nb[3] / ts_src1;
|
| 370 |
+
const int64_t s3 = dst->nb[3] / ts_dst;
|
| 371 |
+
|
| 372 |
+
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
|
| 373 |
+
const int64_t ncols_dst = ids ? ne2 : ne1;
|
| 374 |
+
const int64_t nchannels_y = ids ? ne11 : ne12;
|
| 375 |
+
const int64_t nchannels_dst = ids ? ne1 : ne2;
|
| 376 |
+
const int64_t stride_channel_dst = ids ? s1 : s2;
|
| 377 |
+
const int64_t stride_channel_y = ids ? s11 : s12;
|
| 378 |
+
|
| 379 |
+
GGML_ASSERT(!ids || ncols_dst == 1);
|
| 380 |
+
|
| 381 |
+
switch (src0->type) {
|
| 382 |
+
case GGML_TYPE_F32: {
|
| 383 |
+
const float * src0_d = (const float *) src0->data;
|
| 384 |
+
constexpr int vals_per_T = 1;
|
| 385 |
+
mul_mat_f_switch_cols_per_block(
|
| 386 |
+
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
|
| 387 |
+
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
| 388 |
+
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
| 389 |
+
} break;
|
| 390 |
+
case GGML_TYPE_F16: {
|
| 391 |
+
const half2 * src0_d = (const half2 *) src0->data;
|
| 392 |
+
constexpr int vals_per_T = 2;
|
| 393 |
+
mul_mat_f_switch_cols_per_block(
|
| 394 |
+
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
|
| 395 |
+
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
| 396 |
+
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
| 397 |
+
} break;
|
| 398 |
+
case GGML_TYPE_BF16: {
|
| 399 |
+
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
| 400 |
+
constexpr int vals_per_T = 2;
|
| 401 |
+
mul_mat_f_switch_cols_per_block(
|
| 402 |
+
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
|
| 403 |
+
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
| 404 |
+
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
| 405 |
+
} break;
|
| 406 |
+
default:
|
| 407 |
+
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
|
| 412 |
+
if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
|
| 413 |
+
return false;
|
| 414 |
+
}
|
| 415 |
+
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
|
| 416 |
+
return false;
|
| 417 |
+
}
|
| 418 |
+
if (ne11 > 16) {
|
| 419 |
+
return false;
|
| 420 |
+
}
|
| 421 |
+
switch (type) {
|
| 422 |
+
case GGML_TYPE_F32:
|
| 423 |
+
return ampere_mma_available(cc);
|
| 424 |
+
case GGML_TYPE_F16:
|
| 425 |
+
return turing_mma_available(cc);
|
| 426 |
+
case GGML_TYPE_BF16:
|
| 427 |
+
return ampere_mma_available(cc);
|
| 428 |
+
default:
|
| 429 |
+
return false;
|
| 430 |
+
}
|
| 431 |
+
}
|
ggml/src/ggml-cuda/mmf.cuh
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
|
| 3 |
+
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
| 4 |
+
|
| 5 |
+
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);
|
ggml/src/ggml-cuda/mmq.cu
CHANGED
|
@@ -310,7 +310,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
|
| 310 |
return false;
|
| 311 |
}
|
| 312 |
|
| 313 |
-
if (
|
| 314 |
return true;
|
| 315 |
}
|
| 316 |
|
|
|
|
| 310 |
return false;
|
| 311 |
}
|
| 312 |
|
| 313 |
+
if (turing_mma_available(cc)) {
|
| 314 |
return true;
|
| 315 |
}
|
| 316 |
|
ggml/src/ggml-cuda/mmq.cuh
CHANGED
|
@@ -92,7 +92,7 @@ struct tile_x_sizes {
|
|
| 92 |
};
|
| 93 |
|
| 94 |
static int get_mmq_x_max_host(const int cc) {
|
| 95 |
-
return (amd_mfma_available(cc) ||
|
| 96 |
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
|
| 97 |
#ifdef GGML_CUDA_FORCE_MMQ
|
| 98 |
128 : 64;
|
|
@@ -102,9 +102,9 @@ static int get_mmq_x_max_host(const int cc) {
|
|
| 102 |
}
|
| 103 |
|
| 104 |
static constexpr __device__ int get_mmq_x_max_device() {
|
| 105 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 106 |
return 128;
|
| 107 |
-
#else // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 108 |
|
| 109 |
#if defined(GGML_USE_HIP)
|
| 110 |
return 64;
|
|
@@ -121,7 +121,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
|
| 121 |
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
| 122 |
|
| 123 |
#endif // defined(GGML_USE_HIP)
|
| 124 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 125 |
}
|
| 126 |
|
| 127 |
static int get_mmq_y_host(const int cc) {
|
|
@@ -233,7 +233,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
| 233 |
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
| 234 |
if (amd_mfma_available(cc)) {
|
| 235 |
return mmq_x >= 128 ? 32 : 16;
|
| 236 |
-
} else if (
|
| 237 |
return 16;
|
| 238 |
} else {
|
| 239 |
return 8;
|
|
@@ -244,7 +244,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
|
| 244 |
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
|
| 245 |
return mmq_x >= 128 ? 32 : 16;
|
| 246 |
}
|
| 247 |
-
#elif defined(
|
| 248 |
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
|
| 249 |
return mmq_x >= 48 ? 16 : 8;
|
| 250 |
}
|
|
@@ -279,14 +279,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 279 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 280 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 281 |
|
| 282 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 283 |
int * x_qs = (int *) x_tile;
|
| 284 |
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
|
| 285 |
#else
|
| 286 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
| 287 |
int * x_qs = (int *) x_tile;
|
| 288 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 289 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 290 |
|
| 291 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
|
| 292 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -305,12 +305,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 305 |
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
|
| 306 |
const int qs0 = get_int_b2(bxi->qs, kqsx);
|
| 307 |
|
| 308 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 309 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
|
| 310 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
|
| 311 |
#else
|
| 312 |
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
| 313 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 314 |
}
|
| 315 |
|
| 316 |
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
|
|
@@ -327,11 +327,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 327 |
|
| 328 |
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
|
| 329 |
|
| 330 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 331 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 332 |
#else
|
| 333 |
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
|
| 334 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 335 |
}
|
| 336 |
}
|
| 337 |
|
|
@@ -382,14 +382,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 382 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 383 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 384 |
|
| 385 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 386 |
int * x_qs = (int *) x_tile;
|
| 387 |
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
| 388 |
#else
|
| 389 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
| 390 |
int * x_qs = (int *) x_tile;
|
| 391 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 392 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 393 |
|
| 394 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
|
| 395 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -408,12 +408,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 408 |
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
|
| 409 |
const int qs0 = get_int_b4(bxi->qs, kqsx);
|
| 410 |
|
| 411 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 412 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
| 413 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
|
| 414 |
#else
|
| 415 |
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
| 416 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 417 |
}
|
| 418 |
|
| 419 |
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
|
|
@@ -430,11 +430,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 430 |
|
| 431 |
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
|
| 432 |
|
| 433 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 434 |
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
| 435 |
#else
|
| 436 |
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
|
| 437 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 438 |
}
|
| 439 |
}
|
| 440 |
|
|
@@ -485,14 +485,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 485 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 486 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 487 |
|
| 488 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 489 |
int * x_qs = (int *) x_tile;
|
| 490 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 491 |
#else
|
| 492 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
|
| 493 |
int * x_qs = (int *) x_tile;
|
| 494 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 495 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 496 |
|
| 497 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
|
| 498 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -527,13 +527,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 527 |
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
| 528 |
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
|
| 529 |
|
| 530 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 531 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
| 532 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
| 533 |
#else
|
| 534 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
| 535 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
| 536 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 537 |
}
|
| 538 |
|
| 539 |
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
|
|
@@ -550,11 +550,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 550 |
|
| 551 |
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
|
| 552 |
|
| 553 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 554 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 555 |
#else
|
| 556 |
x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
|
| 557 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 558 |
}
|
| 559 |
}
|
| 560 |
|
|
@@ -563,14 +563,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 563 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 564 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 565 |
|
| 566 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 567 |
int * x_qs = (int *) x_tile;
|
| 568 |
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
| 569 |
#else
|
| 570 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
| 571 |
int * x_qs = (int *) x_tile;
|
| 572 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 573 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 574 |
|
| 575 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
|
| 576 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -603,13 +603,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 603 |
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
|
| 604 |
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
| 605 |
|
| 606 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 607 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
| 608 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
| 609 |
#else
|
| 610 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
| 611 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
| 612 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 613 |
}
|
| 614 |
|
| 615 |
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
|
|
@@ -626,11 +626,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 626 |
|
| 627 |
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
|
| 628 |
|
| 629 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 630 |
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
| 631 |
#else
|
| 632 |
x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
|
| 633 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 634 |
}
|
| 635 |
}
|
| 636 |
|
|
@@ -639,14 +639,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 639 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 640 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 641 |
|
| 642 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 643 |
int * x_qs = (int *) x_tile;
|
| 644 |
float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
|
| 645 |
#else
|
| 646 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
| 647 |
int * x_qs = (int *) x_tile;
|
| 648 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 649 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 650 |
|
| 651 |
// MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
|
| 652 |
constexpr int threads_per_row = 32;
|
|
@@ -665,13 +665,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 665 |
|
| 666 |
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
| 667 |
|
| 668 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 669 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
|
| 670 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
|
| 671 |
#else
|
| 672 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
|
| 673 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
|
| 674 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 675 |
}
|
| 676 |
|
| 677 |
constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
|
|
@@ -688,11 +688,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 688 |
|
| 689 |
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
|
| 690 |
|
| 691 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 692 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 693 |
#else
|
| 694 |
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
|
| 695 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 696 |
}
|
| 697 |
}
|
| 698 |
|
|
@@ -701,14 +701,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 701 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 702 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 703 |
|
| 704 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 705 |
int * x_qs = (int *) x_tile;
|
| 706 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 707 |
#else
|
| 708 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
|
| 709 |
int * x_qs = (int *) x_tile;
|
| 710 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 711 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 712 |
|
| 713 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
|
| 714 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -730,13 +730,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 730 |
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
| 731 |
const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
|
| 732 |
|
| 733 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 734 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
|
| 735 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
|
| 736 |
#else
|
| 737 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
| 738 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
|
| 739 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 740 |
}
|
| 741 |
|
| 742 |
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
|
|
@@ -753,11 +753,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 753 |
|
| 754 |
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
|
| 755 |
|
| 756 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 757 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
| 758 |
#else
|
| 759 |
x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
| 760 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 761 |
}
|
| 762 |
}
|
| 763 |
|
|
@@ -1178,7 +1178,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
| 1178 |
}
|
| 1179 |
}
|
| 1180 |
}
|
| 1181 |
-
#elif defined(
|
| 1182 |
|
| 1183 |
typedef tile<16, 4, int> tile_A;
|
| 1184 |
typedef tile<16, 8, int> tile_A_8;
|
|
@@ -1264,14 +1264,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1264 |
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 1265 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 1266 |
|
| 1267 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1268 |
int * x_qs = (int *) x_tile;
|
| 1269 |
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
| 1270 |
#else
|
| 1271 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
| 1272 |
int * x_qs = (int *) x_tile;
|
| 1273 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 1274 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1275 |
|
| 1276 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
|
| 1277 |
constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
|
|
@@ -1295,11 +1295,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1295 |
|
| 1296 |
const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
|
| 1297 |
|
| 1298 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1299 |
x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
|
| 1300 |
#else
|
| 1301 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
|
| 1302 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1303 |
}
|
| 1304 |
|
| 1305 |
const int sc_m = bxi->scales[kqsx];
|
|
@@ -1310,11 +1310,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1310 |
const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
|
| 1311 |
#endif // FAST_FP16_AVAILABLE
|
| 1312 |
|
| 1313 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1314 |
x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
|
| 1315 |
#else
|
| 1316 |
x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
|
| 1317 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1318 |
}
|
| 1319 |
}
|
| 1320 |
|
|
@@ -1452,7 +1452,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
| 1452 |
}
|
| 1453 |
}
|
| 1454 |
}
|
| 1455 |
-
#elif defined(
|
| 1456 |
|
| 1457 |
typedef tile<16, 4, int> tile_A;
|
| 1458 |
typedef tile<16, 8, int> tile_A_8;
|
|
@@ -1582,7 +1582,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1582 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 1583 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 1584 |
|
| 1585 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1586 |
int * x_qs = (int *) x_tile;
|
| 1587 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 1588 |
#else
|
|
@@ -1590,7 +1590,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1590 |
int * x_qs = (int *) x_tile;
|
| 1591 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 1592 |
int * x_sc = (int *) (x_df + txs.dm);
|
| 1593 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1594 |
|
| 1595 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
|
| 1596 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -1618,11 +1618,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1618 |
|
| 1619 |
const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
|
| 1620 |
|
| 1621 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1622 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
|
| 1623 |
#else
|
| 1624 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
|
| 1625 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1626 |
}
|
| 1627 |
}
|
| 1628 |
|
|
@@ -1649,7 +1649,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1649 |
|
| 1650 |
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
|
| 1651 |
|
| 1652 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1653 |
const int8_t * sc8 = (const int8_t *) ≻
|
| 1654 |
const float d = bxi->d;
|
| 1655 |
|
|
@@ -1659,10 +1659,10 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1659 |
}
|
| 1660 |
#else
|
| 1661 |
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
|
| 1662 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1663 |
}
|
| 1664 |
|
| 1665 |
-
#if !(defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1666 |
#pragma unroll
|
| 1667 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
| 1668 |
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
|
@@ -1675,7 +1675,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1675 |
|
| 1676 |
x_df[i] = bxi->d;
|
| 1677 |
}
|
| 1678 |
-
#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1679 |
}
|
| 1680 |
|
| 1681 |
template <int mmq_x, int mmq_y>
|
|
@@ -1728,7 +1728,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1728 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 1729 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 1730 |
|
| 1731 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1732 |
int * x_qs = (int *) x_tile;
|
| 1733 |
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
| 1734 |
#else
|
|
@@ -1736,7 +1736,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1736 |
int * x_qs = (int *) x_tile;
|
| 1737 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 1738 |
int * x_sc = (int *) (x_dm + txs.dm);
|
| 1739 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1740 |
|
| 1741 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
|
| 1742 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -1753,15 +1753,15 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1753 |
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
|
| 1754 |
const int qs0 = get_int_b4(bxi->qs, txi);
|
| 1755 |
|
| 1756 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1757 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
| 1758 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
|
| 1759 |
#else
|
| 1760 |
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
| 1761 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1762 |
}
|
| 1763 |
|
| 1764 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1765 |
constexpr int rows_per_warp = warp_size / 2;
|
| 1766 |
#pragma unroll
|
| 1767 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
@@ -1829,7 +1829,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1829 |
|
| 1830 |
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
|
| 1831 |
}
|
| 1832 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1833 |
}
|
| 1834 |
|
| 1835 |
template <int mmq_x, int mmq_y>
|
|
@@ -1872,7 +1872,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1872 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 1873 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 1874 |
|
| 1875 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1876 |
int * x_qs = (int *) x_tile;
|
| 1877 |
half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
|
| 1878 |
#else
|
|
@@ -1880,7 +1880,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1880 |
int * x_qs = (int *) x_tile;
|
| 1881 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 1882 |
int * x_sc = (int *) (x_dm + txs.dm);
|
| 1883 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1884 |
|
| 1885 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
|
| 1886 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -1908,16 +1908,16 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1908 |
const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
|
| 1909 |
const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
|
| 1910 |
|
| 1911 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1912 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
|
| 1913 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
|
| 1914 |
#else
|
| 1915 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
|
| 1916 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
|
| 1917 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1918 |
}
|
| 1919 |
|
| 1920 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1921 |
constexpr int rows_per_warp = warp_size / 2;
|
| 1922 |
#pragma unroll
|
| 1923 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
@@ -1986,7 +1986,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 1986 |
|
| 1987 |
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
|
| 1988 |
}
|
| 1989 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 1990 |
}
|
| 1991 |
|
| 1992 |
template <int mmq_x, int mmq_y>
|
|
@@ -2029,7 +2029,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2029 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2030 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2031 |
|
| 2032 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2033 |
int * x_qs = (int *) x_tile;
|
| 2034 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2035 |
int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
|
|
@@ -2038,7 +2038,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2038 |
int * x_qs = (int *) x_tile;
|
| 2039 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2040 |
int * x_sc = (int *) (x_df + txs.dm);
|
| 2041 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2042 |
|
| 2043 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
|
| 2044 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2065,13 +2065,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2065 |
const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
|
| 2066 |
const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
|
| 2067 |
|
| 2068 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2069 |
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
| 2070 |
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
| 2071 |
#else
|
| 2072 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
| 2073 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
| 2074 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2075 |
}
|
| 2076 |
|
| 2077 |
#pragma unroll
|
|
@@ -2084,11 +2084,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2084 |
|
| 2085 |
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
|
| 2086 |
|
| 2087 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2088 |
x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
|
| 2089 |
#else
|
| 2090 |
x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
|
| 2091 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2092 |
}
|
| 2093 |
|
| 2094 |
constexpr int rows_per_warp = warp_size / 4;
|
|
@@ -2102,11 +2102,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2102 |
|
| 2103 |
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
|
| 2104 |
|
| 2105 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2106 |
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
|
| 2107 |
#else
|
| 2108 |
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
|
| 2109 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2110 |
}
|
| 2111 |
}
|
| 2112 |
|
|
@@ -2199,7 +2199,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 2199 |
}
|
| 2200 |
}
|
| 2201 |
}
|
| 2202 |
-
#elif defined(
|
| 2203 |
|
| 2204 |
typedef tile<16, 4, int> tile_A;
|
| 2205 |
typedef tile< 8, 4, int> tile_B;
|
|
@@ -2311,14 +2311,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2311 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2312 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2313 |
|
| 2314 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2315 |
int * x_qs = (int *) x_tile;
|
| 2316 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2317 |
#else
|
| 2318 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
|
| 2319 |
int * x_qs = (int *) x_tile;
|
| 2320 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2321 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2322 |
|
| 2323 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
|
| 2324 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2340,13 +2340,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2340 |
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
| 2341 |
const int k0 = kbx * (2 * QI4_NL) + kqsx;
|
| 2342 |
|
| 2343 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2344 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
| 2345 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
|
| 2346 |
#else
|
| 2347 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
| 2348 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
|
| 2349 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2350 |
}
|
| 2351 |
|
| 2352 |
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
|
|
@@ -2363,11 +2363,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2363 |
|
| 2364 |
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
|
| 2365 |
|
| 2366 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2367 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
|
| 2368 |
#else
|
| 2369 |
x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
|
| 2370 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2371 |
}
|
| 2372 |
}
|
| 2373 |
|
|
@@ -2376,14 +2376,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2376 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2377 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2378 |
|
| 2379 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2380 |
int * x_qs = (int *) x_tile;
|
| 2381 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2382 |
#else
|
| 2383 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
|
| 2384 |
int * x_qs = (int *) x_tile;
|
| 2385 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2386 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2387 |
|
| 2388 |
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
|
| 2389 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2414,22 +2414,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2414 |
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
| 2415 |
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
| 2416 |
|
| 2417 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2418 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
|
| 2419 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
|
| 2420 |
#else
|
| 2421 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
|
| 2422 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
|
| 2423 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2424 |
}
|
| 2425 |
|
| 2426 |
const int ls = aux32 >> 28;
|
| 2427 |
const float d = bxi->d;
|
| 2428 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2429 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
|
| 2430 |
#else
|
| 2431 |
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
|
| 2432 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2433 |
}
|
| 2434 |
}
|
| 2435 |
|
|
@@ -2438,14 +2438,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2438 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2439 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2440 |
|
| 2441 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2442 |
int * x_qs = (int *) x_tile;
|
| 2443 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2444 |
#else
|
| 2445 |
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
| 2446 |
int * x_qs = (int *) x_tile;
|
| 2447 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2448 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2449 |
|
| 2450 |
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
|
| 2451 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2472,24 +2472,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2472 |
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
| 2473 |
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
| 2474 |
|
| 2475 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2476 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2477 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2478 |
#else
|
| 2479 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2480 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2481 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2482 |
}
|
| 2483 |
|
| 2484 |
const int ls = bxi->scales[kqsx];
|
| 2485 |
const float d = bxi->d;
|
| 2486 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2487 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 2488 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 2489 |
#else
|
| 2490 |
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 2491 |
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 2492 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2493 |
}
|
| 2494 |
}
|
| 2495 |
|
|
@@ -2498,14 +2498,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2498 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2499 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2500 |
|
| 2501 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2502 |
int * x_qs = (int *) x_tile;
|
| 2503 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2504 |
#else
|
| 2505 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
|
| 2506 |
int * x_qs = (int *) x_tile;
|
| 2507 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2508 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2509 |
|
| 2510 |
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
|
| 2511 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2539,24 +2539,24 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2539 |
const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
|
| 2540 |
const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
|
| 2541 |
|
| 2542 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2543 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2544 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2545 |
#else
|
| 2546 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2547 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2548 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2549 |
}
|
| 2550 |
|
| 2551 |
const int ls = bxi->scales[kqsx];
|
| 2552 |
const float d = bxi->d;
|
| 2553 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2554 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 2555 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 2556 |
#else
|
| 2557 |
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 2558 |
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 2559 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2560 |
}
|
| 2561 |
}
|
| 2562 |
|
|
@@ -2565,14 +2565,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2565 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2566 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2567 |
|
| 2568 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2569 |
int * x_qs = (int *) x_tile;
|
| 2570 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2571 |
#else
|
| 2572 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
|
| 2573 |
int * x_qs = (int *) x_tile;
|
| 2574 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2575 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2576 |
|
| 2577 |
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
|
| 2578 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2601,22 +2601,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2601 |
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
| 2602 |
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
| 2603 |
|
| 2604 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2605 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2606 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2607 |
#else
|
| 2608 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2609 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2610 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2611 |
}
|
| 2612 |
|
| 2613 |
const int ls = aux32 >> 28;
|
| 2614 |
const float d = bxi->d;
|
| 2615 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2616 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
|
| 2617 |
#else
|
| 2618 |
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
|
| 2619 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2620 |
}
|
| 2621 |
}
|
| 2622 |
|
|
@@ -2625,14 +2625,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2625 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2626 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2627 |
|
| 2628 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2629 |
int * x_qs = (int *) x_tile;
|
| 2630 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2631 |
#else
|
| 2632 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
| 2633 |
int * x_qs = (int *) x_tile;
|
| 2634 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2635 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2636 |
|
| 2637 |
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
|
| 2638 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2668,22 +2668,22 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2668 |
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
| 2669 |
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
| 2670 |
|
| 2671 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2672 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
|
| 2673 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
|
| 2674 |
#else
|
| 2675 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
|
| 2676 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
|
| 2677 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2678 |
}
|
| 2679 |
|
| 2680 |
const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
|
| 2681 |
const float d = bxi->d;
|
| 2682 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2683 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
|
| 2684 |
#else
|
| 2685 |
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
|
| 2686 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2687 |
}
|
| 2688 |
}
|
| 2689 |
|
|
@@ -2692,14 +2692,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2692 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2693 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2694 |
|
| 2695 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2696 |
int * x_qs = (int *) x_tile;
|
| 2697 |
half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2698 |
#else
|
| 2699 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
| 2700 |
int * x_qs = (int *) x_tile;
|
| 2701 |
half2 * x_ds = (half2 *) (x_qs + txs.qs);
|
| 2702 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2703 |
|
| 2704 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
|
| 2705 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2727,23 +2727,23 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2727 |
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
|
| 2728 |
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
|
| 2729 |
|
| 2730 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2731 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
|
| 2732 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
|
| 2733 |
#else
|
| 2734 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
|
| 2735 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
|
| 2736 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2737 |
}
|
| 2738 |
|
| 2739 |
const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
|
| 2740 |
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
|
| 2741 |
|
| 2742 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2743 |
x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
|
| 2744 |
#else
|
| 2745 |
x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
|
| 2746 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2747 |
}
|
| 2748 |
}
|
| 2749 |
|
|
@@ -2752,14 +2752,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2752 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2753 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2754 |
|
| 2755 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2756 |
int * x_qs = (int *) x_tile;
|
| 2757 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2758 |
#else
|
| 2759 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
|
| 2760 |
int * x_qs = (int *) x_tile;
|
| 2761 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2762 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2763 |
|
| 2764 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
|
| 2765 |
constexpr int nrows = warp_size / threads_per_row;
|
|
@@ -2779,13 +2779,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2779 |
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
| 2780 |
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
|
| 2781 |
|
| 2782 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2783 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
| 2784 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
| 2785 |
#else
|
| 2786 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
| 2787 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
|
| 2788 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2789 |
}
|
| 2790 |
|
| 2791 |
constexpr int rows_per_warp = warp_size / 8;
|
|
@@ -2804,11 +2804,11 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
| 2804 |
const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
|
| 2805 |
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
|
| 2806 |
|
| 2807 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2808 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
|
| 2809 |
#else
|
| 2810 |
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
|
| 2811 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2812 |
}
|
| 2813 |
}
|
| 2814 |
|
|
@@ -2859,9 +2859,9 @@ static __device__ __forceinline__ void mmq_write_back_mma(
|
|
| 2859 |
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
| 2860 |
|
| 2861 |
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
|
| 2862 |
-
#if defined(
|
| 2863 |
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
|
| 2864 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 2865 |
|
| 2866 |
#pragma unroll
|
| 2867 |
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
@@ -3061,13 +3061,13 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
| 3061 |
int * tile_y = data_mul_mat_q + mmq_x;
|
| 3062 |
int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
|
| 3063 |
|
| 3064 |
-
#if defined(AMD_MFMA_AVAILABLE) || defined(
|
| 3065 |
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
|
| 3066 |
constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
|
| 3067 |
#else
|
| 3068 |
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
|
| 3069 |
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
|
| 3070 |
-
#endif // defined(AMD_MFMA_AVAILABLE) || defined(
|
| 3071 |
|
| 3072 |
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
| 3073 |
|
|
@@ -3534,7 +3534,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
|
|
| 3534 |
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
|
| 3535 |
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
| 3536 |
const size_t nbs_ids = mmq_x*sizeof(int);
|
| 3537 |
-
const size_t nbs_x = (
|
| 3538 |
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
|
| 3539 |
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
|
| 3540 |
}
|
|
|
|
| 92 |
};
|
| 93 |
|
| 94 |
static int get_mmq_x_max_host(const int cc) {
|
| 95 |
+
return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 :
|
| 96 |
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
|
| 97 |
#ifdef GGML_CUDA_FORCE_MMQ
|
| 98 |
128 : 64;
|
|
|
|
| 102 |
}
|
| 103 |
|
| 104 |
static constexpr __device__ int get_mmq_x_max_device() {
|
| 105 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 106 |
return 128;
|
| 107 |
+
#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 108 |
|
| 109 |
#if defined(GGML_USE_HIP)
|
| 110 |
return 64;
|
|
|
|
| 121 |
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
| 122 |
|
| 123 |
#endif // defined(GGML_USE_HIP)
|
| 124 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 125 |
}
|
| 126 |
|
| 127 |
static int get_mmq_y_host(const int cc) {
|
|
|
|
| 233 |
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
| 234 |
if (amd_mfma_available(cc)) {
|
| 235 |
return mmq_x >= 128 ? 32 : 16;
|
| 236 |
+
} else if (turing_mma_available(cc) && mmq_x >= 48) {
|
| 237 |
return 16;
|
| 238 |
} else {
|
| 239 |
return 8;
|
|
|
|
| 244 |
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
|
| 245 |
return mmq_x >= 128 ? 32 : 16;
|
| 246 |
}
|
| 247 |
+
#elif defined(TURING_MMA_AVAILABLE)
|
| 248 |
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
|
| 249 |
return mmq_x >= 48 ? 16 : 8;
|
| 250 |
}
|
|
|
|
| 279 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 280 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 281 |
|
| 282 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 283 |
int * x_qs = (int *) x_tile;
|
| 284 |
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
|
| 285 |
#else
|
| 286 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
| 287 |
int * x_qs = (int *) x_tile;
|
| 288 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 289 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 290 |
|
| 291 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
|
| 292 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 305 |
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
|
| 306 |
const int qs0 = get_int_b2(bxi->qs, kqsx);
|
| 307 |
|
| 308 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 309 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
|
| 310 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
|
| 311 |
#else
|
| 312 |
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
| 313 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 314 |
}
|
| 315 |
|
| 316 |
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
|
|
|
|
| 327 |
|
| 328 |
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
|
| 329 |
|
| 330 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 331 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 332 |
#else
|
| 333 |
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
|
| 334 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 335 |
}
|
| 336 |
}
|
| 337 |
|
|
|
|
| 382 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 383 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 384 |
|
| 385 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 386 |
int * x_qs = (int *) x_tile;
|
| 387 |
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
| 388 |
#else
|
| 389 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
| 390 |
int * x_qs = (int *) x_tile;
|
| 391 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 392 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 393 |
|
| 394 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
|
| 395 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 408 |
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
|
| 409 |
const int qs0 = get_int_b4(bxi->qs, kqsx);
|
| 410 |
|
| 411 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 412 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
| 413 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
|
| 414 |
#else
|
| 415 |
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
| 416 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 417 |
}
|
| 418 |
|
| 419 |
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
|
|
|
|
| 430 |
|
| 431 |
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
|
| 432 |
|
| 433 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 434 |
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
| 435 |
#else
|
| 436 |
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
|
| 437 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 438 |
}
|
| 439 |
}
|
| 440 |
|
|
|
|
| 485 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 486 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 487 |
|
| 488 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 489 |
int * x_qs = (int *) x_tile;
|
| 490 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 491 |
#else
|
| 492 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
|
| 493 |
int * x_qs = (int *) x_tile;
|
| 494 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 495 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 496 |
|
| 497 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
|
| 498 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 527 |
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
| 528 |
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
|
| 529 |
|
| 530 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 531 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
| 532 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
| 533 |
#else
|
| 534 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
| 535 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
| 536 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 537 |
}
|
| 538 |
|
| 539 |
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
|
|
|
|
| 550 |
|
| 551 |
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
|
| 552 |
|
| 553 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 554 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 555 |
#else
|
| 556 |
x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
|
| 557 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 558 |
}
|
| 559 |
}
|
| 560 |
|
|
|
|
| 563 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 564 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 565 |
|
| 566 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 567 |
int * x_qs = (int *) x_tile;
|
| 568 |
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
| 569 |
#else
|
| 570 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
| 571 |
int * x_qs = (int *) x_tile;
|
| 572 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 573 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 574 |
|
| 575 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
|
| 576 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 603 |
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
|
| 604 |
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
| 605 |
|
| 606 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 607 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
| 608 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
| 609 |
#else
|
| 610 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
| 611 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
| 612 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 613 |
}
|
| 614 |
|
| 615 |
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
|
|
|
|
| 626 |
|
| 627 |
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
|
| 628 |
|
| 629 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 630 |
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
| 631 |
#else
|
| 632 |
x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
|
| 633 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 634 |
}
|
| 635 |
}
|
| 636 |
|
|
|
|
| 639 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 640 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 641 |
|
| 642 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 643 |
int * x_qs = (int *) x_tile;
|
| 644 |
float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
|
| 645 |
#else
|
| 646 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
| 647 |
int * x_qs = (int *) x_tile;
|
| 648 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 649 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 650 |
|
| 651 |
// MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
|
| 652 |
constexpr int threads_per_row = 32;
|
|
|
|
| 665 |
|
| 666 |
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
| 667 |
|
| 668 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 669 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
|
| 670 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
|
| 671 |
#else
|
| 672 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
|
| 673 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
|
| 674 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 675 |
}
|
| 676 |
|
| 677 |
constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
|
|
|
|
| 688 |
|
| 689 |
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
|
| 690 |
|
| 691 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 692 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 693 |
#else
|
| 694 |
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
|
| 695 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 696 |
}
|
| 697 |
}
|
| 698 |
|
|
|
|
| 701 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 702 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 703 |
|
| 704 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 705 |
int * x_qs = (int *) x_tile;
|
| 706 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 707 |
#else
|
| 708 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
|
| 709 |
int * x_qs = (int *) x_tile;
|
| 710 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 711 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 712 |
|
| 713 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
|
| 714 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 730 |
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
|
| 731 |
const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
|
| 732 |
|
| 733 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 734 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
|
| 735 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
|
| 736 |
#else
|
| 737 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
| 738 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
|
| 739 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 740 |
}
|
| 741 |
|
| 742 |
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
|
|
|
|
| 753 |
|
| 754 |
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
|
| 755 |
|
| 756 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 757 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
| 758 |
#else
|
| 759 |
x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
|
| 760 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 761 |
}
|
| 762 |
}
|
| 763 |
|
|
|
|
| 1178 |
}
|
| 1179 |
}
|
| 1180 |
}
|
| 1181 |
+
#elif defined(TURING_MMA_AVAILABLE)
|
| 1182 |
|
| 1183 |
typedef tile<16, 4, int> tile_A;
|
| 1184 |
typedef tile<16, 8, int> tile_A_8;
|
|
|
|
| 1264 |
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
| 1265 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 1266 |
|
| 1267 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1268 |
int * x_qs = (int *) x_tile;
|
| 1269 |
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
| 1270 |
#else
|
| 1271 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
| 1272 |
int * x_qs = (int *) x_tile;
|
| 1273 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 1274 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1275 |
|
| 1276 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
|
| 1277 |
constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
|
|
|
|
| 1295 |
|
| 1296 |
const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
|
| 1297 |
|
| 1298 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1299 |
x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
|
| 1300 |
#else
|
| 1301 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
|
| 1302 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1303 |
}
|
| 1304 |
|
| 1305 |
const int sc_m = bxi->scales[kqsx];
|
|
|
|
| 1310 |
const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
|
| 1311 |
#endif // FAST_FP16_AVAILABLE
|
| 1312 |
|
| 1313 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1314 |
x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
|
| 1315 |
#else
|
| 1316 |
x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
|
| 1317 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1318 |
}
|
| 1319 |
}
|
| 1320 |
|
|
|
|
| 1452 |
}
|
| 1453 |
}
|
| 1454 |
}
|
| 1455 |
+
#elif defined(TURING_MMA_AVAILABLE)
|
| 1456 |
|
| 1457 |
typedef tile<16, 4, int> tile_A;
|
| 1458 |
typedef tile<16, 8, int> tile_A_8;
|
|
|
|
| 1582 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 1583 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 1584 |
|
| 1585 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1586 |
int * x_qs = (int *) x_tile;
|
| 1587 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 1588 |
#else
|
|
|
|
| 1590 |
int * x_qs = (int *) x_tile;
|
| 1591 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 1592 |
int * x_sc = (int *) (x_df + txs.dm);
|
| 1593 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1594 |
|
| 1595 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
|
| 1596 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 1618 |
|
| 1619 |
const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
|
| 1620 |
|
| 1621 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1622 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
|
| 1623 |
#else
|
| 1624 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
|
| 1625 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1626 |
}
|
| 1627 |
}
|
| 1628 |
|
|
|
|
| 1649 |
|
| 1650 |
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
|
| 1651 |
|
| 1652 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1653 |
const int8_t * sc8 = (const int8_t *) ≻
|
| 1654 |
const float d = bxi->d;
|
| 1655 |
|
|
|
|
| 1659 |
}
|
| 1660 |
#else
|
| 1661 |
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
|
| 1662 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1663 |
}
|
| 1664 |
|
| 1665 |
+
#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
|
| 1666 |
#pragma unroll
|
| 1667 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
| 1668 |
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
|
|
|
| 1675 |
|
| 1676 |
x_df[i] = bxi->d;
|
| 1677 |
}
|
| 1678 |
+
#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
|
| 1679 |
}
|
| 1680 |
|
| 1681 |
template <int mmq_x, int mmq_y>
|
|
|
|
| 1728 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 1729 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 1730 |
|
| 1731 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1732 |
int * x_qs = (int *) x_tile;
|
| 1733 |
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
| 1734 |
#else
|
|
|
|
| 1736 |
int * x_qs = (int *) x_tile;
|
| 1737 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 1738 |
int * x_sc = (int *) (x_dm + txs.dm);
|
| 1739 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1740 |
|
| 1741 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
|
| 1742 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 1753 |
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
|
| 1754 |
const int qs0 = get_int_b4(bxi->qs, txi);
|
| 1755 |
|
| 1756 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1757 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
| 1758 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
|
| 1759 |
#else
|
| 1760 |
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
| 1761 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1762 |
}
|
| 1763 |
|
| 1764 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1765 |
constexpr int rows_per_warp = warp_size / 2;
|
| 1766 |
#pragma unroll
|
| 1767 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
|
|
| 1829 |
|
| 1830 |
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
|
| 1831 |
}
|
| 1832 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1833 |
}
|
| 1834 |
|
| 1835 |
template <int mmq_x, int mmq_y>
|
|
|
|
| 1872 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 1873 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 1874 |
|
| 1875 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1876 |
int * x_qs = (int *) x_tile;
|
| 1877 |
half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
|
| 1878 |
#else
|
|
|
|
| 1880 |
int * x_qs = (int *) x_tile;
|
| 1881 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 1882 |
int * x_sc = (int *) (x_dm + txs.dm);
|
| 1883 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1884 |
|
| 1885 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
|
| 1886 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 1908 |
const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
|
| 1909 |
const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
|
| 1910 |
|
| 1911 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1912 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
|
| 1913 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
|
| 1914 |
#else
|
| 1915 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
|
| 1916 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
|
| 1917 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1918 |
}
|
| 1919 |
|
| 1920 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1921 |
constexpr int rows_per_warp = warp_size / 2;
|
| 1922 |
#pragma unroll
|
| 1923 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
|
|
| 1986 |
|
| 1987 |
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
|
| 1988 |
}
|
| 1989 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 1990 |
}
|
| 1991 |
|
| 1992 |
template <int mmq_x, int mmq_y>
|
|
|
|
| 2029 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2030 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2031 |
|
| 2032 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2033 |
int * x_qs = (int *) x_tile;
|
| 2034 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2035 |
int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
|
|
|
|
| 2038 |
int * x_qs = (int *) x_tile;
|
| 2039 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2040 |
int * x_sc = (int *) (x_df + txs.dm);
|
| 2041 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2042 |
|
| 2043 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
|
| 2044 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 2065 |
const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
|
| 2066 |
const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
|
| 2067 |
|
| 2068 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2069 |
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
| 2070 |
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
| 2071 |
#else
|
| 2072 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
| 2073 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
| 2074 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2075 |
}
|
| 2076 |
|
| 2077 |
#pragma unroll
|
|
|
|
| 2084 |
|
| 2085 |
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
|
| 2086 |
|
| 2087 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2088 |
x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
|
| 2089 |
#else
|
| 2090 |
x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
|
| 2091 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2092 |
}
|
| 2093 |
|
| 2094 |
constexpr int rows_per_warp = warp_size / 4;
|
|
|
|
| 2102 |
|
| 2103 |
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
|
| 2104 |
|
| 2105 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2106 |
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
|
| 2107 |
#else
|
| 2108 |
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
|
| 2109 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2110 |
}
|
| 2111 |
}
|
| 2112 |
|
|
|
|
| 2199 |
}
|
| 2200 |
}
|
| 2201 |
}
|
| 2202 |
+
#elif defined(TURING_MMA_AVAILABLE)
|
| 2203 |
|
| 2204 |
typedef tile<16, 4, int> tile_A;
|
| 2205 |
typedef tile< 8, 4, int> tile_B;
|
|
|
|
| 2311 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2312 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2313 |
|
| 2314 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2315 |
int * x_qs = (int *) x_tile;
|
| 2316 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2317 |
#else
|
| 2318 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
|
| 2319 |
int * x_qs = (int *) x_tile;
|
| 2320 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2321 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2322 |
|
| 2323 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
|
| 2324 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 2340 |
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
| 2341 |
const int k0 = kbx * (2 * QI4_NL) + kqsx;
|
| 2342 |
|
| 2343 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2344 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
| 2345 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
|
| 2346 |
#else
|
| 2347 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
| 2348 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
|
| 2349 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2350 |
}
|
| 2351 |
|
| 2352 |
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
|
|
|
|
| 2363 |
|
| 2364 |
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
|
| 2365 |
|
| 2366 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2367 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
|
| 2368 |
#else
|
| 2369 |
x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
|
| 2370 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2371 |
}
|
| 2372 |
}
|
| 2373 |
|
|
|
|
| 2376 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2377 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2378 |
|
| 2379 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2380 |
int * x_qs = (int *) x_tile;
|
| 2381 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2382 |
#else
|
| 2383 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
|
| 2384 |
int * x_qs = (int *) x_tile;
|
| 2385 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2386 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2387 |
|
| 2388 |
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
|
| 2389 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 2414 |
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
| 2415 |
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
| 2416 |
|
| 2417 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2418 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
|
| 2419 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
|
| 2420 |
#else
|
| 2421 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
|
| 2422 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
|
| 2423 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2424 |
}
|
| 2425 |
|
| 2426 |
const int ls = aux32 >> 28;
|
| 2427 |
const float d = bxi->d;
|
| 2428 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2429 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
|
| 2430 |
#else
|
| 2431 |
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
|
| 2432 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2433 |
}
|
| 2434 |
}
|
| 2435 |
|
|
|
|
| 2438 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2439 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2440 |
|
| 2441 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2442 |
int * x_qs = (int *) x_tile;
|
| 2443 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2444 |
#else
|
| 2445 |
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
| 2446 |
int * x_qs = (int *) x_tile;
|
| 2447 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2448 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2449 |
|
| 2450 |
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
|
| 2451 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 2472 |
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
| 2473 |
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
| 2474 |
|
| 2475 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2476 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2477 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2478 |
#else
|
| 2479 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2480 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2481 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2482 |
}
|
| 2483 |
|
| 2484 |
const int ls = bxi->scales[kqsx];
|
| 2485 |
const float d = bxi->d;
|
| 2486 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2487 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 2488 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 2489 |
#else
|
| 2490 |
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 2491 |
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 2492 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2493 |
}
|
| 2494 |
}
|
| 2495 |
|
|
|
|
| 2498 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2499 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2500 |
|
| 2501 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2502 |
int * x_qs = (int *) x_tile;
|
| 2503 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2504 |
#else
|
| 2505 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
|
| 2506 |
int * x_qs = (int *) x_tile;
|
| 2507 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2508 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2509 |
|
| 2510 |
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
|
| 2511 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 2539 |
const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
|
| 2540 |
const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
|
| 2541 |
|
| 2542 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2543 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2544 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2545 |
#else
|
| 2546 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2547 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2548 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2549 |
}
|
| 2550 |
|
| 2551 |
const int ls = bxi->scales[kqsx];
|
| 2552 |
const float d = bxi->d;
|
| 2553 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2554 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 2555 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 2556 |
#else
|
| 2557 |
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 2558 |
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 2559 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2560 |
}
|
| 2561 |
}
|
| 2562 |
|
|
|
|
| 2565 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2566 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2567 |
|
| 2568 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2569 |
int * x_qs = (int *) x_tile;
|
| 2570 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2571 |
#else
|
| 2572 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
|
| 2573 |
int * x_qs = (int *) x_tile;
|
| 2574 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2575 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2576 |
|
| 2577 |
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
|
| 2578 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 2601 |
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
| 2602 |
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
| 2603 |
|
| 2604 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2605 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2606 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2607 |
#else
|
| 2608 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2609 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2610 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2611 |
}
|
| 2612 |
|
| 2613 |
const int ls = aux32 >> 28;
|
| 2614 |
const float d = bxi->d;
|
| 2615 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2616 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
|
| 2617 |
#else
|
| 2618 |
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
|
| 2619 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2620 |
}
|
| 2621 |
}
|
| 2622 |
|
|
|
|
| 2625 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2626 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2627 |
|
| 2628 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2629 |
int * x_qs = (int *) x_tile;
|
| 2630 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2631 |
#else
|
| 2632 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
| 2633 |
int * x_qs = (int *) x_tile;
|
| 2634 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2635 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2636 |
|
| 2637 |
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
|
| 2638 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 2668 |
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
| 2669 |
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
| 2670 |
|
| 2671 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2672 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
|
| 2673 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
|
| 2674 |
#else
|
| 2675 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
|
| 2676 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
|
| 2677 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2678 |
}
|
| 2679 |
|
| 2680 |
const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
|
| 2681 |
const float d = bxi->d;
|
| 2682 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2683 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
|
| 2684 |
#else
|
| 2685 |
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
|
| 2686 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2687 |
}
|
| 2688 |
}
|
| 2689 |
|
|
|
|
| 2692 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2693 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2694 |
|
| 2695 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2696 |
int * x_qs = (int *) x_tile;
|
| 2697 |
half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2698 |
#else
|
| 2699 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
| 2700 |
int * x_qs = (int *) x_tile;
|
| 2701 |
half2 * x_ds = (half2 *) (x_qs + txs.qs);
|
| 2702 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2703 |
|
| 2704 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
|
| 2705 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 2727 |
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
|
| 2728 |
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
|
| 2729 |
|
| 2730 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2731 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
|
| 2732 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
|
| 2733 |
#else
|
| 2734 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
|
| 2735 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
|
| 2736 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2737 |
}
|
| 2738 |
|
| 2739 |
const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
|
| 2740 |
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
|
| 2741 |
|
| 2742 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2743 |
x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
|
| 2744 |
#else
|
| 2745 |
x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
|
| 2746 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2747 |
}
|
| 2748 |
}
|
| 2749 |
|
|
|
|
| 2752 |
constexpr int nwarps = mmq_get_nwarps_device();
|
| 2753 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 2754 |
|
| 2755 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2756 |
int * x_qs = (int *) x_tile;
|
| 2757 |
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
| 2758 |
#else
|
| 2759 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
|
| 2760 |
int * x_qs = (int *) x_tile;
|
| 2761 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2762 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2763 |
|
| 2764 |
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
|
| 2765 |
constexpr int nrows = warp_size / threads_per_row;
|
|
|
|
| 2779 |
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
|
| 2780 |
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
|
| 2781 |
|
| 2782 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2783 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
| 2784 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
| 2785 |
#else
|
| 2786 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
| 2787 |
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
|
| 2788 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2789 |
}
|
| 2790 |
|
| 2791 |
constexpr int rows_per_warp = warp_size / 8;
|
|
|
|
| 2804 |
const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
|
| 2805 |
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
|
| 2806 |
|
| 2807 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2808 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
|
| 2809 |
#else
|
| 2810 |
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
|
| 2811 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2812 |
}
|
| 2813 |
}
|
| 2814 |
|
|
|
|
| 2859 |
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
| 2860 |
|
| 2861 |
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
|
| 2862 |
+
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
| 2863 |
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
|
| 2864 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 2865 |
|
| 2866 |
#pragma unroll
|
| 2867 |
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
|
|
| 3061 |
int * tile_y = data_mul_mat_q + mmq_x;
|
| 3062 |
int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
|
| 3063 |
|
| 3064 |
+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 3065 |
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
|
| 3066 |
constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
|
| 3067 |
#else
|
| 3068 |
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
|
| 3069 |
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
|
| 3070 |
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
| 3071 |
|
| 3072 |
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
| 3073 |
|
|
|
|
| 3534 |
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
|
| 3535 |
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
| 3536 |
const size_t nbs_ids = mmq_x*sizeof(int);
|
| 3537 |
+
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
| 3538 |
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
|
| 3539 |
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
|
| 3540 |
}
|
ggml/src/ggml-cuda/mmvf.cu
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ggml.h"
|
| 2 |
+
#include "common.cuh"
|
| 3 |
+
#include "mmvf.cuh"
|
| 4 |
+
|
| 5 |
+
template <typename T, typename type_acc, int ncols_dst, int block_size>
|
| 6 |
+
static __global__ void mul_mat_vec_f(
|
| 7 |
+
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
| 8 |
+
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
| 9 |
+
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
| 10 |
+
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
| 11 |
+
const int row = blockIdx.x;
|
| 12 |
+
const int channel_dst = blockIdx.y;
|
| 13 |
+
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
| 14 |
+
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
| 15 |
+
const int sample_dst = blockIdx.z;
|
| 16 |
+
const int sample_x = sample_dst / sample_ratio;
|
| 17 |
+
const int sample_y = sample_dst;
|
| 18 |
+
const int tid = threadIdx.x;
|
| 19 |
+
|
| 20 |
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 21 |
+
|
| 22 |
+
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
| 23 |
+
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
|
| 24 |
+
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
|
| 25 |
+
|
| 26 |
+
const float2 * y2 = (const float2 *) y;
|
| 27 |
+
|
| 28 |
+
extern __shared__ char data_mmv[];
|
| 29 |
+
float * buf_iw = (float *) data_mmv;
|
| 30 |
+
|
| 31 |
+
if (block_size > warp_size) {
|
| 32 |
+
if (tid < warp_size) {
|
| 33 |
+
buf_iw[tid] = 0.0f;
|
| 34 |
+
}
|
| 35 |
+
__syncthreads();
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
float sumf[ncols_dst] = {0.0f};
|
| 39 |
+
|
| 40 |
+
if constexpr (std::is_same_v<T, float>) {
|
| 41 |
+
const float2 * x2 = (const float2 *) x;
|
| 42 |
+
|
| 43 |
+
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
| 44 |
+
const float2 tmpx = x2[col2];
|
| 45 |
+
|
| 46 |
+
#pragma unroll
|
| 47 |
+
for (int j = 0; j < ncols_dst; ++j) {
|
| 48 |
+
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
| 49 |
+
sumf[j] += tmpx.x*tmpy.x;
|
| 50 |
+
sumf[j] += tmpx.y*tmpy.y;
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
} else if constexpr (std::is_same_v<T, half>) {
|
| 54 |
+
const half2 * x2 = (const half2 *) x;
|
| 55 |
+
|
| 56 |
+
if (std::is_same_v<type_acc, float>) {
|
| 57 |
+
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
| 58 |
+
const float2 tmpx = __half22float2(x2[col2]);
|
| 59 |
+
|
| 60 |
+
#pragma unroll
|
| 61 |
+
for (int j = 0; j < ncols_dst; ++j) {
|
| 62 |
+
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
| 63 |
+
sumf[j] += tmpx.x * tmpy.x;
|
| 64 |
+
sumf[j] += tmpx.y * tmpy.y;
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
} else {
|
| 68 |
+
#ifdef FP16_AVAILABLE
|
| 69 |
+
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
|
| 70 |
+
|
| 71 |
+
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
| 72 |
+
const half2 tmpx = x2[col2];
|
| 73 |
+
|
| 74 |
+
#pragma unroll
|
| 75 |
+
for (int j = 0; j < ncols_dst; ++j) {
|
| 76 |
+
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
| 77 |
+
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
#pragma unroll
|
| 82 |
+
for (int j = 0; j < ncols_dst; ++j) {
|
| 83 |
+
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
|
| 84 |
+
}
|
| 85 |
+
#else
|
| 86 |
+
NO_DEVICE_CODE;
|
| 87 |
+
#endif // FP16_AVAILABLE
|
| 88 |
+
}
|
| 89 |
+
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
|
| 90 |
+
const int * x2 = (const int *) x;
|
| 91 |
+
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
| 92 |
+
const int tmpx = x2[col2];
|
| 93 |
+
#pragma unroll
|
| 94 |
+
for (int j = 0; j < ncols_dst; ++j) {
|
| 95 |
+
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
| 96 |
+
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
| 97 |
+
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
} else {
|
| 101 |
+
static_assert(std::is_same_v<T, void>, "unsupported type");
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
#pragma unroll
|
| 105 |
+
for (int j = 0; j < ncols_dst; ++j) {
|
| 106 |
+
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
| 107 |
+
|
| 108 |
+
if (block_size > warp_size) {
|
| 109 |
+
buf_iw[tid/warp_size] = sumf[j];
|
| 110 |
+
__syncthreads();
|
| 111 |
+
if (tid < warp_size) {
|
| 112 |
+
sumf[j] = buf_iw[tid];
|
| 113 |
+
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
| 114 |
+
}
|
| 115 |
+
if (j < ncols_dst) {
|
| 116 |
+
__syncthreads();
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
if (tid >= ncols_dst) {
|
| 122 |
+
return;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
dst[tid*stride_col_dst + row] = sumf[tid];
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
template <typename T, typename type_acc, int ncols_dst>
|
| 129 |
+
static void launch_mul_mat_vec_f_cuda(
|
| 130 |
+
const T * x, const float * y, const int32_t * ids, float * dst,
|
| 131 |
+
const int64_t ncols, const int64_t nrows,
|
| 132 |
+
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
| 133 |
+
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
| 134 |
+
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
| 135 |
+
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
| 136 |
+
cudaStream_t stream) {
|
| 137 |
+
GGML_ASSERT(ncols % 2 == 0);
|
| 138 |
+
GGML_ASSERT(stride_row % 2 == 0);
|
| 139 |
+
GGML_ASSERT(stride_col_y % 2 == 0);
|
| 140 |
+
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
| 141 |
+
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
| 142 |
+
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
| 143 |
+
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
| 144 |
+
|
| 145 |
+
const int device = ggml_cuda_get_device();
|
| 146 |
+
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
| 147 |
+
|
| 148 |
+
int64_t block_size_best = warp_size;
|
| 149 |
+
int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
|
| 150 |
+
int64_t max_block_size = 256;
|
| 151 |
+
if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
|
| 152 |
+
max_block_size = 128;
|
| 153 |
+
}
|
| 154 |
+
for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
|
| 155 |
+
const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
|
| 156 |
+
if (niter < niter_best) {
|
| 157 |
+
niter_best = niter;
|
| 158 |
+
block_size_best = block_size;
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
const int nbytes_shared = warp_size*sizeof(float);
|
| 163 |
+
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
|
| 164 |
+
const dim3 block_dims(block_size_best, 1, 1);
|
| 165 |
+
switch (block_size_best) {
|
| 166 |
+
case 32: {
|
| 167 |
+
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 168 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 169 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 170 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 171 |
+
} break;
|
| 172 |
+
case 64: {
|
| 173 |
+
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 174 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 175 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 176 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 177 |
+
} break;
|
| 178 |
+
case 96: {
|
| 179 |
+
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 180 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 181 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 182 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 183 |
+
} break;
|
| 184 |
+
case 128: {
|
| 185 |
+
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 186 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 187 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 188 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 189 |
+
} break;
|
| 190 |
+
case 160: {
|
| 191 |
+
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 192 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 193 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 194 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 195 |
+
} break;
|
| 196 |
+
case 192: {
|
| 197 |
+
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 198 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 199 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 200 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 201 |
+
} break;
|
| 202 |
+
case 224: {
|
| 203 |
+
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 204 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 205 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 206 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 207 |
+
} break;
|
| 208 |
+
case 256: {
|
| 209 |
+
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
| 210 |
+
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
| 211 |
+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 212 |
+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
| 213 |
+
} break;
|
| 214 |
+
default: {
|
| 215 |
+
GGML_ABORT("fatal error");
|
| 216 |
+
} break;
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
template <typename T, typename type_acc>
|
| 221 |
+
static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
| 222 |
+
const T * x, const float * y, const int32_t * ids, float * dst,
|
| 223 |
+
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
| 224 |
+
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
| 225 |
+
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
| 226 |
+
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
| 227 |
+
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
| 228 |
+
cudaStream_t stream) {
|
| 229 |
+
switch (ncols_dst) {
|
| 230 |
+
case 1:
|
| 231 |
+
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
|
| 232 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 233 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 234 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 235 |
+
break;
|
| 236 |
+
case 2:
|
| 237 |
+
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
|
| 238 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 239 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 240 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 241 |
+
break;
|
| 242 |
+
case 3:
|
| 243 |
+
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
|
| 244 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 245 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 246 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 247 |
+
break;
|
| 248 |
+
case 4:
|
| 249 |
+
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
|
| 250 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 251 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 252 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 253 |
+
break;
|
| 254 |
+
case 5:
|
| 255 |
+
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
|
| 256 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 257 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 258 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 259 |
+
break;
|
| 260 |
+
case 6:
|
| 261 |
+
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
|
| 262 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 263 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 264 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 265 |
+
break;
|
| 266 |
+
case 7:
|
| 267 |
+
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
|
| 268 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 269 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 270 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 271 |
+
break;
|
| 272 |
+
case 8:
|
| 273 |
+
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
|
| 274 |
+
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
| 275 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 276 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 277 |
+
break;
|
| 278 |
+
default:
|
| 279 |
+
GGML_ABORT("fatal error");
|
| 280 |
+
break;
|
| 281 |
+
}
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
template<typename T>
|
| 285 |
+
static void mul_mat_vec_f_cuda(
|
| 286 |
+
const T * x, const float * y, const int32_t * ids, float * dst,
|
| 287 |
+
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
| 288 |
+
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
|
| 289 |
+
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
| 290 |
+
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
| 291 |
+
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
| 292 |
+
enum ggml_prec prec, cudaStream_t stream) {
|
| 293 |
+
if constexpr(std::is_same_v<T, half>) {
|
| 294 |
+
if (prec == GGML_PREC_DEFAULT) {
|
| 295 |
+
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
|
| 296 |
+
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
| 297 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 298 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 299 |
+
return;
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
|
| 303 |
+
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
| 304 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
| 305 |
+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
| 309 |
+
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
| 310 |
+
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
| 311 |
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
| 312 |
+
|
| 313 |
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
| 314 |
+
|
| 315 |
+
const size_t ts_src0 = ggml_type_size(src0->type);
|
| 316 |
+
const size_t ts_src1 = ggml_type_size(src1->type);
|
| 317 |
+
const size_t ts_dst = ggml_type_size(dst->type);
|
| 318 |
+
|
| 319 |
+
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
|
| 320 |
+
GGML_ASSERT(ne13 == ne3);
|
| 321 |
+
|
| 322 |
+
GGML_ASSERT( nb00 == ts_src0);
|
| 323 |
+
GGML_ASSERT( nb10 == ts_src1);
|
| 324 |
+
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
|
| 325 |
+
GGML_ASSERT( nb0 == ts_dst);
|
| 326 |
+
|
| 327 |
+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 328 |
+
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
| 329 |
+
|
| 330 |
+
const float * src1_d = (const float *) src1->data;
|
| 331 |
+
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
| 332 |
+
float * dst_d = (float *) dst->data;
|
| 333 |
+
|
| 334 |
+
const int64_t s01 = src0->nb[1] / ts_src0;
|
| 335 |
+
const int64_t s11 = src1->nb[1] / ts_src1;
|
| 336 |
+
const int64_t s1 = dst->nb[1] / ts_dst;
|
| 337 |
+
const int64_t s02 = src0->nb[2] / ts_src0;
|
| 338 |
+
const int64_t s12 = src1->nb[2] / ts_src1;
|
| 339 |
+
const int64_t s2 = dst->nb[2] / ts_dst;
|
| 340 |
+
const int64_t s03 = src0->nb[3] / ts_src0;
|
| 341 |
+
const int64_t s13 = src1->nb[3] / ts_src1;
|
| 342 |
+
const int64_t s3 = dst->nb[3] / ts_dst;
|
| 343 |
+
|
| 344 |
+
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
|
| 345 |
+
const int64_t ncols_dst = ids ? ne2 : ne1;
|
| 346 |
+
const int64_t nchannels_y = ids ? ne11 : ne12;
|
| 347 |
+
const int64_t nchannels_dst = ids ? ne1 : ne2;
|
| 348 |
+
const int64_t stride_channel_dst = ids ? s1 : s2;
|
| 349 |
+
const int64_t stride_channel_y = ids ? s11 : s12;
|
| 350 |
+
|
| 351 |
+
GGML_ASSERT(!ids || ncols_dst == 1);
|
| 352 |
+
|
| 353 |
+
switch (src0->type) {
|
| 354 |
+
case GGML_TYPE_F32: {
|
| 355 |
+
const float * src0_d = (const float *) src0->data;
|
| 356 |
+
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
| 357 |
+
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
| 358 |
+
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
| 359 |
+
} break;
|
| 360 |
+
case GGML_TYPE_F16: {
|
| 361 |
+
const half * src0_d = (const half *) src0->data;
|
| 362 |
+
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
| 363 |
+
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
| 364 |
+
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
| 365 |
+
} break;
|
| 366 |
+
case GGML_TYPE_BF16: {
|
| 367 |
+
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
| 368 |
+
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
| 369 |
+
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
| 370 |
+
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
| 371 |
+
} break;
|
| 372 |
+
default:
|
| 373 |
+
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
| 374 |
+
}
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
void ggml_cuda_op_mul_mat_vec_f(
|
| 378 |
+
ggml_backend_cuda_context & ctx,
|
| 379 |
+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
| 380 |
+
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
| 381 |
+
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
| 382 |
+
|
| 383 |
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
| 384 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 385 |
+
|
| 386 |
+
const int64_t ne00 = src0->ne[0];
|
| 387 |
+
const int64_t ne10 = src1->ne[0];
|
| 388 |
+
const int64_t ne0 = dst->ne[0];
|
| 389 |
+
const int64_t row_diff = row_high - row_low;
|
| 390 |
+
|
| 391 |
+
const int id = ggml_cuda_get_device();
|
| 392 |
+
const int cc = ggml_cuda_info().devices[id].cc;
|
| 393 |
+
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
// ggml_cuda_op provides single, contiguous matrices
|
| 397 |
+
const int64_t stride_row = ne00;
|
| 398 |
+
const int64_t stride_col_y = ne10;
|
| 399 |
+
const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
|
| 400 |
+
const int64_t nchannels_x = 1;
|
| 401 |
+
const int64_t nchannels_y = 1;
|
| 402 |
+
const int64_t nchannels_dst = 1;
|
| 403 |
+
const int64_t stride_channel_x = 0;
|
| 404 |
+
const int64_t stride_channel_y = 0;
|
| 405 |
+
const int64_t stride_channel_dst = 0;
|
| 406 |
+
const int64_t nsamples_x = 1;
|
| 407 |
+
const int64_t nsamples_dst = 1;
|
| 408 |
+
const int64_t stride_sample_x = 0;
|
| 409 |
+
const int64_t stride_sample_y = 0;
|
| 410 |
+
const int64_t stride_sample_dst = 0;
|
| 411 |
+
|
| 412 |
+
switch (src0->type) {
|
| 413 |
+
case GGML_TYPE_F32: {
|
| 414 |
+
const float * src0_d = (const float *) src0_dd_i;
|
| 415 |
+
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
| 416 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 417 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
| 418 |
+
} break;
|
| 419 |
+
case GGML_TYPE_F16: {
|
| 420 |
+
const half * src0_d = (const half *) src0_dd_i;
|
| 421 |
+
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
| 422 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 423 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
| 424 |
+
} break;
|
| 425 |
+
case GGML_TYPE_BF16: {
|
| 426 |
+
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
| 427 |
+
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
| 428 |
+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
| 429 |
+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
| 430 |
+
} break;
|
| 431 |
+
default:
|
| 432 |
+
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
GGML_UNUSED(ctx);
|
| 436 |
+
GGML_UNUSED(src1);
|
| 437 |
+
GGML_UNUSED(dst);
|
| 438 |
+
GGML_UNUSED(src1_ddq_i);
|
| 439 |
+
GGML_UNUSED(src1_ncols);
|
| 440 |
+
GGML_UNUSED(src1_padded_row_size);
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
|
| 444 |
+
if (src0_ne[0] % 2 != 0) {
|
| 445 |
+
return false;
|
| 446 |
+
}
|
| 447 |
+
switch (type) {
|
| 448 |
+
case GGML_TYPE_F32:
|
| 449 |
+
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
| 450 |
+
if (ampere_mma_available(cc)) {
|
| 451 |
+
return ne11 <= 3;
|
| 452 |
+
}
|
| 453 |
+
if (cc >= GGML_CUDA_CC_TURING) {
|
| 454 |
+
return ne11 <= 4;
|
| 455 |
+
}
|
| 456 |
+
return ne11 <= 3;
|
| 457 |
+
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
|
| 458 |
+
if (fp32_mma_hardware_available(cc)) {
|
| 459 |
+
return ne11 <= 3;
|
| 460 |
+
}
|
| 461 |
+
return ne11 <= 8;
|
| 462 |
+
}
|
| 463 |
+
return ne11 <= 8;
|
| 464 |
+
case GGML_TYPE_F16:
|
| 465 |
+
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
| 466 |
+
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
|
| 467 |
+
if (ampere_mma_available(cc)) {
|
| 468 |
+
return src0_small && ne11 == 1;
|
| 469 |
+
}
|
| 470 |
+
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
| 471 |
+
return src0_small && ne11 <= 4;
|
| 472 |
+
}
|
| 473 |
+
if (fp16_mma_hardware_available(cc)) {
|
| 474 |
+
return src0_small && ne11 <= 3;
|
| 475 |
+
}
|
| 476 |
+
return ne11 <= 8;
|
| 477 |
+
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
|
| 478 |
+
if (fp16_mma_hardware_available(cc)) {
|
| 479 |
+
if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
| 480 |
+
return ne11 <= 5;
|
| 481 |
+
}
|
| 482 |
+
return ne11 <= 2;
|
| 483 |
+
}
|
| 484 |
+
return ne11 <= 8;
|
| 485 |
+
}
|
| 486 |
+
return ne11 <= 8;
|
| 487 |
+
case GGML_TYPE_BF16:
|
| 488 |
+
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
| 489 |
+
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
|
| 490 |
+
if (ampere_mma_available(cc)) {
|
| 491 |
+
return src0_small && ne11 == 1;
|
| 492 |
+
}
|
| 493 |
+
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
| 494 |
+
return src0_small && ne11 <= 4;
|
| 495 |
+
}
|
| 496 |
+
if (bf16_mma_hardware_available(cc)) {
|
| 497 |
+
return src0_small && ne11 <= 3;
|
| 498 |
+
}
|
| 499 |
+
return ne11 <= 8;
|
| 500 |
+
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
|
| 501 |
+
if (bf16_mma_hardware_available(cc)) {
|
| 502 |
+
return ne11 <= 3;
|
| 503 |
+
}
|
| 504 |
+
return ne11 <= 8;
|
| 505 |
+
}
|
| 506 |
+
return ne11 <= 8;
|
| 507 |
+
default:
|
| 508 |
+
return false;
|
| 509 |
+
}
|
| 510 |
+
}
|
ggml/src/ggml-cuda/mmvf.cuh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
|
| 3 |
+
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
| 4 |
+
|
| 5 |
+
void ggml_cuda_op_mul_mat_vec_f(
|
| 6 |
+
ggml_backend_cuda_context & ctx,
|
| 7 |
+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
| 8 |
+
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
| 9 |
+
const int64_t src1_padded_row_size, cudaStream_t stream);
|
| 10 |
+
|
| 11 |
+
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
|
ggml/src/ggml-cuda/vendors/hip.h
CHANGED
|
@@ -200,6 +200,7 @@
|
|
| 200 |
#endif
|
| 201 |
|
| 202 |
typedef hip_bfloat16 nv_bfloat16;
|
|
|
|
| 203 |
|
| 204 |
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
| 205 |
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
|
|
|
| 200 |
#endif
|
| 201 |
|
| 202 |
typedef hip_bfloat16 nv_bfloat16;
|
| 203 |
+
typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix
|
| 204 |
|
| 205 |
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
| 206 |
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
ggml/src/ggml-cuda/vendors/musa.h
CHANGED
|
@@ -137,4 +137,5 @@
|
|
| 137 |
#define cudaStreamEndCapture musaStreamEndCapture
|
| 138 |
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
|
| 139 |
|
| 140 |
-
typedef
|
|
|
|
|
|
| 137 |
#define cudaStreamEndCapture musaStreamEndCapture
|
| 138 |
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
|
| 139 |
|
| 140 |
+
typedef __mt_bfloat16 nv_bfloat16;
|
| 141 |
+
typedef __mt_bfloat162 nv_bfloat162;
|