Spaces:
Sleeping
Sleeping
Commit
·
f328957
1
Parent(s):
9ed1962
CUDA: use mma PTX instructions for FlashAttention (llama/11583)
Browse files* CUDA: use mma PTX instructions for FlashAttention
* __shfl_sync workaround for movmatrix
* add __shfl_sync to HIP
Co-authored-by: Diego Devesa <[email protected]>
- ggml/include/ggml.h +1 -1
- ggml/src/ggml-cuda/CMakeLists.txt +1 -1
- ggml/src/ggml-cuda/common.cuh +4 -2
- ggml/src/ggml-cuda/fattn-common.cuh +154 -25
- ggml/src/ggml-cuda/fattn-mma-f16.cuh +637 -0
- ggml/src/ggml-cuda/fattn-tile-f16.cu +18 -6
- ggml/src/ggml-cuda/fattn-tile-f32.cu +13 -6
- ggml/src/ggml-cuda/fattn-vec-f16.cuh +8 -1
- ggml/src/ggml-cuda/fattn-vec-f32.cuh +7 -1
- ggml/src/ggml-cuda/fattn-wmma-f16.cu +648 -0
- ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -541
- ggml/src/ggml-cuda/fattn.cu +50 -124
- ggml/src/ggml-cuda/mma.cuh +286 -49
- ggml/src/ggml-cuda/mmq.cu +1 -1
- ggml/src/ggml-cuda/mmq.cuh +176 -173
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu +10 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu +10 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu +10 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu +10 -0
- ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +0 -10
- ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +0 -9
- ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +0 -10
- ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +0 -10
- ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +0 -8
- ggml/src/ggml-cuda/template-instances/generate_cu_files.py +8 -16
- ggml/src/ggml-cuda/vendors/hip.h +1 -0
- ggml/src/ggml-hip/CMakeLists.txt +1 -1
- ggml/src/ggml-musa/CMakeLists.txt +1 -1
ggml/include/ggml.h
CHANGED
|
@@ -1775,7 +1775,7 @@ extern "C" {
|
|
| 1775 |
struct ggml_tensor * a,
|
| 1776 |
int k);
|
| 1777 |
|
| 1778 |
-
#define GGML_KQ_MASK_PAD
|
| 1779 |
|
| 1780 |
// q: [n_embd, n_batch, n_head, 1]
|
| 1781 |
// k: [n_embd, n_kv, n_head_kv, 1]
|
|
|
|
| 1775 |
struct ggml_tensor * a,
|
| 1776 |
int k);
|
| 1777 |
|
| 1778 |
+
#define GGML_KQ_MASK_PAD 64
|
| 1779 |
|
| 1780 |
// q: [n_embd, n_batch, n_head, 1]
|
| 1781 |
// k: [n_embd, n_kv, n_head_kv, 1]
|
ggml/src/ggml-cuda/CMakeLists.txt
CHANGED
|
@@ -28,7 +28,7 @@ if (CUDAToolkit_FOUND)
|
|
| 28 |
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
|
| 29 |
|
| 30 |
file(GLOB GGML_SOURCES_CUDA "*.cu")
|
| 31 |
-
file(GLOB SRCS "template-instances/fattn-
|
| 32 |
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
| 33 |
file(GLOB SRCS "template-instances/mmq*.cu")
|
| 34 |
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
|
|
|
| 28 |
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
|
| 29 |
|
| 30 |
file(GLOB GGML_SOURCES_CUDA "*.cu")
|
| 31 |
+
file(GLOB SRCS "template-instances/fattn-mma*.cu")
|
| 32 |
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
| 33 |
file(GLOB SRCS "template-instances/mmq*.cu")
|
| 34 |
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
ggml/src/ggml-cuda/common.cuh
CHANGED
|
@@ -148,7 +148,7 @@ typedef float2 dfloat2;
|
|
| 148 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
| 149 |
|
| 150 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
| 151 |
-
#define
|
| 152 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
| 153 |
|
| 154 |
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
|
@@ -159,11 +159,13 @@ static constexpr bool fast_fp16_available(const int cc) {
|
|
| 159 |
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
|
| 160 |
}
|
| 161 |
|
|
|
|
| 162 |
static constexpr bool fp16_mma_available(const int cc) {
|
| 163 |
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
|
| 164 |
}
|
| 165 |
|
| 166 |
-
|
|
|
|
| 167 |
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
|
| 168 |
}
|
| 169 |
|
|
|
|
| 148 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
| 149 |
|
| 150 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
| 151 |
+
#define NEW_MMA_AVAILABLE
|
| 152 |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
| 153 |
|
| 154 |
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
|
|
|
| 159 |
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
|
| 160 |
}
|
| 161 |
|
| 162 |
+
// Any FP16 tensor cores are available.
|
| 163 |
static constexpr bool fp16_mma_available(const int cc) {
|
| 164 |
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
|
| 165 |
}
|
| 166 |
|
| 167 |
+
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
| 168 |
+
static constexpr bool new_mma_available(const int cc) {
|
| 169 |
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
|
| 170 |
}
|
| 171 |
|
ggml/src/ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -516,6 +516,104 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
|
| 516 |
nullptr;
|
| 517 |
}
|
| 518 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
template<int D, int parallel_blocks> // D == head size
|
| 520 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 521 |
__launch_bounds__(D, 1)
|
|
@@ -581,10 +679,11 @@ static void on_no_fattn_vec_case(const int D) {
|
|
| 581 |
}
|
| 582 |
}
|
| 583 |
|
| 584 |
-
|
|
|
|
| 585 |
void launch_fattn(
|
| 586 |
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
| 587 |
-
const int nwarps, const
|
| 588 |
) {
|
| 589 |
const ggml_tensor * Q = dst->src[0];
|
| 590 |
const ggml_tensor * K = dst->src[1];
|
|
@@ -603,20 +702,23 @@ void launch_fattn(
|
|
| 603 |
|
| 604 |
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
| 605 |
|
|
|
|
|
|
|
| 606 |
ggml_cuda_pool & pool = ctx.pool();
|
| 607 |
cudaStream_t main_stream = ctx.stream();
|
|
|
|
| 608 |
|
| 609 |
ggml_cuda_pool_alloc<half> K_f16(pool);
|
| 610 |
ggml_cuda_pool_alloc<half> V_f16(pool);
|
| 611 |
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
| 612 |
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
| 613 |
|
| 614 |
-
char * K_data = (char *) K->data;
|
| 615 |
size_t nb11 = K->nb[1];
|
| 616 |
size_t nb12 = K->nb[2];
|
| 617 |
size_t nb13 = K->nb[3];
|
| 618 |
|
| 619 |
-
char * V_data = (char *) V->data;
|
| 620 |
size_t nb21 = V->nb[1];
|
| 621 |
size_t nb22 = V->nb[2];
|
| 622 |
size_t nb23 = V->nb[3];
|
|
@@ -649,39 +751,60 @@ void launch_fattn(
|
|
| 649 |
nb23 = nb23*bs*sizeof(half)/ts;
|
| 650 |
}
|
| 651 |
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
| 655 |
-
}
|
| 656 |
|
| 657 |
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
| 658 |
-
|
| 659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
|
| 661 |
float scale = 1.0f;
|
| 662 |
float max_bias = 0.0f;
|
| 663 |
float logit_softcap = 0.0f;
|
| 664 |
|
| 665 |
-
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
| 666 |
-
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
| 667 |
-
memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
|
| 668 |
|
| 669 |
if (logit_softcap != 0.0f) {
|
| 670 |
scale /= logit_softcap;
|
| 671 |
}
|
| 672 |
|
| 673 |
const uint32_t n_head = Q->ne[2];
|
| 674 |
-
const uint32_t n_head_log2 = 1u << (
|
| 675 |
|
| 676 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 677 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 678 |
|
| 679 |
-
fattn_kernel<<<blocks_num, block_dim,
|
| 680 |
(const char *) Q->data,
|
| 681 |
K_data,
|
| 682 |
V_data,
|
| 683 |
mask ? ((const char *) mask->data) : nullptr,
|
| 684 |
-
(parallel_blocks)
|
| 685 |
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
| 686 |
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 687 |
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
@@ -693,16 +816,22 @@ void launch_fattn(
|
|
| 693 |
);
|
| 694 |
CUDA_CHECK(cudaGetLastError());
|
| 695 |
|
| 696 |
-
if (
|
| 697 |
-
|
| 698 |
-
|
|
|
|
| 699 |
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
|
|
|
| 707 |
CUDA_CHECK(cudaGetLastError());
|
| 708 |
}
|
|
|
|
| 516 |
nullptr;
|
| 517 |
}
|
| 518 |
|
| 519 |
+
template<int D, int ncols, int KQ_stride> // D == head size
|
| 520 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 521 |
+
__launch_bounds__(D, 1)
|
| 522 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 523 |
+
static __global__ void flash_attn_stream_k_fixup(
|
| 524 |
+
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
| 525 |
+
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
| 526 |
+
|
| 527 |
+
const int iter_k = ne11 / KQ_stride;
|
| 528 |
+
const int iter_j = (ne01 + (ncols - 1)) / ncols;
|
| 529 |
+
|
| 530 |
+
const int bidx0 = blockIdx.x;
|
| 531 |
+
|
| 532 |
+
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
|
| 533 |
+
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
|
| 534 |
+
|
| 535 |
+
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
| 536 |
+
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
| 537 |
+
const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
|
| 538 |
+
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
|
| 539 |
+
return;
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
const int channel = kbc0 / (iter_k*iter_j);
|
| 543 |
+
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
| 544 |
+
|
| 545 |
+
dst += jt*ncols*ne02*D + channel*D;
|
| 546 |
+
|
| 547 |
+
// Load the partial result that needs a fixup:
|
| 548 |
+
float dst_val[ncols] = {0.0f};
|
| 549 |
+
float max_val[ncols] = {0.0f};
|
| 550 |
+
float rowsum[ncols] = {0.0f};
|
| 551 |
+
#pragma unroll
|
| 552 |
+
for (int j = 0; j < ncols; ++j) {
|
| 553 |
+
if (jt*ncols + j >= ne01) {
|
| 554 |
+
break;
|
| 555 |
+
}
|
| 556 |
+
dst_val[j] = dst[j*ne02*D + threadIdx.x];
|
| 557 |
+
|
| 558 |
+
const float2 tmp = dst_fixup[bidx0*ncols + j];
|
| 559 |
+
max_val[j] = tmp.x;
|
| 560 |
+
rowsum[j] = tmp.y;
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
// Iterate over previous blocks and compute the combined results.
|
| 564 |
+
// All CUDA blocks that get here must have a previous block that needs a fixup.
|
| 565 |
+
int bidx = bidx0 - 1;
|
| 566 |
+
int kbc_stop = kbc0;
|
| 567 |
+
while(true) {
|
| 568 |
+
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
|
| 569 |
+
if (kbc == kbc_stop) { // Did not have any data.
|
| 570 |
+
bidx--;
|
| 571 |
+
kbc_stop = kbc;
|
| 572 |
+
continue;
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
#pragma unroll
|
| 576 |
+
for (int j = 0; j < ncols; ++j) {
|
| 577 |
+
if (jt*ncols + j >= ne01) {
|
| 578 |
+
break;
|
| 579 |
+
}
|
| 580 |
+
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
|
| 581 |
+
|
| 582 |
+
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
|
| 583 |
+
|
| 584 |
+
// Scale the current and new value accumulators depending on the max. values.
|
| 585 |
+
const float max_val_new = fmaxf(max_val[j], tmp.x);
|
| 586 |
+
|
| 587 |
+
const float diff_val = max_val[j] - max_val_new;
|
| 588 |
+
const float diff_add = tmp.x - max_val_new;
|
| 589 |
+
|
| 590 |
+
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
|
| 591 |
+
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
|
| 592 |
+
|
| 593 |
+
dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
|
| 594 |
+
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
|
| 595 |
+
|
| 596 |
+
max_val[j] = max_val_new;
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
| 600 |
+
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
|
| 601 |
+
break;
|
| 602 |
+
}
|
| 603 |
+
bidx--;
|
| 604 |
+
kbc_stop = kbc;
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
// Write back final result:
|
| 608 |
+
#pragma unroll
|
| 609 |
+
for (int j = 0; j < ncols; ++j) {
|
| 610 |
+
if (jt*ncols + j >= ne01) {
|
| 611 |
+
return;
|
| 612 |
+
}
|
| 613 |
+
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
|
| 614 |
+
}
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
template<int D, int parallel_blocks> // D == head size
|
| 618 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 619 |
__launch_bounds__(D, 1)
|
|
|
|
| 679 |
}
|
| 680 |
}
|
| 681 |
|
| 682 |
+
// parallel_blocks == 0 is stream-k decomposition
|
| 683 |
+
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
|
| 684 |
void launch_fattn(
|
| 685 |
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
| 686 |
+
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
|
| 687 |
) {
|
| 688 |
const ggml_tensor * Q = dst->src[0];
|
| 689 |
const ggml_tensor * K = dst->src[1];
|
|
|
|
| 702 |
|
| 703 |
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
| 704 |
|
| 705 |
+
GGML_ASSERT(Q->ne[3] == 1);
|
| 706 |
+
|
| 707 |
ggml_cuda_pool & pool = ctx.pool();
|
| 708 |
cudaStream_t main_stream = ctx.stream();
|
| 709 |
+
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
| 710 |
|
| 711 |
ggml_cuda_pool_alloc<half> K_f16(pool);
|
| 712 |
ggml_cuda_pool_alloc<half> V_f16(pool);
|
| 713 |
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
| 714 |
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
| 715 |
|
| 716 |
+
const char * K_data = (const char *) K->data;
|
| 717 |
size_t nb11 = K->nb[1];
|
| 718 |
size_t nb12 = K->nb[2];
|
| 719 |
size_t nb13 = K->nb[3];
|
| 720 |
|
| 721 |
+
const char * V_data = (const char *) V->data;
|
| 722 |
size_t nb21 = V->nb[1];
|
| 723 |
size_t nb22 = V->nb[2];
|
| 724 |
size_t nb23 = V->nb[3];
|
|
|
|
| 751 |
nb23 = nb23*bs*sizeof(half)/ts;
|
| 752 |
}
|
| 753 |
|
| 754 |
+
const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
|
| 755 |
+
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
|
|
|
|
|
|
|
| 756 |
|
| 757 |
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
| 758 |
+
dim3 blocks_num;
|
| 759 |
+
if (parallel_blocks == 0) {
|
| 760 |
+
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
| 761 |
+
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
|
| 762 |
+
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
|
| 763 |
+
const bool short_context = K->ne[1] < 4096;
|
| 764 |
+
|
| 765 |
+
const int nblocks_stream_k = 2*nsm;
|
| 766 |
+
|
| 767 |
+
blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
|
| 768 |
+
blocks_num.y = 1;
|
| 769 |
+
blocks_num.z = 1;
|
| 770 |
+
|
| 771 |
+
dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
|
| 772 |
+
} else {
|
| 773 |
+
blocks_num.x = parallel_blocks*ntiles_x;
|
| 774 |
+
blocks_num.y = Q->ne[2];
|
| 775 |
+
blocks_num.z = Q->ne[3];
|
| 776 |
+
|
| 777 |
+
if (parallel_blocks > 1) {
|
| 778 |
+
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
| 779 |
+
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
| 780 |
+
}
|
| 781 |
+
}
|
| 782 |
+
|
| 783 |
|
| 784 |
float scale = 1.0f;
|
| 785 |
float max_bias = 0.0f;
|
| 786 |
float logit_softcap = 0.0f;
|
| 787 |
|
| 788 |
+
memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
|
| 789 |
+
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
| 790 |
+
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
| 791 |
|
| 792 |
if (logit_softcap != 0.0f) {
|
| 793 |
scale /= logit_softcap;
|
| 794 |
}
|
| 795 |
|
| 796 |
const uint32_t n_head = Q->ne[2];
|
| 797 |
+
const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
|
| 798 |
|
| 799 |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 800 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 801 |
|
| 802 |
+
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
| 803 |
(const char *) Q->data,
|
| 804 |
K_data,
|
| 805 |
V_data,
|
| 806 |
mask ? ((const char *) mask->data) : nullptr,
|
| 807 |
+
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
| 808 |
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
| 809 |
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 810 |
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
|
|
| 816 |
);
|
| 817 |
CUDA_CHECK(cudaGetLastError());
|
| 818 |
|
| 819 |
+
if constexpr (parallel_blocks == 0) {
|
| 820 |
+
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
| 821 |
+
const dim3 block_dim_combine(D, 1, 1);
|
| 822 |
+
const dim3 blocks_num_combine = blocks_num;
|
| 823 |
|
| 824 |
+
flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
|
| 825 |
+
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
| 826 |
+
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
| 827 |
+
}
|
| 828 |
+
} else if constexpr (parallel_blocks > 1) {
|
| 829 |
+
const dim3 block_dim_combine(D, 1, 1);
|
| 830 |
+
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
| 831 |
|
| 832 |
+
flash_attn_combine_results<D, parallel_blocks>
|
| 833 |
+
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
| 834 |
+
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
| 835 |
+
}
|
| 836 |
CUDA_CHECK(cudaGetLastError());
|
| 837 |
}
|
ggml/src/ggml-cuda/fattn-mma-f16.cuh
ADDED
|
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
#include "mma.cuh"
|
| 3 |
+
#include "fattn-common.cuh"
|
| 4 |
+
|
| 5 |
+
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
|
| 6 |
+
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
| 7 |
+
const float2 * const __restrict__ Q_f2,
|
| 8 |
+
const half2 * const __restrict__ K_h2,
|
| 9 |
+
const half2 * const __restrict__ V_h2,
|
| 10 |
+
const half * const __restrict__ maskh,
|
| 11 |
+
float2 * const __restrict__ dstk,
|
| 12 |
+
float2 * const __restrict__ dstk_fixup,
|
| 13 |
+
const float scale,
|
| 14 |
+
const float slope,
|
| 15 |
+
const float logit_softcap,
|
| 16 |
+
const int ne00,
|
| 17 |
+
const int ne01,
|
| 18 |
+
const int ne02,
|
| 19 |
+
const int ne03,
|
| 20 |
+
const int ne10,
|
| 21 |
+
const int ne11,
|
| 22 |
+
const int ne12,
|
| 23 |
+
const int ne13,
|
| 24 |
+
const int ne31,
|
| 25 |
+
const int nb31,
|
| 26 |
+
const int nb01,
|
| 27 |
+
const int nb02,
|
| 28 |
+
const int nb03,
|
| 29 |
+
const int nb11,
|
| 30 |
+
const int nb12,
|
| 31 |
+
const int nb13,
|
| 32 |
+
const int nb21,
|
| 33 |
+
const int nb22,
|
| 34 |
+
const int nb23,
|
| 35 |
+
const int ne0,
|
| 36 |
+
const int ne1,
|
| 37 |
+
const int ne2,
|
| 38 |
+
const int ne3,
|
| 39 |
+
const int jt,
|
| 40 |
+
const int kb0_start,
|
| 41 |
+
const int kb0_stop) {
|
| 42 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 43 |
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 44 |
+
|
| 45 |
+
typedef mma_A_I16K8<half2> mma_A;
|
| 46 |
+
typedef mma_B_J8K8<half2> mma_B;
|
| 47 |
+
typedef mma_C_I16J8<float> mma_C_KQ;
|
| 48 |
+
typedef mma_C_I16J8<half2> mma_C_VKQ;
|
| 49 |
+
|
| 50 |
+
static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps");
|
| 51 |
+
constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column.
|
| 52 |
+
|
| 53 |
+
static_assert(D % nwarps == 0, "bad D");
|
| 54 |
+
static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
|
| 55 |
+
|
| 56 |
+
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
| 57 |
+
extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements.
|
| 58 |
+
|
| 59 |
+
const int stride_Q = nb01 / sizeof(float2);
|
| 60 |
+
const int stride_KV = nb11 / sizeof(half2);
|
| 61 |
+
const int stride_mask = nb31 / sizeof(half);
|
| 62 |
+
|
| 63 |
+
mma_B Q_B[D/(2*mma_B::K)];
|
| 64 |
+
mma_C_VKQ VKQ_C[D/mma_C_VKQ::I];
|
| 65 |
+
|
| 66 |
+
float2 KQ_rowsum = {0.0f, 0.0f};
|
| 67 |
+
float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
|
| 68 |
+
float2 KQ_max_scale = {0.0f, 0.0f};
|
| 69 |
+
|
| 70 |
+
// Temporarily load Q data into tile_KV, will be loaded into registers afterwards.
|
| 71 |
+
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
| 72 |
+
const half2 scale_h2 = make_half2(scale, scale);
|
| 73 |
+
#pragma unroll
|
| 74 |
+
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 75 |
+
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
|
| 76 |
+
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
| 77 |
+
const int stride_j = WARP_SIZE / stride_k;
|
| 78 |
+
|
| 79 |
+
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
|
| 80 |
+
break;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
#pragma unroll
|
| 84 |
+
for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
|
| 85 |
+
const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
| 86 |
+
|
| 87 |
+
if (jt*ncols + j < ne01) {
|
| 88 |
+
#pragma unroll
|
| 89 |
+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 90 |
+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 91 |
+
|
| 92 |
+
const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
|
| 93 |
+
tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
| 94 |
+
}
|
| 95 |
+
} else {
|
| 96 |
+
#pragma unroll
|
| 97 |
+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 98 |
+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 99 |
+
|
| 100 |
+
tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f);
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
__syncthreads();
|
| 107 |
+
|
| 108 |
+
{
|
| 109 |
+
const int j0 = (threadIdx.y / np) * mma_B::J;
|
| 110 |
+
|
| 111 |
+
#pragma unroll
|
| 112 |
+
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
|
| 113 |
+
Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded);
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
__syncthreads();
|
| 118 |
+
|
| 119 |
+
// Iterate over ne11 == previous tokens:
|
| 120 |
+
for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) {
|
| 121 |
+
const int k_VKQ_0 = kb0*KQ_stride;
|
| 122 |
+
mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)];
|
| 123 |
+
|
| 124 |
+
// Load K data into tile with decreasing granularity for D for better memory bandwidth:
|
| 125 |
+
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
|
| 126 |
+
#pragma unroll
|
| 127 |
+
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 128 |
+
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
|
| 129 |
+
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
| 130 |
+
const int stride_i = WARP_SIZE / stride_k;
|
| 131 |
+
|
| 132 |
+
#pragma unroll
|
| 133 |
+
for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) {
|
| 134 |
+
const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
| 135 |
+
|
| 136 |
+
#pragma unroll
|
| 137 |
+
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) {
|
| 138 |
+
const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 139 |
+
|
| 140 |
+
tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ];
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
__syncthreads();
|
| 146 |
+
|
| 147 |
+
// Calculate tile of KQ:
|
| 148 |
+
#pragma unroll
|
| 149 |
+
for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) {
|
| 150 |
+
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I;
|
| 151 |
+
#pragma unroll
|
| 152 |
+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) {
|
| 153 |
+
mma_A K_A;
|
| 154 |
+
K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
|
| 155 |
+
KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]);
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
__syncthreads();
|
| 160 |
+
|
| 161 |
+
if (use_logit_softcap) {
|
| 162 |
+
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
|
| 163 |
+
#pragma unroll
|
| 164 |
+
for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) {
|
| 165 |
+
#pragma unroll
|
| 166 |
+
for (int l = 0; l < mma_C_KQ::ne; ++l) {
|
| 167 |
+
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
if (maskh) {
|
| 173 |
+
static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size");
|
| 174 |
+
static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size");
|
| 175 |
+
#pragma unroll
|
| 176 |
+
for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) {
|
| 177 |
+
const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I;
|
| 178 |
+
#pragma unroll
|
| 179 |
+
for (int l = 0; l < mma_C_KQ::ne; ++l) {
|
| 180 |
+
const int i = i0 + mma_C_KQ::get_i(l);
|
| 181 |
+
const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l);
|
| 182 |
+
|
| 183 |
+
KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
// Calculate softmax for each KQ column using the current max. value.
|
| 189 |
+
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
| 190 |
+
float2 KQ_max_new = KQ_max;
|
| 191 |
+
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
|
| 192 |
+
#pragma unroll
|
| 193 |
+
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
|
| 194 |
+
#pragma unroll
|
| 195 |
+
for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) {
|
| 196 |
+
KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
|
| 197 |
+
KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
// Values per KQ column are spread across 8 threads, does not need full warp reduce:
|
| 202 |
+
#pragma unroll
|
| 203 |
+
for (int offset = 16; offset > 2; offset >>= 1) {
|
| 204 |
+
KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
|
| 205 |
+
KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
{
|
| 209 |
+
const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
|
| 210 |
+
KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
|
| 211 |
+
if (diff.x <= SOFTMAX_FTZ_THRESHOLD) {
|
| 212 |
+
KQ_max_scale.x = 0.0f;
|
| 213 |
+
}
|
| 214 |
+
if (diff.y <= SOFTMAX_FTZ_THRESHOLD) {
|
| 215 |
+
KQ_max_scale.y = 0.0f;
|
| 216 |
+
}
|
| 217 |
+
KQ_max = KQ_max_new;
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
|
| 221 |
+
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
|
| 222 |
+
#pragma unroll
|
| 223 |
+
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
|
| 224 |
+
#pragma unroll
|
| 225 |
+
for (int l = 0; l < mma_C_KQ::ne; ++l) {
|
| 226 |
+
const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y;
|
| 227 |
+
const float diff = KQ_C[k].x[l] - KQ_max_l;
|
| 228 |
+
KQ_C[k].x[l] = expf(diff);
|
| 229 |
+
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
| 230 |
+
KQ_C[k].x[l] = 0.0f;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
if (l % 2 == 0) {
|
| 234 |
+
KQ_rowsum_add.x += KQ_C[k].x[l];
|
| 235 |
+
} else {
|
| 236 |
+
KQ_rowsum_add.y += KQ_C[k].x[l];
|
| 237 |
+
}
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
| 242 |
+
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
|
| 243 |
+
KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y;
|
| 244 |
+
|
| 245 |
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
|
| 246 |
+
#pragma unroll
|
| 247 |
+
for (int i = 0; i < D/mma_C_VKQ::I; ++i) {
|
| 248 |
+
#pragma unroll
|
| 249 |
+
for (int l = 0; l < mma_C_VKQ::ne; ++l) {
|
| 250 |
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
// Convert KQ C tiles into B tiles for VKQ calculation:
|
| 255 |
+
mma_B B[KQ_stride/(np*2*mma_B::K)];
|
| 256 |
+
static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size");
|
| 257 |
+
#pragma unroll
|
| 258 |
+
for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) {
|
| 259 |
+
B[k] = KQ_C[k].to_mma_B();
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
// Load V data into tile with decreasing granularity for D for better memory bandwidth:
|
| 263 |
+
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
|
| 264 |
+
#pragma unroll
|
| 265 |
+
for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 266 |
+
const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i);
|
| 267 |
+
const int i0_stop = D/2 - (D/2) % (1*stride_i);
|
| 268 |
+
const int stride_k = WARP_SIZE / stride_i;
|
| 269 |
+
|
| 270 |
+
#pragma unroll
|
| 271 |
+
for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) {
|
| 272 |
+
const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i);
|
| 273 |
+
|
| 274 |
+
#pragma unroll
|
| 275 |
+
for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) {
|
| 276 |
+
const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i);
|
| 277 |
+
|
| 278 |
+
tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V];
|
| 279 |
+
}
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
__syncthreads();
|
| 284 |
+
|
| 285 |
+
// Calculate VKQ tile:
|
| 286 |
+
#pragma unroll
|
| 287 |
+
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) {
|
| 288 |
+
static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size");
|
| 289 |
+
#pragma unroll
|
| 290 |
+
for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) {
|
| 291 |
+
const int k0 = k00 + (threadIdx.y % np)*mma_A::K;
|
| 292 |
+
|
| 293 |
+
mma_A A;
|
| 294 |
+
A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
|
| 295 |
+
VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]);
|
| 296 |
+
}
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
__syncthreads();
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
// Finally, sum up partial KQ rowsums.
|
| 303 |
+
// The partial sums are spread across 8 threads each, does not need full reduce.
|
| 304 |
+
#pragma unroll
|
| 305 |
+
for (int offset = 16; offset > 2; offset >>= 1) {
|
| 306 |
+
KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE);
|
| 307 |
+
KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE);
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
// Write VKQ accumulators to shared memory in column-major format.
|
| 311 |
+
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
| 312 |
+
// Also for np > 1 the combination is done via these values in shared memory.
|
| 313 |
+
const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data
|
| 314 |
+
#pragma unroll
|
| 315 |
+
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
|
| 316 |
+
const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format.
|
| 317 |
+
|
| 318 |
+
#pragma unroll
|
| 319 |
+
for (int l = 0; l < mma_B::ne; ++l) {
|
| 320 |
+
const int k = k0 + mma_B::get_k(l);
|
| 321 |
+
|
| 322 |
+
tile_KV[j_cwd*D2_padded + k] = B.x[l];
|
| 323 |
+
}
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset
|
| 327 |
+
const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
|
| 328 |
+
const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
|
| 329 |
+
|
| 330 |
+
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) {
|
| 331 |
+
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
| 332 |
+
((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
__syncthreads();
|
| 336 |
+
|
| 337 |
+
static_assert(np == 1 || np == 2 || np == 4, "bad np");
|
| 338 |
+
if (np == 1) {
|
| 339 |
+
// No combination is needed, the meta data can be directly written from registers to VRAM.
|
| 340 |
+
if (needs_fixup && threadIdx.x < mma_B::J) {
|
| 341 |
+
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
| 342 |
+
dstk_fixup_meta[j_cwm] = KQ_cmr;
|
| 343 |
+
}
|
| 344 |
+
if (is_fixup && threadIdx.x < mma_B::J) {
|
| 345 |
+
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
| 346 |
+
dstk_fixup_meta[j_cwm] = KQ_cmr;
|
| 347 |
+
}
|
| 348 |
+
} else if (threadIdx.y % np == 0) {
|
| 349 |
+
// Combine the meta data for parallel warps via shared memory.
|
| 350 |
+
// Warps with threadIdx.y % np != 0 must NOT return early.
|
| 351 |
+
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
| 352 |
+
|
| 353 |
+
float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2;
|
| 354 |
+
|
| 355 |
+
float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
|
| 356 |
+
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
|
| 357 |
+
KQ_cm = meta_j[0];
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
|
| 361 |
+
#pragma unroll
|
| 362 |
+
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
|
| 363 |
+
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
|
| 367 |
+
float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
|
| 368 |
+
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
|
| 369 |
+
KQ_crs = KQ_cms*meta_j[1];
|
| 370 |
+
}
|
| 371 |
+
#pragma unroll
|
| 372 |
+
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
|
| 373 |
+
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
// Write back combined meta data:
|
| 377 |
+
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
|
| 378 |
+
meta_j[0] = KQ_cmn; // Combined max. KQ values.
|
| 379 |
+
meta_j[1] = KQ_crs; // Combined KQ rowsums.
|
| 380 |
+
meta_j[2] = KQ_cms; // KQ max scales per parallel warp.
|
| 381 |
+
}
|
| 382 |
+
if (needs_fixup && threadIdx.x < mma_B::J) {
|
| 383 |
+
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
| 384 |
+
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
| 385 |
+
}
|
| 386 |
+
if (is_fixup && threadIdx.x < mma_B::J) {
|
| 387 |
+
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
| 388 |
+
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
| 389 |
+
}
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
if (np > 1) {
|
| 393 |
+
__syncthreads();
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
if (np == 1 || threadIdx.y % np == 0) {
|
| 397 |
+
// The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
|
| 398 |
+
// The values after that are for the partial results of the individual blocks.
|
| 399 |
+
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2));
|
| 400 |
+
|
| 401 |
+
#pragma unroll
|
| 402 |
+
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 403 |
+
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
|
| 404 |
+
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
| 405 |
+
const int stride_j = WARP_SIZE / stride_k;
|
| 406 |
+
|
| 407 |
+
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
|
| 408 |
+
break;
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
#pragma unroll
|
| 412 |
+
for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
|
| 413 |
+
const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
| 414 |
+
const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J;
|
| 415 |
+
|
| 416 |
+
if (!is_fixup && jt*ncols + j_dst >= ne01) {
|
| 417 |
+
continue;
|
| 418 |
+
}
|
| 419 |
+
const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2;
|
| 420 |
+
#pragma unroll
|
| 421 |
+
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 422 |
+
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 423 |
+
|
| 424 |
+
float2 dstk_val = make_float2(0.0f, 0.0f);
|
| 425 |
+
#pragma unroll
|
| 426 |
+
for (int ip = 0; ip < np; ++ip) {
|
| 427 |
+
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2];
|
| 428 |
+
const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]);
|
| 429 |
+
dstk_val.x += dstk_val_add.x*KQ_crs;
|
| 430 |
+
dstk_val.y += dstk_val_add.y*KQ_crs;
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
if (!needs_fixup && !is_fixup) {
|
| 434 |
+
const float KQ_rowsum_j = meta_j[1];
|
| 435 |
+
dstk_val.x /= KQ_rowsum_j;
|
| 436 |
+
dstk_val.y /= KQ_rowsum_j;
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
if (is_fixup) {
|
| 440 |
+
dstk_fixup_data[j_dst*(D/2) + k] = dstk_val;
|
| 441 |
+
} else {
|
| 442 |
+
dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val;
|
| 443 |
+
}
|
| 444 |
+
}
|
| 445 |
+
}
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
if (np > 1) {
|
| 450 |
+
__syncthreads();
|
| 451 |
+
}
|
| 452 |
+
#else
|
| 453 |
+
NO_DEVICE_CODE;
|
| 454 |
+
#endif // NEW_MMA_AVAILABLE
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap>
|
| 458 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 459 |
+
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
| 460 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 461 |
+
static __global__ void flash_attn_ext_f16(
|
| 462 |
+
const char * __restrict__ Q,
|
| 463 |
+
const char * __restrict__ K,
|
| 464 |
+
const char * __restrict__ V,
|
| 465 |
+
const char * __restrict__ mask,
|
| 466 |
+
float * __restrict__ dst,
|
| 467 |
+
float2 * __restrict__ dst_meta,
|
| 468 |
+
const float scale,
|
| 469 |
+
const float max_bias,
|
| 470 |
+
const float m0,
|
| 471 |
+
const float m1,
|
| 472 |
+
const uint32_t n_head_log2,
|
| 473 |
+
const float logit_softcap,
|
| 474 |
+
const int ne00,
|
| 475 |
+
const int ne01,
|
| 476 |
+
const int ne02,
|
| 477 |
+
const int ne03,
|
| 478 |
+
const int ne10,
|
| 479 |
+
const int ne11,
|
| 480 |
+
const int ne12,
|
| 481 |
+
const int ne13,
|
| 482 |
+
const int ne31,
|
| 483 |
+
const int nb31,
|
| 484 |
+
const int nb01,
|
| 485 |
+
const int nb02,
|
| 486 |
+
const int nb03,
|
| 487 |
+
const int nb11,
|
| 488 |
+
const int nb12,
|
| 489 |
+
const int nb13,
|
| 490 |
+
const int nb21,
|
| 491 |
+
const int nb22,
|
| 492 |
+
const int nb23,
|
| 493 |
+
const int ne0,
|
| 494 |
+
const int ne1,
|
| 495 |
+
const int ne2,
|
| 496 |
+
const int ne3) {
|
| 497 |
+
// Skip unused kernel variants for faster compilation:
|
| 498 |
+
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 499 |
+
NO_DEVICE_CODE;
|
| 500 |
+
return;
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride");
|
| 504 |
+
|
| 505 |
+
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 506 |
+
|
| 507 |
+
const int iter_k = ne11 / KQ_stride;
|
| 508 |
+
const int iter_j = (ne01 + (ncols - 1)) / ncols;
|
| 509 |
+
|
| 510 |
+
// kbc == k block continuous, current index in continuous ijk space.
|
| 511 |
+
int kbc = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x;
|
| 512 |
+
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x;
|
| 513 |
+
|
| 514 |
+
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
| 515 |
+
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
| 516 |
+
// In the most general case >2 seams can fall into the same tile.
|
| 517 |
+
|
| 518 |
+
// kb0 == k start index when in the output tile.
|
| 519 |
+
int kb0_start = kbc % iter_k;
|
| 520 |
+
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
| 521 |
+
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
| 522 |
+
const int channel = kbc / (iter_k*iter_j);
|
| 523 |
+
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
| 524 |
+
|
| 525 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel);
|
| 526 |
+
const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio));
|
| 527 |
+
const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
|
| 528 |
+
const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
|
| 529 |
+
float2 * dstk = ((float2 *) dst) + channel*(D/2);
|
| 530 |
+
|
| 531 |
+
const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
|
| 532 |
+
|
| 533 |
+
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
| 534 |
+
if (kb0_start == 0) {
|
| 535 |
+
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
| 536 |
+
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
| 537 |
+
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
| 538 |
+
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
|
| 539 |
+
jt, kb0_start, kb0_stop);
|
| 540 |
+
} else {
|
| 541 |
+
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
| 542 |
+
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
| 543 |
+
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
| 544 |
+
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
|
| 545 |
+
jt, kb0_start, kb0_stop);
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
kbc += iter_k;
|
| 549 |
+
kbc -= kbc % iter_k;
|
| 550 |
+
|
| 551 |
+
kb0_start = 0;
|
| 552 |
+
kb0_stop = min(iter_k, kbc_stop - kbc);
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
if (kbc >= kbc_stop) {
|
| 556 |
+
return;
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
const int channel = kbc / (iter_k*iter_j);
|
| 560 |
+
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
| 561 |
+
|
| 562 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel);
|
| 563 |
+
const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio));
|
| 564 |
+
const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
|
| 565 |
+
const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
|
| 566 |
+
float2 * dstk = ((float2 *) dst) + channel*(D/2);
|
| 567 |
+
|
| 568 |
+
const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
|
| 569 |
+
|
| 570 |
+
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
| 571 |
+
constexpr bool needs_fixup = false;
|
| 572 |
+
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
|
| 573 |
+
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
|
| 574 |
+
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
|
| 575 |
+
jt, kb0_start, kb0_stop);
|
| 576 |
+
}
|
| 577 |
+
|
| 578 |
+
template <int D, int cols_per_block>
|
| 579 |
+
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 580 |
+
typedef mma_A_I16K8<half2> mma_A;
|
| 581 |
+
typedef mma_B_J8K8<half2> mma_B;
|
| 582 |
+
|
| 583 |
+
static_assert(D % mma_B::K == 0, "bad D");
|
| 584 |
+
static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block");
|
| 585 |
+
|
| 586 |
+
const ggml_tensor * KQV = dst;
|
| 587 |
+
|
| 588 |
+
constexpr int KQ_stride = D <= 128 ? 64 : 32;
|
| 589 |
+
constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
|
| 590 |
+
cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8);
|
| 591 |
+
constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half);
|
| 592 |
+
|
| 593 |
+
float logit_softcap;
|
| 594 |
+
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
| 595 |
+
|
| 596 |
+
fattn_kernel_t fattn_kernel;
|
| 597 |
+
if (logit_softcap == 0.0f) {
|
| 598 |
+
constexpr bool use_logit_softcap = false;
|
| 599 |
+
fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
|
| 600 |
+
} else {
|
| 601 |
+
constexpr bool use_logit_softcap = true;
|
| 602 |
+
fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
|
| 603 |
+
}
|
| 604 |
+
launch_fattn<D, cols_per_block, 0, KQ_stride>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
#define DECL_FATTN_MMA_F16_CASE(D, cols_per_block) \
|
| 608 |
+
template void ggml_cuda_flash_attn_ext_mma_f16_case \
|
| 609 |
+
<D, cols_per_block>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
| 610 |
+
|
| 611 |
+
extern DECL_FATTN_MMA_F16_CASE( 64, 8);
|
| 612 |
+
extern DECL_FATTN_MMA_F16_CASE( 80, 8);
|
| 613 |
+
extern DECL_FATTN_MMA_F16_CASE( 96, 8);
|
| 614 |
+
extern DECL_FATTN_MMA_F16_CASE(112, 8);
|
| 615 |
+
extern DECL_FATTN_MMA_F16_CASE(128, 8);
|
| 616 |
+
extern DECL_FATTN_MMA_F16_CASE(256, 8);
|
| 617 |
+
|
| 618 |
+
extern DECL_FATTN_MMA_F16_CASE( 64, 16);
|
| 619 |
+
extern DECL_FATTN_MMA_F16_CASE( 80, 16);
|
| 620 |
+
extern DECL_FATTN_MMA_F16_CASE( 96, 16);
|
| 621 |
+
extern DECL_FATTN_MMA_F16_CASE(112, 16);
|
| 622 |
+
extern DECL_FATTN_MMA_F16_CASE(128, 16);
|
| 623 |
+
extern DECL_FATTN_MMA_F16_CASE(256, 16);
|
| 624 |
+
|
| 625 |
+
extern DECL_FATTN_MMA_F16_CASE( 64, 32);
|
| 626 |
+
extern DECL_FATTN_MMA_F16_CASE( 80, 32);
|
| 627 |
+
extern DECL_FATTN_MMA_F16_CASE( 96, 32);
|
| 628 |
+
extern DECL_FATTN_MMA_F16_CASE(112, 32);
|
| 629 |
+
extern DECL_FATTN_MMA_F16_CASE(128, 32);
|
| 630 |
+
extern DECL_FATTN_MMA_F16_CASE(256, 32);
|
| 631 |
+
|
| 632 |
+
extern DECL_FATTN_MMA_F16_CASE( 64, 64);
|
| 633 |
+
extern DECL_FATTN_MMA_F16_CASE( 80, 64);
|
| 634 |
+
extern DECL_FATTN_MMA_F16_CASE( 96, 64);
|
| 635 |
+
extern DECL_FATTN_MMA_F16_CASE(112, 64);
|
| 636 |
+
extern DECL_FATTN_MMA_F16_CASE(128, 64);
|
| 637 |
+
extern DECL_FATTN_MMA_F16_CASE(256, 64);
|
ggml/src/ggml-cuda/fattn-tile-f16.cu
CHANGED
|
@@ -45,7 +45,17 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
| 45 |
const int ne2,
|
| 46 |
const int ne3) {
|
| 47 |
#ifdef FP16_AVAILABLE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
// Skip unused kernel variants for faster compilation:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 50 |
NO_DEVICE_CODE;
|
| 51 |
return;
|
|
@@ -288,16 +298,18 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 288 |
const ggml_tensor * Q = dst->src[0];
|
| 289 |
switch (Q->ne[0]) {
|
| 290 |
case 64: {
|
| 291 |
-
constexpr int
|
| 292 |
-
constexpr int
|
|
|
|
| 293 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 294 |
-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps,
|
| 295 |
} break;
|
| 296 |
case 128: {
|
| 297 |
-
constexpr int
|
| 298 |
-
constexpr int
|
|
|
|
| 299 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 300 |
-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps,
|
| 301 |
} break;
|
| 302 |
default: {
|
| 303 |
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
|
|
|
| 45 |
const int ne2,
|
| 46 |
const int ne3) {
|
| 47 |
#ifdef FP16_AVAILABLE
|
| 48 |
+
|
| 49 |
+
#ifndef FLASH_ATTN_AVAILABLE
|
| 50 |
+
NO_DEVICE_CODE;
|
| 51 |
+
return;
|
| 52 |
+
#endif // FLASH_ATTN_AVAILABLE
|
| 53 |
+
|
| 54 |
// Skip unused kernel variants for faster compilation:
|
| 55 |
+
#ifdef FP16_MMA_AVAILABLE
|
| 56 |
+
NO_DEVICE_CODE;
|
| 57 |
+
return;
|
| 58 |
+
#endif // FP16_MMA_AVAILABLE
|
| 59 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 60 |
NO_DEVICE_CODE;
|
| 61 |
return;
|
|
|
|
| 298 |
const ggml_tensor * Q = dst->src[0];
|
| 299 |
switch (Q->ne[0]) {
|
| 300 |
case 64: {
|
| 301 |
+
constexpr int D = 64;
|
| 302 |
+
constexpr int nwarps = 8;
|
| 303 |
+
constexpr size_t nbytes_shared = 0;
|
| 304 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 305 |
+
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
| 306 |
} break;
|
| 307 |
case 128: {
|
| 308 |
+
constexpr int D = 128;
|
| 309 |
+
constexpr int nwarps = 8;
|
| 310 |
+
constexpr size_t nbytes_shared = 0;
|
| 311 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 312 |
+
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
| 313 |
} break;
|
| 314 |
default: {
|
| 315 |
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
ggml/src/ggml-cuda/fattn-tile-f32.cu
CHANGED
|
@@ -48,7 +48,12 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
| 48 |
NO_DEVICE_CODE;
|
| 49 |
return;
|
| 50 |
#endif // FLASH_ATTN_AVAILABLE
|
|
|
|
| 51 |
// Skip unused kernel variants for faster compilation:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 53 |
NO_DEVICE_CODE;
|
| 54 |
return;
|
|
@@ -287,16 +292,18 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 287 |
const ggml_tensor * Q = dst->src[0];
|
| 288 |
switch (Q->ne[0]) {
|
| 289 |
case 64: {
|
| 290 |
-
constexpr int
|
| 291 |
-
constexpr int
|
|
|
|
| 292 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 293 |
-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps,
|
| 294 |
} break;
|
| 295 |
case 128: {
|
| 296 |
-
constexpr int
|
| 297 |
-
constexpr int
|
|
|
|
| 298 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 299 |
-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps,
|
| 300 |
} break;
|
| 301 |
default: {
|
| 302 |
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
|
|
|
| 48 |
NO_DEVICE_CODE;
|
| 49 |
return;
|
| 50 |
#endif // FLASH_ATTN_AVAILABLE
|
| 51 |
+
|
| 52 |
// Skip unused kernel variants for faster compilation:
|
| 53 |
+
#ifdef FP16_MMA_AVAILABLE
|
| 54 |
+
NO_DEVICE_CODE;
|
| 55 |
+
return;
|
| 56 |
+
#endif // FP16_MMA_AVAILABLE
|
| 57 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 58 |
NO_DEVICE_CODE;
|
| 59 |
return;
|
|
|
|
| 292 |
const ggml_tensor * Q = dst->src[0];
|
| 293 |
switch (Q->ne[0]) {
|
| 294 |
case 64: {
|
| 295 |
+
constexpr int D = 64;
|
| 296 |
+
constexpr int nwarps = 8;
|
| 297 |
+
constexpr size_t nbytes_shared = 0;
|
| 298 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 299 |
+
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
| 300 |
} break;
|
| 301 |
case 128: {
|
| 302 |
+
constexpr int D = 128;
|
| 303 |
+
constexpr int nwarps = 8;
|
| 304 |
+
constexpr size_t nbytes_shared = 0;
|
| 305 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 306 |
+
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
| 307 |
} break;
|
| 308 |
default: {
|
| 309 |
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
ggml/src/ggml-cuda/fattn-vec-f16.cuh
CHANGED
|
@@ -42,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 42 |
const int ne2,
|
| 43 |
const int ne3) {
|
| 44 |
#ifdef FP16_AVAILABLE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
// Skip unused kernel variants for faster compilation:
|
| 46 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 47 |
NO_DEVICE_CODE;
|
|
@@ -303,7 +309,8 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
|
|
| 303 |
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
| 304 |
constexpr bool need_f16_K = D != 128;
|
| 305 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 306 |
-
|
|
|
|
| 307 |
}
|
| 308 |
|
| 309 |
template <int D, ggml_type type_K, ggml_type type_V>
|
|
|
|
| 42 |
const int ne2,
|
| 43 |
const int ne3) {
|
| 44 |
#ifdef FP16_AVAILABLE
|
| 45 |
+
|
| 46 |
+
#ifndef FLASH_ATTN_AVAILABLE
|
| 47 |
+
NO_DEVICE_CODE;
|
| 48 |
+
return;
|
| 49 |
+
#endif // FLASH_ATTN_AVAILABLE
|
| 50 |
+
|
| 51 |
// Skip unused kernel variants for faster compilation:
|
| 52 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 53 |
NO_DEVICE_CODE;
|
|
|
|
| 309 |
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
| 310 |
constexpr bool need_f16_K = D != 128;
|
| 311 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 312 |
+
constexpr size_t nbytes_shared = 0;
|
| 313 |
+
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
| 314 |
}
|
| 315 |
|
| 316 |
template <int D, ggml_type type_K, ggml_type type_V>
|
ggml/src/ggml-cuda/fattn-vec-f32.cuh
CHANGED
|
@@ -41,6 +41,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
| 41 |
const int ne1,
|
| 42 |
const int ne2,
|
| 43 |
const int ne3) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
// Skip unused kernel variants for faster compilation:
|
| 45 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 46 |
NO_DEVICE_CODE;
|
|
@@ -284,7 +289,8 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
|
|
| 284 |
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
| 285 |
constexpr bool need_f16_K = D != 128;
|
| 286 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 287 |
-
|
|
|
|
| 288 |
}
|
| 289 |
|
| 290 |
template <int D, ggml_type type_K, ggml_type type_V>
|
|
|
|
| 41 |
const int ne1,
|
| 42 |
const int ne2,
|
| 43 |
const int ne3) {
|
| 44 |
+
#ifndef FLASH_ATTN_AVAILABLE
|
| 45 |
+
NO_DEVICE_CODE;
|
| 46 |
+
return;
|
| 47 |
+
#endif // FLASH_ATTN_AVAILABLE
|
| 48 |
+
|
| 49 |
// Skip unused kernel variants for faster compilation:
|
| 50 |
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 51 |
NO_DEVICE_CODE;
|
|
|
|
| 289 |
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
| 290 |
constexpr bool need_f16_K = D != 128;
|
| 291 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 292 |
+
constexpr size_t nbytes_shared = 0;
|
| 293 |
+
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
| 294 |
}
|
| 295 |
|
| 296 |
template <int D, ggml_type type_K, ggml_type type_V>
|
ggml/src/ggml-cuda/fattn-wmma-f16.cu
ADDED
|
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Old and deprecated WMMA FlashAttention implementation.
|
| 2 |
+
// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing.
|
| 3 |
+
// Long-term the WMMA code should be replaced with a dedicated Volta implementation.
|
| 4 |
+
|
| 5 |
+
#include "common.cuh"
|
| 6 |
+
#include "fattn-common.cuh"
|
| 7 |
+
#include "fattn-wmma-f16.cuh"
|
| 8 |
+
|
| 9 |
+
#ifdef FP16_MMA_AVAILABLE
|
| 10 |
+
#include <mma.h>
|
| 11 |
+
#endif // FP16_MMA_AVAILABLE
|
| 12 |
+
|
| 13 |
+
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
| 14 |
+
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
| 15 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 16 |
+
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 17 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 18 |
+
static __global__ void flash_attn_ext_f16(
|
| 19 |
+
const char * __restrict__ Q,
|
| 20 |
+
const char * __restrict__ K,
|
| 21 |
+
const char * __restrict__ V,
|
| 22 |
+
const char * __restrict__ mask,
|
| 23 |
+
float * __restrict__ dst,
|
| 24 |
+
float2 * __restrict__ dst_meta,
|
| 25 |
+
const float scale,
|
| 26 |
+
const float max_bias,
|
| 27 |
+
const float m0,
|
| 28 |
+
const float m1,
|
| 29 |
+
const uint32_t n_head_log2,
|
| 30 |
+
const float logit_softcap,
|
| 31 |
+
const int ne00,
|
| 32 |
+
const int ne01,
|
| 33 |
+
const int ne02,
|
| 34 |
+
const int ne03,
|
| 35 |
+
const int ne10,
|
| 36 |
+
const int ne11,
|
| 37 |
+
const int ne12,
|
| 38 |
+
const int ne13,
|
| 39 |
+
const int ne31,
|
| 40 |
+
const int nb31,
|
| 41 |
+
const int nb01,
|
| 42 |
+
const int nb02,
|
| 43 |
+
const int nb03,
|
| 44 |
+
const int nb11,
|
| 45 |
+
const int nb12,
|
| 46 |
+
const int nb13,
|
| 47 |
+
const int nb21,
|
| 48 |
+
const int nb22,
|
| 49 |
+
const int nb23,
|
| 50 |
+
const int ne0,
|
| 51 |
+
const int ne1,
|
| 52 |
+
const int ne2,
|
| 53 |
+
const int ne3) {
|
| 54 |
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
| 55 |
+
// Skip unused kernel variants for faster compilation:
|
| 56 |
+
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 57 |
+
NO_DEVICE_CODE;
|
| 58 |
+
return;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 62 |
+
|
| 63 |
+
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
| 64 |
+
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 65 |
+
|
| 66 |
+
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
| 67 |
+
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
| 68 |
+
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
| 69 |
+
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
| 70 |
+
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
| 71 |
+
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
|
| 72 |
+
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
|
| 73 |
+
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
|
| 74 |
+
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
| 75 |
+
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
| 76 |
+
|
| 77 |
+
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
| 78 |
+
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
| 79 |
+
static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
|
| 80 |
+
|
| 81 |
+
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
|
| 82 |
+
constexpr int D_padded = D + 8;
|
| 83 |
+
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
| 84 |
+
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
| 85 |
+
|
| 86 |
+
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 87 |
+
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
|
| 88 |
+
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
|
| 89 |
+
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
| 90 |
+
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
| 91 |
+
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
| 92 |
+
|
| 93 |
+
const int stride_Q = nb01 / sizeof(float);
|
| 94 |
+
const int stride_KV = nb11 / sizeof(half);
|
| 95 |
+
|
| 96 |
+
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
| 97 |
+
const half slopeh = __float2half(slopef);
|
| 98 |
+
const half2 slope2 = make_half2(slopef, slopef);
|
| 99 |
+
|
| 100 |
+
const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
|
| 101 |
+
|
| 102 |
+
frag_b Q_b[D/16][ncols/frag_n];
|
| 103 |
+
|
| 104 |
+
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
| 105 |
+
constexpr int mem_KQ = ncols*kqs_padded*kqar;
|
| 106 |
+
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
|
| 107 |
+
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
|
| 108 |
+
float * KQ_f = (float *) KQ;
|
| 109 |
+
half2 * KQ2 = (half2 *) KQ;
|
| 110 |
+
|
| 111 |
+
float KQ_rowsum_f[ncols/nwarps] = {0.0f};
|
| 112 |
+
float KQ_max_f[ncols/nwarps];
|
| 113 |
+
float KQ_max_scale_f[ncols/nwarps] = {0.0f};
|
| 114 |
+
|
| 115 |
+
#pragma unroll
|
| 116 |
+
for (int j = 0; j < ncols/nwarps; ++j) {
|
| 117 |
+
KQ_max_f[j] = -FLT_MAX/2.0f;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
|
| 121 |
+
half2 KQ_max_h2[ncols/nwarps];
|
| 122 |
+
half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
|
| 123 |
+
|
| 124 |
+
#pragma unroll
|
| 125 |
+
for (int j = 0; j < ncols/nwarps; ++j) {
|
| 126 |
+
KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
| 130 |
+
half2 * VKQ2 = (half2 *) VKQ;
|
| 131 |
+
#pragma unroll
|
| 132 |
+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 133 |
+
const int j = j0 + threadIdx.y;
|
| 134 |
+
#pragma unroll
|
| 135 |
+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 136 |
+
const int i = i0 + threadIdx.x;
|
| 137 |
+
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
| 138 |
+
break;
|
| 139 |
+
}
|
| 140 |
+
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
// Convert Q to half and apply scale, temporarily store in KQ:
|
| 145 |
+
#pragma unroll
|
| 146 |
+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 147 |
+
const int j = j0 + threadIdx.y;
|
| 148 |
+
#pragma unroll
|
| 149 |
+
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
| 150 |
+
const int i = i0 + threadIdx.x;
|
| 151 |
+
if (i0 + WARP_SIZE > D && i >= D) {
|
| 152 |
+
break;
|
| 153 |
+
}
|
| 154 |
+
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
__syncthreads();
|
| 159 |
+
|
| 160 |
+
// Load Q into tensor core fragments/registers since it will be used frequently:
|
| 161 |
+
#pragma unroll
|
| 162 |
+
for (int i0 = 0; i0 < D; i0 += 16) {
|
| 163 |
+
#pragma unroll
|
| 164 |
+
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 165 |
+
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
__syncthreads();
|
| 170 |
+
|
| 171 |
+
// Iterate over ne11 == previous tokens:
|
| 172 |
+
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
|
| 173 |
+
// Calculate tile of KQ:
|
| 174 |
+
#pragma unroll
|
| 175 |
+
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
| 176 |
+
frag_c_KQ KQ_c[ncols/frag_n];
|
| 177 |
+
#pragma unroll
|
| 178 |
+
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 179 |
+
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
|
| 180 |
+
}
|
| 181 |
+
#pragma unroll
|
| 182 |
+
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
| 183 |
+
frag_a_K K_a;
|
| 184 |
+
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
| 185 |
+
#pragma unroll
|
| 186 |
+
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 187 |
+
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
#pragma unroll
|
| 191 |
+
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 192 |
+
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
__syncthreads();
|
| 197 |
+
|
| 198 |
+
// Calculate softmax for each KQ column using the current max. value.
|
| 199 |
+
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
| 200 |
+
#pragma unroll
|
| 201 |
+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 202 |
+
const int j = j0 + threadIdx.y;
|
| 203 |
+
|
| 204 |
+
if (std::is_same<KQ_acc_t, float>::value) {
|
| 205 |
+
float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
|
| 206 |
+
#pragma unroll
|
| 207 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
| 208 |
+
const int k = k0 + threadIdx.x;
|
| 209 |
+
|
| 210 |
+
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
| 211 |
+
|
| 212 |
+
if (use_logit_softcap) {
|
| 213 |
+
KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
float KQ_max_new = KQ_max_f[j0/nwarps];
|
| 218 |
+
#pragma unroll
|
| 219 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
| 220 |
+
const int k = k0 + threadIdx.x;
|
| 221 |
+
|
| 222 |
+
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
| 223 |
+
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
|
| 224 |
+
}
|
| 225 |
+
KQ_max_new = warp_reduce_max(KQ_max_new);
|
| 226 |
+
|
| 227 |
+
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
|
| 228 |
+
KQ_max_scale_f[j0/nwarps] = expf(diff);
|
| 229 |
+
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
| 230 |
+
KQ_max_scale_f[j0/nwarps] = 0.0f;
|
| 231 |
+
}
|
| 232 |
+
KQ_max_f[j0/nwarps] = KQ_max_new;
|
| 233 |
+
|
| 234 |
+
float KQ_rowsum_add = 0.0f;
|
| 235 |
+
#pragma unroll
|
| 236 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
| 237 |
+
const int k = k0 + threadIdx.x;
|
| 238 |
+
|
| 239 |
+
const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
|
| 240 |
+
KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
|
| 241 |
+
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
| 242 |
+
KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
|
| 243 |
+
}
|
| 244 |
+
KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
|
| 245 |
+
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
|
| 246 |
+
}
|
| 247 |
+
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
| 248 |
+
|
| 249 |
+
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
| 250 |
+
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
|
| 251 |
+
} else {
|
| 252 |
+
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
|
| 253 |
+
#pragma unroll
|
| 254 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
| 255 |
+
const int k = k0 + threadIdx.x;
|
| 256 |
+
|
| 257 |
+
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
| 258 |
+
|
| 259 |
+
if (use_logit_softcap) {
|
| 260 |
+
// There is no dedicated tangens hyperbolicus function for half2.
|
| 261 |
+
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
|
| 262 |
+
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
|
| 263 |
+
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
|
| 264 |
+
|
| 265 |
+
KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
| 270 |
+
#pragma unroll
|
| 271 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
| 272 |
+
const int k = k0 + threadIdx.x;
|
| 273 |
+
|
| 274 |
+
KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
| 275 |
+
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
|
| 276 |
+
}
|
| 277 |
+
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
| 278 |
+
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
| 279 |
+
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
| 280 |
+
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
| 281 |
+
*((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
|
| 282 |
+
KQ_max_h2[j0/nwarps] = KQ_max_new;
|
| 283 |
+
|
| 284 |
+
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
| 285 |
+
#pragma unroll
|
| 286 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
| 287 |
+
const int k = k0 + threadIdx.x;
|
| 288 |
+
|
| 289 |
+
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
|
| 290 |
+
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
|
| 291 |
+
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
| 292 |
+
*((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
|
| 293 |
+
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
|
| 294 |
+
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
|
| 295 |
+
}
|
| 296 |
+
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
| 297 |
+
|
| 298 |
+
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
| 299 |
+
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
__syncthreads();
|
| 304 |
+
|
| 305 |
+
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
|
| 306 |
+
#pragma unroll
|
| 307 |
+
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 308 |
+
#pragma unroll
|
| 309 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
| 310 |
+
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
| 311 |
+
nvcuda::wmma::load_matrix_sync(
|
| 312 |
+
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
| 313 |
+
KQ + j0*(kqar*kqs_padded) + k,
|
| 314 |
+
kqar*kqs_padded);
|
| 315 |
+
}
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
|
| 319 |
+
#pragma unroll
|
| 320 |
+
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
| 321 |
+
#pragma unroll
|
| 322 |
+
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 323 |
+
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
#pragma unroll
|
| 327 |
+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
| 328 |
+
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
| 329 |
+
|
| 330 |
+
frag_a_V v_a;
|
| 331 |
+
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
| 332 |
+
#pragma unroll
|
| 333 |
+
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 334 |
+
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
__syncthreads();
|
| 340 |
+
|
| 341 |
+
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
|
| 342 |
+
#pragma unroll
|
| 343 |
+
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
| 344 |
+
#pragma unroll
|
| 345 |
+
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 346 |
+
nvcuda::wmma::store_matrix_sync(
|
| 347 |
+
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
| 348 |
+
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
| 349 |
+
D_padded, nvcuda::wmma::mem_col_major);
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
__syncthreads();
|
| 354 |
+
|
| 355 |
+
#pragma unroll
|
| 356 |
+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 357 |
+
const int j = j0 + threadIdx.y;
|
| 358 |
+
|
| 359 |
+
half2 VKQ_scale;
|
| 360 |
+
if (std::is_same<KQ_acc_t, float>::value) {
|
| 361 |
+
VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
|
| 362 |
+
} else {
|
| 363 |
+
VKQ_scale = KQ_max_scale_h2[j0/nwarps];
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
#pragma unroll
|
| 367 |
+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 368 |
+
const int i = i0 + threadIdx.x;
|
| 369 |
+
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
| 370 |
+
break;
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
half2 VKQ_add = make_half2(0.0f, 0.0f);
|
| 374 |
+
#pragma unroll
|
| 375 |
+
for (int l = 0; l < VKQ_ratio; ++l) {
|
| 376 |
+
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
|
| 377 |
+
}
|
| 378 |
+
VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
|
| 379 |
+
}
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
__syncthreads();
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
#pragma unroll
|
| 386 |
+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 387 |
+
const int j_VKQ = j0 + threadIdx.y;
|
| 388 |
+
if (ic0 + j_VKQ >= ne01) {
|
| 389 |
+
return;
|
| 390 |
+
}
|
| 391 |
+
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
| 392 |
+
|
| 393 |
+
float KQ_rowsum_j;
|
| 394 |
+
if (std::is_same<KQ_acc_t, float>::value) {
|
| 395 |
+
KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
|
| 396 |
+
} else {
|
| 397 |
+
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
#pragma unroll
|
| 401 |
+
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
| 402 |
+
const int i = i0 + threadIdx.x;
|
| 403 |
+
if (i0 + WARP_SIZE > D && i >= D) {
|
| 404 |
+
break;
|
| 405 |
+
}
|
| 406 |
+
float dst_val = VKQ[j_VKQ*D_padded + i];
|
| 407 |
+
if (parallel_blocks == 1) {
|
| 408 |
+
dst_val /= KQ_rowsum_j;
|
| 409 |
+
}
|
| 410 |
+
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
if (parallel_blocks == 1 || threadIdx.x != 0) {
|
| 414 |
+
continue;
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
float2 dst_meta_val;
|
| 418 |
+
if (std::is_same<KQ_acc_t, float>::value) {
|
| 419 |
+
dst_meta_val.x = KQ_max_f[j0/nwarps];
|
| 420 |
+
} else {
|
| 421 |
+
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
| 422 |
+
}
|
| 423 |
+
dst_meta_val.y = KQ_rowsum_j;
|
| 424 |
+
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
|
| 425 |
+
}
|
| 426 |
+
#else
|
| 427 |
+
NO_DEVICE_CODE;
|
| 428 |
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
constexpr int get_max_power_of_2(int x) {
|
| 432 |
+
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
static_assert(get_max_power_of_2(1) == 1, "Test failed.");
|
| 436 |
+
static_assert(get_max_power_of_2(2) == 2, "Test failed.");
|
| 437 |
+
static_assert(get_max_power_of_2(4) == 4, "Test failed.");
|
| 438 |
+
static_assert(get_max_power_of_2(6) == 2, "Test failed.");
|
| 439 |
+
|
| 440 |
+
// Number of VKQ rows calculated in parallel:
|
| 441 |
+
constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
|
| 442 |
+
return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
|
| 446 |
+
static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
|
| 447 |
+
static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
|
| 448 |
+
static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
|
| 449 |
+
static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
|
| 450 |
+
static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
|
| 451 |
+
static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
|
| 452 |
+
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
| 453 |
+
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
| 454 |
+
|
| 455 |
+
template <int D, int cols_per_block, typename KQ_acc_t>
|
| 456 |
+
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 457 |
+
const ggml_tensor * KQV = dst;
|
| 458 |
+
const ggml_tensor * Q = dst->src[0];
|
| 459 |
+
|
| 460 |
+
constexpr int nwarps = 4;
|
| 461 |
+
|
| 462 |
+
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
|
| 463 |
+
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
| 464 |
+
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
| 465 |
+
|
| 466 |
+
float logit_softcap;
|
| 467 |
+
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
| 468 |
+
|
| 469 |
+
if (4*blocks_num_pb1 < 2*nsm) {
|
| 470 |
+
constexpr int parallel_blocks = 4;
|
| 471 |
+
fattn_kernel_t fattn_kernel;
|
| 472 |
+
if (logit_softcap == 0.0f) {
|
| 473 |
+
constexpr bool use_logit_softcap = false;
|
| 474 |
+
fattn_kernel = flash_attn_ext_f16<
|
| 475 |
+
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 476 |
+
} else {
|
| 477 |
+
constexpr bool use_logit_softcap = true;
|
| 478 |
+
fattn_kernel = flash_attn_ext_f16<
|
| 479 |
+
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 480 |
+
}
|
| 481 |
+
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
| 482 |
+
return;
|
| 483 |
+
}
|
| 484 |
+
if (2*blocks_num_pb1 < 2*nsm) {
|
| 485 |
+
constexpr int parallel_blocks = 2;
|
| 486 |
+
fattn_kernel_t fattn_kernel;
|
| 487 |
+
if (logit_softcap == 0.0f) {
|
| 488 |
+
constexpr bool use_logit_softcap = false;
|
| 489 |
+
fattn_kernel = flash_attn_ext_f16<
|
| 490 |
+
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 491 |
+
} else {
|
| 492 |
+
constexpr bool use_logit_softcap = true;
|
| 493 |
+
fattn_kernel = flash_attn_ext_f16<
|
| 494 |
+
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 495 |
+
}
|
| 496 |
+
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
| 497 |
+
return;
|
| 498 |
+
}
|
| 499 |
+
constexpr int parallel_blocks = 1;
|
| 500 |
+
fattn_kernel_t fattn_kernel;
|
| 501 |
+
if (logit_softcap == 0.0f) {
|
| 502 |
+
constexpr bool use_logit_softcap = false;
|
| 503 |
+
fattn_kernel = flash_attn_ext_f16<
|
| 504 |
+
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 505 |
+
} else {
|
| 506 |
+
constexpr bool use_logit_softcap = true;
|
| 507 |
+
fattn_kernel = flash_attn_ext_f16<
|
| 508 |
+
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 509 |
+
}
|
| 510 |
+
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 514 |
+
const ggml_tensor * KQV = dst;
|
| 515 |
+
const ggml_tensor * Q = dst->src[0];
|
| 516 |
+
|
| 517 |
+
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
| 518 |
+
|
| 519 |
+
if (prec != GGML_PREC_DEFAULT) {
|
| 520 |
+
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
| 521 |
+
constexpr int cols_per_block = 16;
|
| 522 |
+
switch (Q->ne[0]) {
|
| 523 |
+
case 64:
|
| 524 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
|
| 525 |
+
break;
|
| 526 |
+
case 80:
|
| 527 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
|
| 528 |
+
break;
|
| 529 |
+
case 96:
|
| 530 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
|
| 531 |
+
break;
|
| 532 |
+
case 112:
|
| 533 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
|
| 534 |
+
break;
|
| 535 |
+
case 128:
|
| 536 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
| 537 |
+
break;
|
| 538 |
+
case 256:
|
| 539 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
|
| 540 |
+
break;
|
| 541 |
+
default:
|
| 542 |
+
GGML_ABORT("fatal error");
|
| 543 |
+
break;
|
| 544 |
+
}
|
| 545 |
+
} else {
|
| 546 |
+
constexpr int cols_per_block = 32;
|
| 547 |
+
switch (Q->ne[0]) {
|
| 548 |
+
case 64:
|
| 549 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
|
| 550 |
+
break;
|
| 551 |
+
case 80:
|
| 552 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
|
| 553 |
+
break;
|
| 554 |
+
case 96:
|
| 555 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
|
| 556 |
+
break;
|
| 557 |
+
case 112:
|
| 558 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
|
| 559 |
+
break;
|
| 560 |
+
case 128:
|
| 561 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
| 562 |
+
break;
|
| 563 |
+
// case 256:
|
| 564 |
+
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
| 565 |
+
// break;
|
| 566 |
+
default:
|
| 567 |
+
GGML_ABORT("fatal error");
|
| 568 |
+
break;
|
| 569 |
+
}
|
| 570 |
+
}
|
| 571 |
+
return;
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
|
| 575 |
+
constexpr int cols_per_block = 8;
|
| 576 |
+
switch (Q->ne[0]) {
|
| 577 |
+
case 64:
|
| 578 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
| 579 |
+
break;
|
| 580 |
+
case 96:
|
| 581 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
| 582 |
+
break;
|
| 583 |
+
case 128:
|
| 584 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
| 585 |
+
break;
|
| 586 |
+
case 256:
|
| 587 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
| 588 |
+
break;
|
| 589 |
+
default:
|
| 590 |
+
GGML_ABORT("fatal error");
|
| 591 |
+
break;
|
| 592 |
+
}
|
| 593 |
+
return;
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
if (Q->ne[1] <= 32) {
|
| 597 |
+
constexpr int cols_per_block = 16;
|
| 598 |
+
switch (Q->ne[0]) {
|
| 599 |
+
case 64:
|
| 600 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
| 601 |
+
break;
|
| 602 |
+
case 80:
|
| 603 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
|
| 604 |
+
break;
|
| 605 |
+
case 96:
|
| 606 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
| 607 |
+
break;
|
| 608 |
+
case 112:
|
| 609 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
|
| 610 |
+
break;
|
| 611 |
+
case 128:
|
| 612 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
| 613 |
+
break;
|
| 614 |
+
case 256:
|
| 615 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
| 616 |
+
break;
|
| 617 |
+
default:
|
| 618 |
+
GGML_ABORT("fatal error");
|
| 619 |
+
break;
|
| 620 |
+
}
|
| 621 |
+
return;
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
constexpr int cols_per_block = 32;
|
| 625 |
+
switch (Q->ne[0]) {
|
| 626 |
+
case 64:
|
| 627 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
| 628 |
+
break;
|
| 629 |
+
case 80:
|
| 630 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
|
| 631 |
+
break;
|
| 632 |
+
case 96:
|
| 633 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
| 634 |
+
break;
|
| 635 |
+
case 112:
|
| 636 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
|
| 637 |
+
break;
|
| 638 |
+
case 128:
|
| 639 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
| 640 |
+
break;
|
| 641 |
+
case 256:
|
| 642 |
+
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
| 643 |
+
break;
|
| 644 |
+
default:
|
| 645 |
+
GGML_ABORT("fatal error");
|
| 646 |
+
break;
|
| 647 |
+
}
|
| 648 |
+
}
|
ggml/src/ggml-cuda/fattn-wmma-f16.cuh
CHANGED
|
@@ -1,543 +1,3 @@
|
|
| 1 |
#include "common.cuh"
|
| 2 |
-
#include "fattn-common.cuh"
|
| 3 |
|
| 4 |
-
|
| 5 |
-
#include <mma.h>
|
| 6 |
-
#endif // FP16_MMA_AVAILABLE
|
| 7 |
-
|
| 8 |
-
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
| 9 |
-
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
| 10 |
-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 11 |
-
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 12 |
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 13 |
-
static __global__ void flash_attn_ext_f16(
|
| 14 |
-
const char * __restrict__ Q,
|
| 15 |
-
const char * __restrict__ K,
|
| 16 |
-
const char * __restrict__ V,
|
| 17 |
-
const char * __restrict__ mask,
|
| 18 |
-
float * __restrict__ dst,
|
| 19 |
-
float2 * __restrict__ dst_meta,
|
| 20 |
-
const float scale,
|
| 21 |
-
const float max_bias,
|
| 22 |
-
const float m0,
|
| 23 |
-
const float m1,
|
| 24 |
-
const uint32_t n_head_log2,
|
| 25 |
-
const float logit_softcap,
|
| 26 |
-
const int ne00,
|
| 27 |
-
const int ne01,
|
| 28 |
-
const int ne02,
|
| 29 |
-
const int ne03,
|
| 30 |
-
const int ne10,
|
| 31 |
-
const int ne11,
|
| 32 |
-
const int ne12,
|
| 33 |
-
const int ne13,
|
| 34 |
-
const int ne31,
|
| 35 |
-
const int nb31,
|
| 36 |
-
const int nb01,
|
| 37 |
-
const int nb02,
|
| 38 |
-
const int nb03,
|
| 39 |
-
const int nb11,
|
| 40 |
-
const int nb12,
|
| 41 |
-
const int nb13,
|
| 42 |
-
const int nb21,
|
| 43 |
-
const int nb22,
|
| 44 |
-
const int nb23,
|
| 45 |
-
const int ne0,
|
| 46 |
-
const int ne1,
|
| 47 |
-
const int ne2,
|
| 48 |
-
const int ne3) {
|
| 49 |
-
#ifdef FP16_MMA_AVAILABLE
|
| 50 |
-
// Skip unused kernel variants for faster compilation:
|
| 51 |
-
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
| 52 |
-
NO_DEVICE_CODE;
|
| 53 |
-
return;
|
| 54 |
-
}
|
| 55 |
-
|
| 56 |
-
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 57 |
-
|
| 58 |
-
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
| 59 |
-
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 60 |
-
|
| 61 |
-
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
| 62 |
-
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
| 63 |
-
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
| 64 |
-
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
| 65 |
-
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
| 66 |
-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
|
| 67 |
-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
|
| 68 |
-
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
|
| 69 |
-
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
| 70 |
-
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
| 71 |
-
|
| 72 |
-
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
| 73 |
-
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
| 74 |
-
static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
|
| 75 |
-
|
| 76 |
-
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
|
| 77 |
-
constexpr int D_padded = D + 8;
|
| 78 |
-
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
| 79 |
-
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
| 80 |
-
|
| 81 |
-
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 82 |
-
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
|
| 83 |
-
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
|
| 84 |
-
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
| 85 |
-
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
| 86 |
-
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
| 87 |
-
|
| 88 |
-
const int stride_Q = nb01 / sizeof(float);
|
| 89 |
-
const int stride_KV = nb11 / sizeof(half);
|
| 90 |
-
|
| 91 |
-
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
| 92 |
-
const half slopeh = __float2half(slopef);
|
| 93 |
-
const half2 slope2 = make_half2(slopef, slopef);
|
| 94 |
-
|
| 95 |
-
const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
|
| 96 |
-
|
| 97 |
-
frag_b Q_b[D/16][ncols/frag_n];
|
| 98 |
-
|
| 99 |
-
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
| 100 |
-
constexpr int mem_KQ = ncols*kqs_padded*kqar;
|
| 101 |
-
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
|
| 102 |
-
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
|
| 103 |
-
float * KQ_f = (float *) KQ;
|
| 104 |
-
half2 * KQ2 = (half2 *) KQ;
|
| 105 |
-
|
| 106 |
-
float KQ_rowsum_f[ncols/nwarps] = {0.0f};
|
| 107 |
-
float KQ_max_f[ncols/nwarps];
|
| 108 |
-
float KQ_max_scale_f[ncols/nwarps] = {0.0f};
|
| 109 |
-
|
| 110 |
-
#pragma unroll
|
| 111 |
-
for (int j = 0; j < ncols/nwarps; ++j) {
|
| 112 |
-
KQ_max_f[j] = -FLT_MAX/2.0f;
|
| 113 |
-
}
|
| 114 |
-
|
| 115 |
-
half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
|
| 116 |
-
half2 KQ_max_h2[ncols/nwarps];
|
| 117 |
-
half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
|
| 118 |
-
|
| 119 |
-
#pragma unroll
|
| 120 |
-
for (int j = 0; j < ncols/nwarps; ++j) {
|
| 121 |
-
KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
|
| 122 |
-
}
|
| 123 |
-
|
| 124 |
-
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
| 125 |
-
half2 * VKQ2 = (half2 *) VKQ;
|
| 126 |
-
#pragma unroll
|
| 127 |
-
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 128 |
-
const int j = j0 + threadIdx.y;
|
| 129 |
-
#pragma unroll
|
| 130 |
-
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 131 |
-
const int i = i0 + threadIdx.x;
|
| 132 |
-
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
| 133 |
-
break;
|
| 134 |
-
}
|
| 135 |
-
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
|
| 136 |
-
}
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
// Convert Q to half and apply scale, temporarily store in KQ:
|
| 140 |
-
#pragma unroll
|
| 141 |
-
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 142 |
-
const int j = j0 + threadIdx.y;
|
| 143 |
-
#pragma unroll
|
| 144 |
-
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
| 145 |
-
const int i = i0 + threadIdx.x;
|
| 146 |
-
if (i0 + WARP_SIZE > D && i >= D) {
|
| 147 |
-
break;
|
| 148 |
-
}
|
| 149 |
-
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
| 150 |
-
}
|
| 151 |
-
}
|
| 152 |
-
|
| 153 |
-
__syncthreads();
|
| 154 |
-
|
| 155 |
-
// Load Q into tensor core fragments/registers since it will be used frequently:
|
| 156 |
-
#pragma unroll
|
| 157 |
-
for (int i0 = 0; i0 < D; i0 += 16) {
|
| 158 |
-
#pragma unroll
|
| 159 |
-
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 160 |
-
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
| 161 |
-
}
|
| 162 |
-
}
|
| 163 |
-
|
| 164 |
-
__syncthreads();
|
| 165 |
-
|
| 166 |
-
// Iterate over ne11 == previous tokens:
|
| 167 |
-
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
|
| 168 |
-
// Calculate tile of KQ:
|
| 169 |
-
#pragma unroll
|
| 170 |
-
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
| 171 |
-
frag_c_KQ KQ_c[ncols/frag_n];
|
| 172 |
-
#pragma unroll
|
| 173 |
-
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 174 |
-
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
|
| 175 |
-
}
|
| 176 |
-
#pragma unroll
|
| 177 |
-
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
| 178 |
-
frag_a_K K_a;
|
| 179 |
-
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
| 180 |
-
#pragma unroll
|
| 181 |
-
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 182 |
-
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
| 183 |
-
}
|
| 184 |
-
}
|
| 185 |
-
#pragma unroll
|
| 186 |
-
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 187 |
-
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
|
| 188 |
-
}
|
| 189 |
-
}
|
| 190 |
-
|
| 191 |
-
__syncthreads();
|
| 192 |
-
|
| 193 |
-
// Calculate softmax for each KQ column using the current max. value.
|
| 194 |
-
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
| 195 |
-
#pragma unroll
|
| 196 |
-
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 197 |
-
const int j = j0 + threadIdx.y;
|
| 198 |
-
|
| 199 |
-
if (std::is_same<KQ_acc_t, float>::value) {
|
| 200 |
-
float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
|
| 201 |
-
#pragma unroll
|
| 202 |
-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
| 203 |
-
const int k = k0 + threadIdx.x;
|
| 204 |
-
|
| 205 |
-
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
| 206 |
-
|
| 207 |
-
if (use_logit_softcap) {
|
| 208 |
-
KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
|
| 209 |
-
}
|
| 210 |
-
}
|
| 211 |
-
|
| 212 |
-
float KQ_max_new = KQ_max_f[j0/nwarps];
|
| 213 |
-
#pragma unroll
|
| 214 |
-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
| 215 |
-
const int k = k0 + threadIdx.x;
|
| 216 |
-
|
| 217 |
-
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
| 218 |
-
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
|
| 219 |
-
}
|
| 220 |
-
KQ_max_new = warp_reduce_max(KQ_max_new);
|
| 221 |
-
|
| 222 |
-
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
|
| 223 |
-
KQ_max_scale_f[j0/nwarps] = expf(diff);
|
| 224 |
-
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
| 225 |
-
KQ_max_scale_f[j0/nwarps] = 0.0f;
|
| 226 |
-
}
|
| 227 |
-
KQ_max_f[j0/nwarps] = KQ_max_new;
|
| 228 |
-
|
| 229 |
-
float KQ_rowsum_add = 0.0f;
|
| 230 |
-
#pragma unroll
|
| 231 |
-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
| 232 |
-
const int k = k0 + threadIdx.x;
|
| 233 |
-
|
| 234 |
-
const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
|
| 235 |
-
KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
|
| 236 |
-
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
| 237 |
-
KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
|
| 238 |
-
}
|
| 239 |
-
KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
|
| 240 |
-
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
|
| 241 |
-
}
|
| 242 |
-
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
| 243 |
-
|
| 244 |
-
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
| 245 |
-
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
|
| 246 |
-
} else {
|
| 247 |
-
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
|
| 248 |
-
#pragma unroll
|
| 249 |
-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
| 250 |
-
const int k = k0 + threadIdx.x;
|
| 251 |
-
|
| 252 |
-
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
| 253 |
-
|
| 254 |
-
if (use_logit_softcap) {
|
| 255 |
-
// There is no dedicated tangens hyperbolicus function for half2.
|
| 256 |
-
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
|
| 257 |
-
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
|
| 258 |
-
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
|
| 259 |
-
|
| 260 |
-
KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
|
| 261 |
-
}
|
| 262 |
-
}
|
| 263 |
-
|
| 264 |
-
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
| 265 |
-
#pragma unroll
|
| 266 |
-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
| 267 |
-
const int k = k0 + threadIdx.x;
|
| 268 |
-
|
| 269 |
-
KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
| 270 |
-
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
|
| 271 |
-
}
|
| 272 |
-
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
| 273 |
-
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
| 274 |
-
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
| 275 |
-
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
| 276 |
-
*((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
|
| 277 |
-
KQ_max_h2[j0/nwarps] = KQ_max_new;
|
| 278 |
-
|
| 279 |
-
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
| 280 |
-
#pragma unroll
|
| 281 |
-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
| 282 |
-
const int k = k0 + threadIdx.x;
|
| 283 |
-
|
| 284 |
-
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
|
| 285 |
-
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
|
| 286 |
-
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
| 287 |
-
*((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
|
| 288 |
-
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
|
| 289 |
-
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
|
| 290 |
-
}
|
| 291 |
-
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
| 292 |
-
|
| 293 |
-
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
| 294 |
-
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
|
| 295 |
-
}
|
| 296 |
-
}
|
| 297 |
-
|
| 298 |
-
__syncthreads();
|
| 299 |
-
|
| 300 |
-
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
|
| 301 |
-
#pragma unroll
|
| 302 |
-
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 303 |
-
#pragma unroll
|
| 304 |
-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
| 305 |
-
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
| 306 |
-
nvcuda::wmma::load_matrix_sync(
|
| 307 |
-
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
| 308 |
-
KQ + j0*(kqar*kqs_padded) + k,
|
| 309 |
-
kqar*kqs_padded);
|
| 310 |
-
}
|
| 311 |
-
}
|
| 312 |
-
|
| 313 |
-
frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
|
| 314 |
-
#pragma unroll
|
| 315 |
-
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
| 316 |
-
#pragma unroll
|
| 317 |
-
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 318 |
-
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
|
| 319 |
-
}
|
| 320 |
-
|
| 321 |
-
#pragma unroll
|
| 322 |
-
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
| 323 |
-
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
| 324 |
-
|
| 325 |
-
frag_a_V v_a;
|
| 326 |
-
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
| 327 |
-
#pragma unroll
|
| 328 |
-
for (int j = 0; j < ncols/frag_n; ++j) {
|
| 329 |
-
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
| 330 |
-
}
|
| 331 |
-
}
|
| 332 |
-
}
|
| 333 |
-
|
| 334 |
-
__syncthreads();
|
| 335 |
-
|
| 336 |
-
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
|
| 337 |
-
#pragma unroll
|
| 338 |
-
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
| 339 |
-
#pragma unroll
|
| 340 |
-
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
| 341 |
-
nvcuda::wmma::store_matrix_sync(
|
| 342 |
-
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
| 343 |
-
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
| 344 |
-
D_padded, nvcuda::wmma::mem_col_major);
|
| 345 |
-
}
|
| 346 |
-
}
|
| 347 |
-
|
| 348 |
-
__syncthreads();
|
| 349 |
-
|
| 350 |
-
#pragma unroll
|
| 351 |
-
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 352 |
-
const int j = j0 + threadIdx.y;
|
| 353 |
-
|
| 354 |
-
half2 VKQ_scale;
|
| 355 |
-
if (std::is_same<KQ_acc_t, float>::value) {
|
| 356 |
-
VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
|
| 357 |
-
} else {
|
| 358 |
-
VKQ_scale = KQ_max_scale_h2[j0/nwarps];
|
| 359 |
-
}
|
| 360 |
-
|
| 361 |
-
#pragma unroll
|
| 362 |
-
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 363 |
-
const int i = i0 + threadIdx.x;
|
| 364 |
-
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
| 365 |
-
break;
|
| 366 |
-
}
|
| 367 |
-
|
| 368 |
-
half2 VKQ_add = make_half2(0.0f, 0.0f);
|
| 369 |
-
#pragma unroll
|
| 370 |
-
for (int l = 0; l < VKQ_ratio; ++l) {
|
| 371 |
-
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
|
| 372 |
-
}
|
| 373 |
-
VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
|
| 374 |
-
}
|
| 375 |
-
}
|
| 376 |
-
|
| 377 |
-
__syncthreads();
|
| 378 |
-
}
|
| 379 |
-
|
| 380 |
-
#pragma unroll
|
| 381 |
-
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
| 382 |
-
const int j_VKQ = j0 + threadIdx.y;
|
| 383 |
-
if (ic0 + j_VKQ >= ne01) {
|
| 384 |
-
return;
|
| 385 |
-
}
|
| 386 |
-
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
| 387 |
-
|
| 388 |
-
float KQ_rowsum_j;
|
| 389 |
-
if (std::is_same<KQ_acc_t, float>::value) {
|
| 390 |
-
KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
|
| 391 |
-
} else {
|
| 392 |
-
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
| 393 |
-
}
|
| 394 |
-
|
| 395 |
-
#pragma unroll
|
| 396 |
-
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
| 397 |
-
const int i = i0 + threadIdx.x;
|
| 398 |
-
if (i0 + WARP_SIZE > D && i >= D) {
|
| 399 |
-
break;
|
| 400 |
-
}
|
| 401 |
-
float dst_val = VKQ[j_VKQ*D_padded + i];
|
| 402 |
-
if (parallel_blocks == 1) {
|
| 403 |
-
dst_val /= KQ_rowsum_j;
|
| 404 |
-
}
|
| 405 |
-
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
|
| 406 |
-
}
|
| 407 |
-
|
| 408 |
-
if (parallel_blocks == 1 || threadIdx.x != 0) {
|
| 409 |
-
continue;
|
| 410 |
-
}
|
| 411 |
-
|
| 412 |
-
float2 dst_meta_val;
|
| 413 |
-
if (std::is_same<KQ_acc_t, float>::value) {
|
| 414 |
-
dst_meta_val.x = KQ_max_f[j0/nwarps];
|
| 415 |
-
} else {
|
| 416 |
-
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
| 417 |
-
}
|
| 418 |
-
dst_meta_val.y = KQ_rowsum_j;
|
| 419 |
-
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
|
| 420 |
-
}
|
| 421 |
-
#else
|
| 422 |
-
NO_DEVICE_CODE;
|
| 423 |
-
#endif // FP16_MMA_AVAILABLE
|
| 424 |
-
}
|
| 425 |
-
|
| 426 |
-
constexpr int get_max_power_of_2(int x) {
|
| 427 |
-
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
|
| 428 |
-
}
|
| 429 |
-
|
| 430 |
-
static_assert(get_max_power_of_2(1) == 1, "Test failed.");
|
| 431 |
-
static_assert(get_max_power_of_2(2) == 2, "Test failed.");
|
| 432 |
-
static_assert(get_max_power_of_2(4) == 4, "Test failed.");
|
| 433 |
-
static_assert(get_max_power_of_2(6) == 2, "Test failed.");
|
| 434 |
-
|
| 435 |
-
// Number of VKQ rows calculated in parallel:
|
| 436 |
-
constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
|
| 437 |
-
return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
|
| 438 |
-
}
|
| 439 |
-
|
| 440 |
-
static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
|
| 441 |
-
static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
|
| 442 |
-
static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
|
| 443 |
-
static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
|
| 444 |
-
static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
|
| 445 |
-
static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
|
| 446 |
-
static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
|
| 447 |
-
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
| 448 |
-
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
| 449 |
-
|
| 450 |
-
template <int D, int cols_per_block, typename KQ_acc_t>
|
| 451 |
-
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 452 |
-
const ggml_tensor * KQV = dst;
|
| 453 |
-
const ggml_tensor * Q = dst->src[0];
|
| 454 |
-
|
| 455 |
-
constexpr int nwarps = 4;
|
| 456 |
-
|
| 457 |
-
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
|
| 458 |
-
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
| 459 |
-
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
| 460 |
-
|
| 461 |
-
float logit_softcap;
|
| 462 |
-
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
| 463 |
-
|
| 464 |
-
if (4*blocks_num_pb1 < 2*nsm) {
|
| 465 |
-
constexpr int parallel_blocks = 4;
|
| 466 |
-
fattn_kernel_t fattn_kernel;
|
| 467 |
-
if (logit_softcap == 0.0f) {
|
| 468 |
-
constexpr bool use_logit_softcap = false;
|
| 469 |
-
fattn_kernel = flash_attn_ext_f16<
|
| 470 |
-
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 471 |
-
} else {
|
| 472 |
-
constexpr bool use_logit_softcap = true;
|
| 473 |
-
fattn_kernel = flash_attn_ext_f16<
|
| 474 |
-
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 475 |
-
}
|
| 476 |
-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
| 477 |
-
return;
|
| 478 |
-
}
|
| 479 |
-
if (2*blocks_num_pb1 < 2*nsm) {
|
| 480 |
-
constexpr int parallel_blocks = 2;
|
| 481 |
-
fattn_kernel_t fattn_kernel;
|
| 482 |
-
if (logit_softcap == 0.0f) {
|
| 483 |
-
constexpr bool use_logit_softcap = false;
|
| 484 |
-
fattn_kernel = flash_attn_ext_f16<
|
| 485 |
-
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 486 |
-
} else {
|
| 487 |
-
constexpr bool use_logit_softcap = true;
|
| 488 |
-
fattn_kernel = flash_attn_ext_f16<
|
| 489 |
-
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 490 |
-
}
|
| 491 |
-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
| 492 |
-
return;
|
| 493 |
-
}
|
| 494 |
-
constexpr int parallel_blocks = 1;
|
| 495 |
-
fattn_kernel_t fattn_kernel;
|
| 496 |
-
if (logit_softcap == 0.0f) {
|
| 497 |
-
constexpr bool use_logit_softcap = false;
|
| 498 |
-
fattn_kernel = flash_attn_ext_f16<
|
| 499 |
-
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 500 |
-
} else {
|
| 501 |
-
constexpr bool use_logit_softcap = true;
|
| 502 |
-
fattn_kernel = flash_attn_ext_f16<
|
| 503 |
-
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 504 |
-
}
|
| 505 |
-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
| 506 |
-
}
|
| 507 |
-
|
| 508 |
-
#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
|
| 509 |
-
template void ggml_cuda_flash_attn_ext_wmma_f16_case \
|
| 510 |
-
<D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
| 511 |
-
|
| 512 |
-
extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
|
| 513 |
-
extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
|
| 514 |
-
extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float);
|
| 515 |
-
extern DECL_FATTN_WMMA_F16_CASE(112, 16, float);
|
| 516 |
-
extern DECL_FATTN_WMMA_F16_CASE(128, 16, float);
|
| 517 |
-
extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
|
| 518 |
-
|
| 519 |
-
extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float);
|
| 520 |
-
extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float);
|
| 521 |
-
extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float);
|
| 522 |
-
extern DECL_FATTN_WMMA_F16_CASE(112, 32, float);
|
| 523 |
-
extern DECL_FATTN_WMMA_F16_CASE(128, 32, float);
|
| 524 |
-
// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
|
| 525 |
-
|
| 526 |
-
extern DECL_FATTN_WMMA_F16_CASE( 64, 8, half);
|
| 527 |
-
extern DECL_FATTN_WMMA_F16_CASE( 96, 8, half);
|
| 528 |
-
extern DECL_FATTN_WMMA_F16_CASE(128, 8, half);
|
| 529 |
-
extern DECL_FATTN_WMMA_F16_CASE(256, 8, half);
|
| 530 |
-
|
| 531 |
-
extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half);
|
| 532 |
-
extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half);
|
| 533 |
-
extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half);
|
| 534 |
-
extern DECL_FATTN_WMMA_F16_CASE(112, 16, half);
|
| 535 |
-
extern DECL_FATTN_WMMA_F16_CASE(128, 16, half);
|
| 536 |
-
extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
|
| 537 |
-
|
| 538 |
-
extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half);
|
| 539 |
-
extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half);
|
| 540 |
-
extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
|
| 541 |
-
extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
|
| 542 |
-
extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
|
| 543 |
-
extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
|
|
|
|
| 1 |
#include "common.cuh"
|
|
|
|
| 2 |
|
| 3 |
+
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ggml/src/ggml-cuda/fattn.cu
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
#include "common.cuh"
|
| 2 |
#include "fattn-common.cuh"
|
|
|
|
| 3 |
#include "fattn-tile-f16.cuh"
|
| 4 |
#include "fattn-tile-f32.cuh"
|
| 5 |
#include "fattn-vec-f16.cuh"
|
|
@@ -7,144 +8,56 @@
|
|
| 7 |
#include "fattn-wmma-f16.cuh"
|
| 8 |
#include "fattn.cuh"
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 13 |
-
const ggml_tensor * KQV = dst;
|
| 14 |
-
const ggml_tensor * Q = dst->src[0];
|
| 15 |
-
|
| 16 |
-
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
| 17 |
-
|
| 18 |
-
if (prec != GGML_PREC_DEFAULT) {
|
| 19 |
-
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
| 20 |
-
constexpr int cols_per_block = 16;
|
| 21 |
-
switch (Q->ne[0]) {
|
| 22 |
-
case 64:
|
| 23 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
|
| 24 |
-
break;
|
| 25 |
-
case 80:
|
| 26 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
|
| 27 |
-
break;
|
| 28 |
-
case 96:
|
| 29 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
|
| 30 |
-
break;
|
| 31 |
-
case 112:
|
| 32 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
|
| 33 |
-
break;
|
| 34 |
-
case 128:
|
| 35 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
| 36 |
-
break;
|
| 37 |
-
case 256:
|
| 38 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
|
| 39 |
-
break;
|
| 40 |
-
default:
|
| 41 |
-
GGML_ABORT("fatal error");
|
| 42 |
-
break;
|
| 43 |
-
}
|
| 44 |
-
} else {
|
| 45 |
-
constexpr int cols_per_block = 32;
|
| 46 |
-
switch (Q->ne[0]) {
|
| 47 |
-
case 64:
|
| 48 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
|
| 49 |
-
break;
|
| 50 |
-
case 80:
|
| 51 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
|
| 52 |
-
break;
|
| 53 |
-
case 96:
|
| 54 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
|
| 55 |
-
break;
|
| 56 |
-
case 112:
|
| 57 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
|
| 58 |
-
break;
|
| 59 |
-
case 128:
|
| 60 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
| 61 |
-
break;
|
| 62 |
-
// case 256:
|
| 63 |
-
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
| 64 |
-
// break;
|
| 65 |
-
default:
|
| 66 |
-
GGML_ABORT("fatal error");
|
| 67 |
-
break;
|
| 68 |
-
}
|
| 69 |
-
}
|
| 70 |
-
return;
|
| 71 |
-
}
|
| 72 |
-
|
| 73 |
-
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
|
| 74 |
-
constexpr int cols_per_block = 8;
|
| 75 |
-
switch (Q->ne[0]) {
|
| 76 |
-
case 64:
|
| 77 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
| 78 |
-
break;
|
| 79 |
-
case 96:
|
| 80 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
| 81 |
-
break;
|
| 82 |
-
case 128:
|
| 83 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
| 84 |
-
break;
|
| 85 |
-
case 256:
|
| 86 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
| 87 |
-
break;
|
| 88 |
-
default:
|
| 89 |
-
GGML_ABORT("fatal error");
|
| 90 |
-
break;
|
| 91 |
-
}
|
| 92 |
-
return;
|
| 93 |
-
}
|
| 94 |
-
|
| 95 |
-
if (Q->ne[1] <= 32) {
|
| 96 |
-
constexpr int cols_per_block = 16;
|
| 97 |
-
switch (Q->ne[0]) {
|
| 98 |
-
case 64:
|
| 99 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
| 100 |
-
break;
|
| 101 |
-
case 80:
|
| 102 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
|
| 103 |
-
break;
|
| 104 |
-
case 96:
|
| 105 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
| 106 |
-
break;
|
| 107 |
-
case 112:
|
| 108 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
|
| 109 |
-
break;
|
| 110 |
-
case 128:
|
| 111 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
| 112 |
-
break;
|
| 113 |
-
case 256:
|
| 114 |
-
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
| 115 |
-
break;
|
| 116 |
-
default:
|
| 117 |
-
GGML_ABORT("fatal error");
|
| 118 |
-
break;
|
| 119 |
-
}
|
| 120 |
-
return;
|
| 121 |
-
}
|
| 122 |
-
|
| 123 |
-
constexpr int cols_per_block = 32;
|
| 124 |
switch (Q->ne[0]) {
|
| 125 |
case 64:
|
| 126 |
-
|
| 127 |
break;
|
| 128 |
case 80:
|
| 129 |
-
|
| 130 |
break;
|
| 131 |
case 96:
|
| 132 |
-
|
| 133 |
break;
|
| 134 |
case 112:
|
| 135 |
-
|
| 136 |
break;
|
| 137 |
case 128:
|
| 138 |
-
|
| 139 |
break;
|
| 140 |
case 256:
|
| 141 |
-
|
| 142 |
break;
|
| 143 |
default:
|
| 144 |
GGML_ABORT("fatal error");
|
| 145 |
break;
|
| 146 |
}
|
| 147 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
| 149 |
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
| 150 |
ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
|
|
@@ -322,11 +235,19 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 322 |
return;
|
| 323 |
}
|
| 324 |
|
| 325 |
-
if (!
|
| 326 |
-
if (
|
| 327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
} else {
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
}
|
| 331 |
return;
|
| 332 |
}
|
|
@@ -341,5 +262,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 341 |
}
|
| 342 |
}
|
| 343 |
|
| 344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
}
|
|
|
|
| 1 |
#include "common.cuh"
|
| 2 |
#include "fattn-common.cuh"
|
| 3 |
+
#include "fattn-mma-f16.cuh"
|
| 4 |
#include "fattn-tile-f16.cuh"
|
| 5 |
#include "fattn-tile-f32.cuh"
|
| 6 |
#include "fattn-vec-f16.cuh"
|
|
|
|
| 8 |
#include "fattn-wmma-f16.cuh"
|
| 9 |
#include "fattn.cuh"
|
| 10 |
|
| 11 |
+
template <int cols_per_block>
|
| 12 |
+
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 13 |
+
const ggml_tensor * Q = dst->src[0];
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
switch (Q->ne[0]) {
|
| 16 |
case 64:
|
| 17 |
+
ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst);
|
| 18 |
break;
|
| 19 |
case 80:
|
| 20 |
+
ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst);
|
| 21 |
break;
|
| 22 |
case 96:
|
| 23 |
+
ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst);
|
| 24 |
break;
|
| 25 |
case 112:
|
| 26 |
+
ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst);
|
| 27 |
break;
|
| 28 |
case 128:
|
| 29 |
+
ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst);
|
| 30 |
break;
|
| 31 |
case 256:
|
| 32 |
+
ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst);
|
| 33 |
break;
|
| 34 |
default:
|
| 35 |
GGML_ABORT("fatal error");
|
| 36 |
break;
|
| 37 |
}
|
| 38 |
}
|
| 39 |
+
|
| 40 |
+
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 41 |
+
const ggml_tensor * Q = dst->src[0];
|
| 42 |
+
|
| 43 |
+
if (Q->ne[1] <= 8) {
|
| 44 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
|
| 45 |
+
return;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
if (Q->ne[1] <= 16) {
|
| 49 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst);
|
| 50 |
+
return;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
if (Q->ne[1] <= 32) {
|
| 54 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst);
|
| 55 |
+
return;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst);
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
| 62 |
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
| 63 |
ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
|
|
|
|
| 235 |
return;
|
| 236 |
}
|
| 237 |
|
| 238 |
+
if (!new_mma_available(cc)) {
|
| 239 |
+
if (prec == GGML_PREC_DEFAULT) {
|
| 240 |
+
if (Q->ne[1] <= 8) {
|
| 241 |
+
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 242 |
+
} else {
|
| 243 |
+
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
| 244 |
+
}
|
| 245 |
} else {
|
| 246 |
+
if (Q->ne[1] <= 8) {
|
| 247 |
+
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
| 248 |
+
} else {
|
| 249 |
+
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
| 250 |
+
}
|
| 251 |
}
|
| 252 |
return;
|
| 253 |
}
|
|
|
|
| 262 |
}
|
| 263 |
}
|
| 264 |
|
| 265 |
+
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
| 266 |
+
if (cc == GGML_CUDA_CC_VOLTA) {
|
| 267 |
+
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
| 271 |
}
|
ggml/src/ggml-cuda/mma.cuh
CHANGED
|
@@ -1,11 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#include "common.cuh"
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
static constexpr int I = 16;
|
| 5 |
static constexpr int K = 4;
|
| 6 |
static constexpr int ne = 2;
|
| 7 |
|
| 8 |
-
|
| 9 |
|
| 10 |
static __device__ __forceinline__ int get_i(const int l) {
|
| 11 |
const int ret = (l%2) * (I/2) + threadIdx.x / K;
|
|
@@ -21,27 +77,35 @@ struct mma_int_A_I16K4 {
|
|
| 21 |
return ret;
|
| 22 |
}
|
| 23 |
|
| 24 |
-
__device__ __forceinline__ void
|
| 25 |
-
#if defined(INT8_MMA_AVAILABLE)
|
| 26 |
-
const int * xs = xs0 + (threadIdx.x%I)*stride;
|
| 27 |
-
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
| 28 |
-
: "+r"(x[0]), "+r"(x[1])
|
| 29 |
-
: "l"(xs));
|
| 30 |
-
#else
|
| 31 |
#pragma unroll
|
| 32 |
for (int l = 0; l < ne; ++l) {
|
| 33 |
x[l] = xs0[get_i(l)*stride + get_k(l)];
|
| 34 |
}
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
}
|
| 37 |
};
|
| 38 |
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
| 40 |
static constexpr int I = 16;
|
| 41 |
static constexpr int K = 8;
|
| 42 |
static constexpr int ne = 4;
|
| 43 |
|
| 44 |
-
|
| 45 |
|
| 46 |
static __device__ __forceinline__ int get_i(const int l) {
|
| 47 |
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
|
|
@@ -57,31 +121,62 @@ struct mma_int_A_I16K8 {
|
|
| 57 |
return ret;
|
| 58 |
}
|
| 59 |
|
| 60 |
-
__device__ __forceinline__ void
|
| 61 |
-
#if defined(INT8_MMA_AVAILABLE)
|
| 62 |
-
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
| 63 |
-
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
| 64 |
-
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
| 65 |
-
: "l"(xs));
|
| 66 |
-
#else
|
| 67 |
#pragma unroll
|
| 68 |
for (int l = 0; l < ne; ++l) {
|
| 69 |
x[l] = xs0[get_i(l)*stride + get_k(l)];
|
| 70 |
}
|
| 71 |
-
#endif // defined(INT8_MMA_AVAILABLE)
|
| 72 |
}
|
| 73 |
|
| 74 |
-
__device__ __forceinline__ void
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
}
|
| 77 |
};
|
| 78 |
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
| 80 |
static constexpr int J = 8;
|
| 81 |
static constexpr int K = 4;
|
| 82 |
static constexpr int ne = 1;
|
| 83 |
|
| 84 |
-
|
| 85 |
|
| 86 |
static __device__ __forceinline__ int get_j(const int /* l */) {
|
| 87 |
const int ret = threadIdx.x / K;
|
|
@@ -97,27 +192,34 @@ struct mma_int_B_J8K4 {
|
|
| 97 |
return ret;
|
| 98 |
}
|
| 99 |
|
| 100 |
-
__device__ __forceinline__ void
|
| 101 |
-
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
|
| 102 |
-
const int * xs = xs0 + (threadIdx.x%J)*stride;
|
| 103 |
-
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
|
| 104 |
-
: "+r"(x[0])
|
| 105 |
-
: "l"(xs));
|
| 106 |
-
#else
|
| 107 |
#pragma unroll
|
| 108 |
for (int l = 0; l < ne; ++l) {
|
| 109 |
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
| 110 |
}
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
}
|
| 113 |
};
|
| 114 |
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
| 116 |
static constexpr int J = 8;
|
| 117 |
static constexpr int K = 8;
|
| 118 |
static constexpr int ne = 2;
|
| 119 |
|
| 120 |
-
|
| 121 |
|
| 122 |
static __device__ __forceinline__ int get_j(const int /* l */) {
|
| 123 |
const int ret = threadIdx.x / (K/2);
|
|
@@ -133,22 +235,31 @@ struct mma_int_B_J8K8 {
|
|
| 133 |
return ret;
|
| 134 |
}
|
| 135 |
|
| 136 |
-
__device__ __forceinline__ void
|
| 137 |
-
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
|
| 138 |
-
const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
|
| 139 |
-
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
| 140 |
-
: "+r"(x[0]), "+r"(x[1])
|
| 141 |
-
: "l"(xs));
|
| 142 |
-
#else
|
| 143 |
#pragma unroll
|
| 144 |
for (int l = 0; l < ne; ++l) {
|
| 145 |
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
| 146 |
}
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
}
|
| 149 |
};
|
| 150 |
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
static constexpr int I = 16;
|
| 153 |
static constexpr int J = 8;
|
| 154 |
static constexpr int ne = 4;
|
|
@@ -169,8 +280,8 @@ struct mma_int_C_I16J8 {
|
|
| 169 |
return ret;
|
| 170 |
}
|
| 171 |
|
| 172 |
-
__device__ __forceinline__ void
|
| 173 |
-
#ifdef
|
| 174 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 175 |
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 176 |
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
|
@@ -188,11 +299,11 @@ struct mma_int_C_I16J8 {
|
|
| 188 |
GGML_UNUSED(mma_A);
|
| 189 |
GGML_UNUSED(mma_B);
|
| 190 |
NO_DEVICE_CODE;
|
| 191 |
-
#endif //
|
| 192 |
}
|
| 193 |
|
| 194 |
-
__device__ __forceinline__ void
|
| 195 |
-
#ifdef
|
| 196 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 197 |
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
| 198 |
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
|
@@ -216,6 +327,132 @@ struct mma_int_C_I16J8 {
|
|
| 216 |
GGML_UNUSED(mma_A);
|
| 217 |
GGML_UNUSED(mma_B);
|
| 218 |
NO_DEVICE_CODE;
|
| 219 |
-
#endif //
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
}
|
| 221 |
};
|
|
|
|
| 1 |
+
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
|
| 2 |
+
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
|
| 3 |
+
// The documentation for the PTX instructions can be found under:
|
| 4 |
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
|
| 5 |
+
//
|
| 6 |
+
// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
|
| 7 |
+
// A is a row-major matrix with shape I x K.
|
| 8 |
+
// B is a column-major matrix with shape K x J.
|
| 9 |
+
// C is a column-major matrix with shape I x J.
|
| 10 |
+
// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements.
|
| 11 |
+
// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
| 12 |
+
// All matrix tiles have ne physical 32 bit elements per warp.
|
| 13 |
+
//
|
| 14 |
+
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
| 15 |
+
|
| 16 |
#include "common.cuh"
|
| 17 |
|
| 18 |
+
|
| 19 |
+
#if CUDART_VERSION >= 11800
|
| 20 |
+
|
| 21 |
+
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
| 22 |
+
int ret = 0;
|
| 23 |
+
|
| 24 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 25 |
+
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
| 26 |
+
: "+r"(ret) : "r"(x));
|
| 27 |
+
#else
|
| 28 |
+
NO_DEVICE_CODE;
|
| 29 |
+
#endif // defined(NEW_MMA_AVAILABLE)
|
| 30 |
+
return ret;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
#else
|
| 34 |
+
|
| 35 |
+
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
| 36 |
+
// Imagine transposing row-major matrix to column-major matrix.
|
| 37 |
+
const int src_i_low = 2 * (threadIdx.x % 4);
|
| 38 |
+
const int src_i_high = src_i_low + 1;
|
| 39 |
+
const int src_j = threadIdx.x / 4;
|
| 40 |
+
|
| 41 |
+
const int src_laneid_low = src_i_low * 4 + src_j / 2;
|
| 42 |
+
const int src_laneid_high = src_i_high * 4 + src_j / 2;
|
| 43 |
+
|
| 44 |
+
const int shift_low = ((src_j + 0) % 2) * 16;
|
| 45 |
+
const int shift_high = ((src_j + 1) % 2) * 16;
|
| 46 |
+
|
| 47 |
+
const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
|
| 48 |
+
const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
|
| 49 |
+
|
| 50 |
+
return ret_low | ret_high;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
#endif // CUDART_VERSION >= 11800
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
template <typename T>
|
| 57 |
+
struct mma_A_I16K4 {
|
| 58 |
+
static_assert(sizeof(T) == 4, "bad type size");
|
| 59 |
+
|
| 60 |
static constexpr int I = 16;
|
| 61 |
static constexpr int K = 4;
|
| 62 |
static constexpr int ne = 2;
|
| 63 |
|
| 64 |
+
T x[ne];
|
| 65 |
|
| 66 |
static __device__ __forceinline__ int get_i(const int l) {
|
| 67 |
const int ret = (l%2) * (I/2) + threadIdx.x / K;
|
|
|
|
| 77 |
return ret;
|
| 78 |
}
|
| 79 |
|
| 80 |
+
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
#pragma unroll
|
| 82 |
for (int l = 0; l < ne; ++l) {
|
| 83 |
x[l] = xs0[get_i(l)*stride + get_k(l)];
|
| 84 |
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
| 88 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 89 |
+
int * xi = (int *) x;
|
| 90 |
+
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride;
|
| 91 |
+
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
| 92 |
+
: "+r"(xi[0]), "+r"(xi[1])
|
| 93 |
+
: "l"(xs));
|
| 94 |
+
#else
|
| 95 |
+
load_generic(xs0, stride);
|
| 96 |
+
#endif // NEW_MMA_AVAILABLE
|
| 97 |
}
|
| 98 |
};
|
| 99 |
|
| 100 |
+
template <typename T>
|
| 101 |
+
struct mma_A_I16K8 {
|
| 102 |
+
static_assert(sizeof(T) == 4, "bad type size");
|
| 103 |
+
|
| 104 |
static constexpr int I = 16;
|
| 105 |
static constexpr int K = 8;
|
| 106 |
static constexpr int ne = 4;
|
| 107 |
|
| 108 |
+
T x[ne];
|
| 109 |
|
| 110 |
static __device__ __forceinline__ int get_i(const int l) {
|
| 111 |
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
|
|
|
|
| 121 |
return ret;
|
| 122 |
}
|
| 123 |
|
| 124 |
+
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
#pragma unroll
|
| 126 |
for (int l = 0; l < ne; ++l) {
|
| 127 |
x[l] = xs0[get_i(l)*stride + get_k(l)];
|
| 128 |
}
|
|
|
|
| 129 |
}
|
| 130 |
|
| 131 |
+
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
| 132 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 133 |
+
int * xi = (int * ) x;
|
| 134 |
+
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
| 135 |
+
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
| 136 |
+
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
| 137 |
+
: "l"(xs));
|
| 138 |
+
#else
|
| 139 |
+
GGML_UNUSED(xs0);
|
| 140 |
+
GGML_UNUSED(stride);
|
| 141 |
+
NO_DEVICE_CODE;
|
| 142 |
+
#endif // NEW_MMA_AVAILABLE
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
__device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) {
|
| 146 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 147 |
+
int * xi = (int * ) x;
|
| 148 |
+
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
|
| 149 |
+
asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
|
| 150 |
+
: "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3])
|
| 151 |
+
: "l"(xs));
|
| 152 |
+
#else
|
| 153 |
+
GGML_UNUSED(xs0);
|
| 154 |
+
GGML_UNUSED(stride);
|
| 155 |
+
NO_DEVICE_CODE;
|
| 156 |
+
#endif // NEW_MMA_AVAILABLE
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
__device__ __forceinline__ void transpose() {
|
| 160 |
+
int * xi = (int *) x;
|
| 161 |
+
xi[0] = ggml_cuda_movmatrix(xi[0]);
|
| 162 |
+
|
| 163 |
+
const int tmp = ggml_cuda_movmatrix(xi[1]);
|
| 164 |
+
xi[1] = ggml_cuda_movmatrix(xi[2]);
|
| 165 |
+
xi[2] = tmp;
|
| 166 |
+
|
| 167 |
+
xi[3] = ggml_cuda_movmatrix(xi[3]);
|
| 168 |
}
|
| 169 |
};
|
| 170 |
|
| 171 |
+
template <typename T>
|
| 172 |
+
struct mma_B_J8K4 {
|
| 173 |
+
static_assert(sizeof(T) == 4, "bad type size");
|
| 174 |
+
|
| 175 |
static constexpr int J = 8;
|
| 176 |
static constexpr int K = 4;
|
| 177 |
static constexpr int ne = 1;
|
| 178 |
|
| 179 |
+
T x[ne];
|
| 180 |
|
| 181 |
static __device__ __forceinline__ int get_j(const int /* l */) {
|
| 182 |
const int ret = threadIdx.x / K;
|
|
|
|
| 192 |
return ret;
|
| 193 |
}
|
| 194 |
|
| 195 |
+
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
#pragma unroll
|
| 197 |
for (int l = 0; l < ne; ++l) {
|
| 198 |
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
| 199 |
}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
| 203 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 204 |
+
int * xi = (int *) x;
|
| 205 |
+
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride;
|
| 206 |
+
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
|
| 207 |
+
: "+r"(xi[0]) : "l"(xs));
|
| 208 |
+
#else
|
| 209 |
+
load_generic(xs0, stride);
|
| 210 |
+
#endif // NEW_MMA_AVAILABLE
|
| 211 |
}
|
| 212 |
};
|
| 213 |
|
| 214 |
+
template <typename T>
|
| 215 |
+
struct mma_B_J8K8 {
|
| 216 |
+
static_assert(sizeof(T) == 4, "bad type size");
|
| 217 |
+
|
| 218 |
static constexpr int J = 8;
|
| 219 |
static constexpr int K = 8;
|
| 220 |
static constexpr int ne = 2;
|
| 221 |
|
| 222 |
+
T x[ne];
|
| 223 |
|
| 224 |
static __device__ __forceinline__ int get_j(const int /* l */) {
|
| 225 |
const int ret = threadIdx.x / (K/2);
|
|
|
|
| 235 |
return ret;
|
| 236 |
}
|
| 237 |
|
| 238 |
+
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
#pragma unroll
|
| 240 |
for (int l = 0; l < ne; ++l) {
|
| 241 |
x[l] = xs0[get_j(l)*stride + get_k(l)];
|
| 242 |
}
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
|
| 246 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 247 |
+
int * xi = (int *) x;
|
| 248 |
+
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
|
| 249 |
+
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
|
| 250 |
+
: "+r"(xi[0]), "+r"(xi[1])
|
| 251 |
+
: "l"(xs));
|
| 252 |
+
#else
|
| 253 |
+
load_generic(xs0, stride);
|
| 254 |
+
#endif // NEW_MMA_AVAILABLE
|
| 255 |
}
|
| 256 |
};
|
| 257 |
|
| 258 |
+
template <typename T>
|
| 259 |
+
struct mma_C_I16J8 {};
|
| 260 |
+
|
| 261 |
+
template <>
|
| 262 |
+
struct mma_C_I16J8<int> {
|
| 263 |
static constexpr int I = 16;
|
| 264 |
static constexpr int J = 8;
|
| 265 |
static constexpr int ne = 4;
|
|
|
|
| 280 |
return ret;
|
| 281 |
}
|
| 282 |
|
| 283 |
+
__device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) {
|
| 284 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 285 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 286 |
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 287 |
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
|
|
|
| 299 |
GGML_UNUSED(mma_A);
|
| 300 |
GGML_UNUSED(mma_B);
|
| 301 |
NO_DEVICE_CODE;
|
| 302 |
+
#endif // NEW_MMA_AVAILABLE
|
| 303 |
}
|
| 304 |
|
| 305 |
+
__device__ __forceinline__ void mma(const mma_A_I16K8<int> & mma_A, const mma_B_J8K8<int> & mma_B) {
|
| 306 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 307 |
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 308 |
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
| 309 |
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
|
|
|
| 327 |
GGML_UNUSED(mma_A);
|
| 328 |
GGML_UNUSED(mma_B);
|
| 329 |
NO_DEVICE_CODE;
|
| 330 |
+
#endif // NEW_MMA_AVAILABLE
|
| 331 |
+
}
|
| 332 |
+
};
|
| 333 |
+
|
| 334 |
+
template <>
|
| 335 |
+
struct mma_C_I16J8<half2> {
|
| 336 |
+
static constexpr int I = 16;
|
| 337 |
+
static constexpr int J = 4;
|
| 338 |
+
static constexpr int ne = 2;
|
| 339 |
+
|
| 340 |
+
half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}};
|
| 341 |
+
|
| 342 |
+
static __device__ __forceinline__ int get_i(const int l) {
|
| 343 |
+
const int ret = l * (I/2) + threadIdx.x / J;
|
| 344 |
+
GGML_CUDA_ASSUME(ret >= 0);
|
| 345 |
+
GGML_CUDA_ASSUME(ret < I);
|
| 346 |
+
return ret;
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
static __device__ __forceinline__ int get_j(const int /* l */) {
|
| 350 |
+
const int ret = threadIdx.x % J;
|
| 351 |
+
GGML_CUDA_ASSUME(ret >= 0);
|
| 352 |
+
GGML_CUDA_ASSUME(ret < J);
|
| 353 |
+
return ret;
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
|
| 357 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 358 |
+
int * Axi = (int *) mma_A.x;
|
| 359 |
+
int * Bxi = (int *) mma_B.x;
|
| 360 |
+
int * xi = (int *) x;
|
| 361 |
+
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 362 |
+
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
| 363 |
+
: "+r"(xi[0]), "+r"(xi[1])
|
| 364 |
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
| 365 |
+
#else
|
| 366 |
+
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
| 367 |
+
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
| 368 |
+
: "+r"(xi[0]), "+r"(xi[1])
|
| 369 |
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
| 370 |
+
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
| 371 |
+
: "+r"(xi[0]), "+r"(xi[1])
|
| 372 |
+
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
| 373 |
+
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 374 |
+
#else
|
| 375 |
+
GGML_UNUSED(mma_A);
|
| 376 |
+
GGML_UNUSED(mma_B);
|
| 377 |
+
NO_DEVICE_CODE;
|
| 378 |
+
#endif // NEW_MMA_AVAILABLE
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
|
| 382 |
+
mma_B_J8K8<half2> mma_B;
|
| 383 |
+
|
| 384 |
+
int * xi = (int *) x;
|
| 385 |
+
int * Bxi = (int *) mma_B.x;
|
| 386 |
+
Bxi[0] = ggml_cuda_movmatrix(xi[0]);
|
| 387 |
+
Bxi[1] = ggml_cuda_movmatrix(xi[1]);
|
| 388 |
+
|
| 389 |
+
return mma_B;
|
| 390 |
+
}
|
| 391 |
+
};
|
| 392 |
+
|
| 393 |
+
template <>
|
| 394 |
+
struct mma_C_I16J8<float> {
|
| 395 |
+
static constexpr int I = 16;
|
| 396 |
+
static constexpr int J = 8;
|
| 397 |
+
static constexpr int ne = 4;
|
| 398 |
+
|
| 399 |
+
float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f};
|
| 400 |
+
|
| 401 |
+
static __device__ __forceinline__ int get_i(const int l) {
|
| 402 |
+
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
|
| 403 |
+
GGML_CUDA_ASSUME(ret >= 0);
|
| 404 |
+
GGML_CUDA_ASSUME(ret < I);
|
| 405 |
+
return ret;
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
static __device__ __forceinline__ int get_j(const int l) {
|
| 409 |
+
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
|
| 410 |
+
GGML_CUDA_ASSUME(ret >= 0);
|
| 411 |
+
GGML_CUDA_ASSUME(ret < J);
|
| 412 |
+
return ret;
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
|
| 416 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 417 |
+
int * Axi = (int *) mma_A.x;
|
| 418 |
+
int * Bxi = (int *) mma_B.x;
|
| 419 |
+
int * xi = (int *) x;
|
| 420 |
+
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 421 |
+
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
| 422 |
+
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
| 423 |
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
|
| 424 |
+
#else
|
| 425 |
+
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
|
| 426 |
+
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 427 |
+
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
| 428 |
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
| 429 |
+
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 430 |
+
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
|
| 431 |
+
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
|
| 432 |
+
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 433 |
+
#else
|
| 434 |
+
GGML_UNUSED(mma_A);
|
| 435 |
+
GGML_UNUSED(mma_B);
|
| 436 |
+
NO_DEVICE_CODE;
|
| 437 |
+
#endif // NEW_MMA_AVAILABLE
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
|
| 441 |
+
mma_B_J8K8<half2> mma_B;
|
| 442 |
+
mma_B.x[0] = make_half2(x[0], x[1]);
|
| 443 |
+
mma_B.x[1] = make_half2(x[2], x[3]);
|
| 444 |
+
|
| 445 |
+
int * Bxi = (int *) mma_B.x;
|
| 446 |
+
Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
|
| 447 |
+
Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
|
| 448 |
+
|
| 449 |
+
return mma_B;
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
__device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) {
|
| 453 |
+
#pragma unroll
|
| 454 |
+
for (int l = 0; l < ne; ++l) {
|
| 455 |
+
x[l] = xs0[get_j(l)*stride + get_i(l)];
|
| 456 |
+
}
|
| 457 |
}
|
| 458 |
};
|
ggml/src/ggml-cuda/mmq.cu
CHANGED
|
@@ -132,7 +132,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
|
| 132 |
return false;
|
| 133 |
}
|
| 134 |
|
| 135 |
-
if (
|
| 136 |
return true;
|
| 137 |
}
|
| 138 |
|
|
|
|
| 132 |
return false;
|
| 133 |
}
|
| 134 |
|
| 135 |
+
if (new_mma_available(cc)) {
|
| 136 |
return true;
|
| 137 |
}
|
| 138 |
|
ggml/src/ggml-cuda/mmq.cuh
CHANGED
|
@@ -87,7 +87,7 @@ struct tile_x_sizes {
|
|
| 87 |
};
|
| 88 |
|
| 89 |
static constexpr int get_mmq_x_max_host(const int cc) {
|
| 90 |
-
return
|
| 91 |
#ifdef GGML_CUDA_FORCE_MMQ
|
| 92 |
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64;
|
| 93 |
#else
|
|
@@ -96,9 +96,9 @@ static constexpr int get_mmq_x_max_host(const int cc) {
|
|
| 96 |
}
|
| 97 |
|
| 98 |
static constexpr __device__ int get_mmq_x_max_device() {
|
| 99 |
-
#ifdef
|
| 100 |
return 128;
|
| 101 |
-
#else //
|
| 102 |
|
| 103 |
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 104 |
return 128;
|
|
@@ -116,7 +116,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
|
| 116 |
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
| 117 |
|
| 118 |
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 119 |
-
#endif //
|
| 120 |
}
|
| 121 |
|
| 122 |
static constexpr int get_mmq_y_host(const int cc) {
|
|
@@ -209,10 +209,10 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
| 209 |
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|
| 210 |
|
| 211 |
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
| 212 |
-
return
|
| 213 |
}
|
| 214 |
|
| 215 |
-
#ifdef
|
| 216 |
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
|
| 217 |
return mmq_x >= 48 ? 16 : 8;
|
| 218 |
}
|
|
@@ -220,21 +220,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
|
|
| 220 |
static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
|
| 221 |
return 8;
|
| 222 |
}
|
| 223 |
-
#endif //
|
| 224 |
|
| 225 |
// ------------------------------------------------------------
|
| 226 |
|
| 227 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
|
| 228 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 229 |
|
| 230 |
-
#ifdef
|
| 231 |
int * x_qs = (int *) x_tile;
|
| 232 |
float * x_df = (float *) (x_qs + 2*WARP_SIZE);
|
| 233 |
#else
|
| 234 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
| 235 |
int * x_qs = (int *) x_tile;
|
| 236 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 237 |
-
#endif //
|
| 238 |
|
| 239 |
const int kbx = threadIdx.x / QI4_0;
|
| 240 |
const int kqsx = threadIdx.x % QI4_0;
|
|
@@ -250,12 +250,12 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 250 |
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
|
| 251 |
const int qs0 = get_int_b2(bxi->qs, kqsx);
|
| 252 |
|
| 253 |
-
#ifdef
|
| 254 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
|
| 255 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
|
| 256 |
#else
|
| 257 |
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
|
| 258 |
-
#endif //
|
| 259 |
}
|
| 260 |
|
| 261 |
const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
|
|
@@ -271,11 +271,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 271 |
|
| 272 |
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
|
| 273 |
|
| 274 |
-
#ifdef
|
| 275 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 276 |
#else
|
| 277 |
x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
|
| 278 |
-
#endif //
|
| 279 |
}
|
| 280 |
}
|
| 281 |
|
|
@@ -322,14 +322,14 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
| 322 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
|
| 323 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 324 |
|
| 325 |
-
#ifdef
|
| 326 |
int * x_qs = (int *) x_tile;
|
| 327 |
half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
|
| 328 |
#else
|
| 329 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
| 330 |
int * x_qs = (int *) x_tile;
|
| 331 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 332 |
-
#endif //
|
| 333 |
|
| 334 |
const int kbx = threadIdx.x / QI4_1;
|
| 335 |
const int kqsx = threadIdx.x % QI4_1;
|
|
@@ -345,12 +345,12 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 345 |
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
|
| 346 |
const int qs0 = get_int_b4(bxi->qs, kqsx);
|
| 347 |
|
| 348 |
-
#ifdef
|
| 349 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
| 350 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
|
| 351 |
#else
|
| 352 |
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
|
| 353 |
-
#endif //
|
| 354 |
}
|
| 355 |
|
| 356 |
const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
|
|
@@ -366,11 +366,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 366 |
|
| 367 |
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
|
| 368 |
|
| 369 |
-
#ifdef
|
| 370 |
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
| 371 |
#else
|
| 372 |
x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
|
| 373 |
-
#endif //
|
| 374 |
}
|
| 375 |
}
|
| 376 |
|
|
@@ -417,14 +417,14 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
| 417 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
|
| 418 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 419 |
|
| 420 |
-
#ifdef
|
| 421 |
int * x_qs = (int *) x_tile;
|
| 422 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 423 |
#else
|
| 424 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
|
| 425 |
int * x_qs = (int *) x_tile;
|
| 426 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 427 |
-
#endif //
|
| 428 |
|
| 429 |
const int kbx = threadIdx.x / QI5_0;
|
| 430 |
const int kqsx = threadIdx.x % QI5_0;
|
|
@@ -456,13 +456,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 456 |
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
| 457 |
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
|
| 458 |
|
| 459 |
-
#ifdef
|
| 460 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
| 461 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
| 462 |
#else
|
| 463 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
| 464 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
| 465 |
-
#endif //
|
| 466 |
}
|
| 467 |
|
| 468 |
const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
|
|
@@ -478,25 +478,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 478 |
|
| 479 |
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
|
| 480 |
|
| 481 |
-
#ifdef
|
| 482 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 483 |
#else
|
| 484 |
x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
|
| 485 |
-
#endif //
|
| 486 |
}
|
| 487 |
}
|
| 488 |
|
| 489 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
|
| 490 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 491 |
|
| 492 |
-
#ifdef
|
| 493 |
int * x_qs = (int *) x_tile;
|
| 494 |
half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
|
| 495 |
#else
|
| 496 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
| 497 |
int * x_qs = (int *) x_tile;
|
| 498 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 499 |
-
#endif //
|
| 500 |
|
| 501 |
const int kbx = threadIdx.x / QI5_1;
|
| 502 |
const int kqsx = threadIdx.x % QI5_1;
|
|
@@ -526,13 +526,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 526 |
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
|
| 527 |
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
| 528 |
|
| 529 |
-
#ifdef
|
| 530 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
| 531 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
| 532 |
#else
|
| 533 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
| 534 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
| 535 |
-
#endif //
|
| 536 |
}
|
| 537 |
|
| 538 |
const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
|
|
@@ -548,25 +548,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 548 |
|
| 549 |
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
|
| 550 |
|
| 551 |
-
#ifdef
|
| 552 |
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
| 553 |
#else
|
| 554 |
x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
|
| 555 |
-
#endif //
|
| 556 |
}
|
| 557 |
}
|
| 558 |
|
| 559 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
|
| 560 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 561 |
|
| 562 |
-
#ifdef
|
| 563 |
int * x_qs = (int *) x_tile;
|
| 564 |
float * x_df = (float *) (x_tile + 2*WARP_SIZE);
|
| 565 |
#else
|
| 566 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
| 567 |
int * x_qs = (int *) x_tile;
|
| 568 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 569 |
-
#endif //
|
| 570 |
|
| 571 |
const int kbx = threadIdx.x / QI8_0;
|
| 572 |
const int kqsx = threadIdx.x % QI8_0;
|
|
@@ -581,13 +581,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 581 |
|
| 582 |
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
| 583 |
|
| 584 |
-
#ifdef
|
| 585 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
|
| 586 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
|
| 587 |
#else
|
| 588 |
x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
|
| 589 |
x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
|
| 590 |
-
#endif //
|
| 591 |
}
|
| 592 |
|
| 593 |
const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
|
|
@@ -603,11 +603,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 603 |
|
| 604 |
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
|
| 605 |
|
| 606 |
-
#ifdef
|
| 607 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 608 |
#else
|
| 609 |
x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
|
| 610 |
-
#endif //
|
| 611 |
}
|
| 612 |
}
|
| 613 |
|
|
@@ -645,9 +645,9 @@ template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
|
|
| 645 |
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
| 646 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 647 |
|
| 648 |
-
typedef
|
| 649 |
-
typedef
|
| 650 |
-
typedef
|
| 651 |
|
| 652 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 653 |
constexpr int rows_per_warp = 2 * granularity;
|
|
@@ -672,7 +672,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
| 672 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
| 673 |
const int k0 = k00 + k01;
|
| 674 |
|
| 675 |
-
A[n][k01/QI8_0].
|
| 676 |
}
|
| 677 |
|
| 678 |
#pragma unroll
|
|
@@ -695,7 +695,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
| 695 |
mma_B B;
|
| 696 |
float dB[mma_C::ne/2];
|
| 697 |
|
| 698 |
-
B.
|
| 699 |
|
| 700 |
#pragma unroll
|
| 701 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
@@ -711,7 +711,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
| 711 |
#pragma unroll
|
| 712 |
for (int n = 0; n < ntx; ++n) {
|
| 713 |
mma_C C;
|
| 714 |
-
C.
|
| 715 |
|
| 716 |
#pragma unroll
|
| 717 |
for (int l = 0; l < mma_C::ne; ++l) {
|
|
@@ -756,9 +756,9 @@ template <int mmq_x, int mmq_y, int nwarps>
|
|
| 756 |
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
| 757 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 758 |
|
| 759 |
-
typedef
|
| 760 |
-
typedef
|
| 761 |
-
typedef
|
| 762 |
|
| 763 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 764 |
constexpr int rows_per_warp = 2 * granularity;
|
|
@@ -782,7 +782,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
| 782 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 783 |
const int k0 = k00 + k01;
|
| 784 |
|
| 785 |
-
A[n][k01/QI8_1].
|
| 786 |
}
|
| 787 |
|
| 788 |
#pragma unroll
|
|
@@ -805,7 +805,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
| 805 |
mma_B B;
|
| 806 |
float2 dsB[mma_C::ne/2];
|
| 807 |
|
| 808 |
-
B.
|
| 809 |
|
| 810 |
#pragma unroll
|
| 811 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
@@ -817,7 +817,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
| 817 |
#pragma unroll
|
| 818 |
for (int n = 0; n < ntx; ++n) {
|
| 819 |
mma_C C;
|
| 820 |
-
C.
|
| 821 |
|
| 822 |
#pragma unroll
|
| 823 |
for (int l = 0; l < mma_C::ne; ++l) {
|
|
@@ -864,12 +864,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|
| 864 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 865 |
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
| 866 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 867 |
-
#ifdef
|
| 868 |
|
| 869 |
-
typedef
|
| 870 |
-
typedef
|
| 871 |
-
typedef
|
| 872 |
-
typedef
|
| 873 |
|
| 874 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 875 |
constexpr int rows_per_warp = 2 * granularity;
|
|
@@ -893,7 +893,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
| 893 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 894 |
const int k0 = k00 + k01;
|
| 895 |
|
| 896 |
-
((mma_A_K8 *) A[n])[k01/8].
|
| 897 |
}
|
| 898 |
|
| 899 |
#pragma unroll
|
|
@@ -916,8 +916,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
| 916 |
mma_B B[2];
|
| 917 |
float dB[mma_C::ne/2];
|
| 918 |
|
| 919 |
-
|
| 920 |
-
B[
|
|
|
|
| 921 |
|
| 922 |
#pragma unroll
|
| 923 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
@@ -929,8 +930,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
| 929 |
#pragma unroll
|
| 930 |
for (int n = 0; n < ntx; ++n) {
|
| 931 |
mma_C C[2];
|
| 932 |
-
C[0].
|
| 933 |
-
C[1].
|
| 934 |
|
| 935 |
#pragma unroll
|
| 936 |
for (int l = 0; l < mma_C::ne; ++l) {
|
|
@@ -942,20 +943,20 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
| 942 |
#else
|
| 943 |
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
|
| 944 |
NO_DEVICE_CODE;
|
| 945 |
-
#endif //
|
| 946 |
}
|
| 947 |
|
| 948 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
|
| 949 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 950 |
|
| 951 |
-
#ifdef
|
| 952 |
int * x_qs = (int *) x_tile;
|
| 953 |
half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
|
| 954 |
#else
|
| 955 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
| 956 |
int * x_qs = (int *) x_tile;
|
| 957 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 958 |
-
#endif //
|
| 959 |
|
| 960 |
const int kqsx = threadIdx.x % QI2_K;
|
| 961 |
|
|
@@ -977,11 +978,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 977 |
|
| 978 |
const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
|
| 979 |
|
| 980 |
-
#ifdef
|
| 981 |
x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
|
| 982 |
#else
|
| 983 |
x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
|
| 984 |
-
#endif //
|
| 985 |
}
|
| 986 |
|
| 987 |
const int sc_m = bxi->scales[kqsx];
|
|
@@ -992,11 +993,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 992 |
const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
|
| 993 |
#endif // FAST_FP16_AVAILABLE
|
| 994 |
|
| 995 |
-
#ifdef
|
| 996 |
x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
|
| 997 |
#else
|
| 998 |
x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik;
|
| 999 |
-
#endif //
|
| 1000 |
}
|
| 1001 |
}
|
| 1002 |
|
|
@@ -1051,12 +1052,12 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
| 1051 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1052 |
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
| 1053 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1054 |
-
#ifdef
|
| 1055 |
|
| 1056 |
-
typedef
|
| 1057 |
-
typedef
|
| 1058 |
-
typedef
|
| 1059 |
-
typedef
|
| 1060 |
|
| 1061 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 1062 |
constexpr int rows_per_warp = 2 * granularity;
|
|
@@ -1081,7 +1082,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
| 1081 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 1082 |
const int k0 = k00 + k01;
|
| 1083 |
|
| 1084 |
-
((mma_A_K8 *) A[n])[k01/QI8_1].
|
| 1085 |
}
|
| 1086 |
}
|
| 1087 |
|
|
@@ -1118,24 +1119,25 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
| 1118 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 1119 |
mma_B B[2];
|
| 1120 |
|
| 1121 |
-
|
| 1122 |
-
B[
|
|
|
|
| 1123 |
|
| 1124 |
mma_C Cm[2];
|
| 1125 |
if (k01 >= WARP_SIZE * 3/4) {
|
| 1126 |
mma_A A1;
|
| 1127 |
A1.x[0] = 0x01010101;
|
| 1128 |
A1.x[1] = 0x01010101;
|
| 1129 |
-
Cm[0].
|
| 1130 |
-
Cm[1].
|
| 1131 |
}
|
| 1132 |
|
| 1133 |
#pragma unroll
|
| 1134 |
for (int n = 0; n < ntx; ++n) {
|
| 1135 |
mma_C Cd[2];
|
| 1136 |
|
| 1137 |
-
Cd[0].
|
| 1138 |
-
Cd[1].
|
| 1139 |
|
| 1140 |
#pragma unroll
|
| 1141 |
for (int l = 0; l < mma_C::ne; ++l) {
|
|
@@ -1172,13 +1174,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
| 1172 |
#else
|
| 1173 |
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
|
| 1174 |
NO_DEVICE_CODE;
|
| 1175 |
-
#endif //
|
| 1176 |
}
|
| 1177 |
|
| 1178 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
|
| 1179 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1180 |
|
| 1181 |
-
#ifdef
|
| 1182 |
int * x_qs = (int *) x_tile;
|
| 1183 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 1184 |
#else
|
|
@@ -1186,7 +1188,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1186 |
int * x_qs = (int *) x_tile;
|
| 1187 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 1188 |
int * x_sc = (int *) (x_df + txs.dm);
|
| 1189 |
-
#endif //
|
| 1190 |
|
| 1191 |
const int kqsx = threadIdx.x % QI3_K;
|
| 1192 |
|
|
@@ -1212,11 +1214,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1212 |
|
| 1213 |
const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
|
| 1214 |
|
| 1215 |
-
#ifdef
|
| 1216 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
|
| 1217 |
#else
|
| 1218 |
x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
|
| 1219 |
-
#endif //
|
| 1220 |
}
|
| 1221 |
}
|
| 1222 |
|
|
@@ -1242,7 +1244,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1242 |
|
| 1243 |
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
|
| 1244 |
|
| 1245 |
-
#ifdef
|
| 1246 |
const int8_t * sc8 = (const int8_t *) ≻
|
| 1247 |
const float d = bxi->d;
|
| 1248 |
|
|
@@ -1252,10 +1254,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1252 |
}
|
| 1253 |
#else
|
| 1254 |
x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
|
| 1255 |
-
#endif //
|
| 1256 |
}
|
| 1257 |
|
| 1258 |
-
#ifndef
|
| 1259 |
#pragma unroll
|
| 1260 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
|
| 1261 |
int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
|
|
@@ -1268,7 +1270,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1268 |
|
| 1269 |
x_df[i] = bxi->d;
|
| 1270 |
}
|
| 1271 |
-
#endif //
|
| 1272 |
}
|
| 1273 |
|
| 1274 |
template <int mmq_x, int mmq_y, int nwarps>
|
|
@@ -1317,7 +1319,7 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
|
|
| 1317 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
|
| 1318 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1319 |
|
| 1320 |
-
#ifdef
|
| 1321 |
int * x_qs = (int *) x_tile;
|
| 1322 |
half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
|
| 1323 |
#else
|
|
@@ -1325,7 +1327,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1325 |
int * x_qs = (int *) x_tile;
|
| 1326 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 1327 |
int * x_sc = (int *) (x_dm + txs.dm);
|
| 1328 |
-
#endif //
|
| 1329 |
|
| 1330 |
#pragma unroll
|
| 1331 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
@@ -1338,15 +1340,15 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1338 |
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
|
| 1339 |
const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
|
| 1340 |
|
| 1341 |
-
#ifdef
|
| 1342 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
| 1343 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
|
| 1344 |
#else
|
| 1345 |
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
|
| 1346 |
-
#endif //
|
| 1347 |
}
|
| 1348 |
|
| 1349 |
-
#ifdef
|
| 1350 |
|
| 1351 |
#pragma unroll
|
| 1352 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
|
|
@@ -1407,7 +1409,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1407 |
|
| 1408 |
x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
|
| 1409 |
}
|
| 1410 |
-
#endif //
|
| 1411 |
}
|
| 1412 |
|
| 1413 |
template <int mmq_x, int mmq_y, int nwarps>
|
|
@@ -1446,7 +1448,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
|
| 1446 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
|
| 1447 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1448 |
|
| 1449 |
-
#ifdef
|
| 1450 |
int * x_qs = (int *) x_tile;
|
| 1451 |
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
|
| 1452 |
#else
|
|
@@ -1454,7 +1456,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1454 |
int * x_qs = (int *) x_tile;
|
| 1455 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 1456 |
int * x_sc = (int *) (x_dm + txs.dm);
|
| 1457 |
-
#endif //
|
| 1458 |
|
| 1459 |
#pragma unroll
|
| 1460 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
@@ -1478,16 +1480,16 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1478 |
const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
|
| 1479 |
const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
|
| 1480 |
|
| 1481 |
-
#ifdef
|
| 1482 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
|
| 1483 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
|
| 1484 |
#else
|
| 1485 |
x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
|
| 1486 |
x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
|
| 1487 |
-
#endif //
|
| 1488 |
}
|
| 1489 |
|
| 1490 |
-
#ifdef
|
| 1491 |
|
| 1492 |
#pragma unroll
|
| 1493 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
|
|
@@ -1548,7 +1550,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1548 |
|
| 1549 |
x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
|
| 1550 |
}
|
| 1551 |
-
#endif //
|
| 1552 |
}
|
| 1553 |
|
| 1554 |
template <int mmq_x, int mmq_y, int nwarps>
|
|
@@ -1587,7 +1589,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
|
| 1587 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
|
| 1588 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1589 |
|
| 1590 |
-
#ifdef
|
| 1591 |
int * x_qs = (int *) x_tile;
|
| 1592 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 1593 |
int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
|
|
@@ -1596,7 +1598,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1596 |
int * x_qs = (int *) x_tile;
|
| 1597 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 1598 |
int * x_sc = (int *) (x_df + txs.dm);
|
| 1599 |
-
#endif //
|
| 1600 |
|
| 1601 |
#pragma unroll
|
| 1602 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
@@ -1619,13 +1621,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1619 |
const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
|
| 1620 |
const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
|
| 1621 |
|
| 1622 |
-
#ifdef
|
| 1623 |
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
| 1624 |
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
| 1625 |
#else
|
| 1626 |
x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
| 1627 |
x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
| 1628 |
-
#endif //
|
| 1629 |
}
|
| 1630 |
|
| 1631 |
const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
|
|
@@ -1641,11 +1643,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1641 |
|
| 1642 |
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
|
| 1643 |
|
| 1644 |
-
#ifdef
|
| 1645 |
x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d;
|
| 1646 |
#else
|
| 1647 |
x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
|
| 1648 |
-
#endif //
|
| 1649 |
}
|
| 1650 |
|
| 1651 |
#pragma unroll
|
|
@@ -1658,11 +1660,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1658 |
|
| 1659 |
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
|
| 1660 |
|
| 1661 |
-
#ifdef
|
| 1662 |
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
|
| 1663 |
#else
|
| 1664 |
x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
|
| 1665 |
-
#endif //
|
| 1666 |
}
|
| 1667 |
}
|
| 1668 |
|
|
@@ -1702,11 +1704,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
| 1702 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1703 |
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
| 1704 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1705 |
-
#ifdef
|
| 1706 |
|
| 1707 |
-
typedef
|
| 1708 |
-
typedef
|
| 1709 |
-
typedef
|
| 1710 |
|
| 1711 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 1712 |
constexpr int rows_per_warp = 2 * granularity;
|
|
@@ -1732,8 +1734,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1732 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 1733 |
const int k0 = k00 + k01;
|
| 1734 |
|
| 1735 |
-
A[n][k01/4 + 0].
|
| 1736 |
-
A[n][k01/4 + 1].
|
| 1737 |
}
|
| 1738 |
|
| 1739 |
#pragma unroll
|
|
@@ -1771,8 +1773,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1771 |
mma_B B[2];
|
| 1772 |
float dB[mma_C::ne/2];
|
| 1773 |
|
| 1774 |
-
|
| 1775 |
-
B[
|
|
|
|
| 1776 |
|
| 1777 |
#pragma unroll
|
| 1778 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
@@ -1784,8 +1787,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1784 |
#pragma unroll
|
| 1785 |
for (int n = 0; n < ntx; ++n) {
|
| 1786 |
mma_C C[2];
|
| 1787 |
-
C[0].
|
| 1788 |
-
C[1].
|
| 1789 |
|
| 1790 |
#pragma unroll
|
| 1791 |
for (int l = 0; l < mma_C::ne; ++l) {
|
|
@@ -1805,20 +1808,20 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
| 1805 |
#else
|
| 1806 |
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
|
| 1807 |
NO_DEVICE_CODE;
|
| 1808 |
-
#endif //
|
| 1809 |
}
|
| 1810 |
|
| 1811 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
|
| 1812 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1813 |
|
| 1814 |
-
#ifdef
|
| 1815 |
int * x_qs = (int *) x_tile;
|
| 1816 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 1817 |
#else
|
| 1818 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
|
| 1819 |
int * x_qs = (int *) x_tile;
|
| 1820 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 1821 |
-
#endif //
|
| 1822 |
|
| 1823 |
const int kbx = threadIdx.x / QI4_NL;
|
| 1824 |
const int kqsx = threadIdx.x % QI4_NL;
|
|
@@ -1836,13 +1839,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1836 |
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
|
| 1837 |
const int2 v = get_int_from_table_16(aux_q4);
|
| 1838 |
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
| 1839 |
-
#ifdef
|
| 1840 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
| 1841 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
| 1842 |
#else
|
| 1843 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
| 1844 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
| 1845 |
-
#endif //
|
| 1846 |
}
|
| 1847 |
|
| 1848 |
const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
|
|
@@ -1858,25 +1861,25 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1858 |
|
| 1859 |
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
|
| 1860 |
|
| 1861 |
-
#ifdef
|
| 1862 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
|
| 1863 |
#else
|
| 1864 |
x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
|
| 1865 |
-
#endif //
|
| 1866 |
}
|
| 1867 |
}
|
| 1868 |
|
| 1869 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
|
| 1870 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1871 |
|
| 1872 |
-
#ifdef
|
| 1873 |
int * x_qs = (int *) x_tile;
|
| 1874 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 1875 |
#else
|
| 1876 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
|
| 1877 |
int * x_qs = (int *) x_tile;
|
| 1878 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 1879 |
-
#endif //
|
| 1880 |
|
| 1881 |
const int kqsx = threadIdx.x % (QI2_XXS/2);
|
| 1882 |
|
|
@@ -1905,36 +1908,36 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1905 |
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
| 1906 |
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
| 1907 |
|
| 1908 |
-
#ifdef
|
| 1909 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
|
| 1910 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
|
| 1911 |
#else
|
| 1912 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0;
|
| 1913 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1;
|
| 1914 |
-
#endif //
|
| 1915 |
}
|
| 1916 |
|
| 1917 |
const int ls = aux32 >> 28;
|
| 1918 |
const float d = bxi->d;
|
| 1919 |
-
#ifdef
|
| 1920 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
|
| 1921 |
#else
|
| 1922 |
x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4;
|
| 1923 |
-
#endif //
|
| 1924 |
}
|
| 1925 |
}
|
| 1926 |
|
| 1927 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
|
| 1928 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1929 |
|
| 1930 |
-
#ifdef
|
| 1931 |
int * x_qs = (int *) x_tile;
|
| 1932 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 1933 |
#else
|
| 1934 |
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
| 1935 |
int * x_qs = (int *) x_tile;
|
| 1936 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 1937 |
-
#endif //
|
| 1938 |
|
| 1939 |
const int kqsx = threadIdx.x % (QI2_XS/2);
|
| 1940 |
|
|
@@ -1959,38 +1962,38 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 1959 |
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
| 1960 |
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
| 1961 |
|
| 1962 |
-
#ifdef
|
| 1963 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
| 1964 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
| 1965 |
#else
|
| 1966 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
| 1967 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
| 1968 |
-
#endif //
|
| 1969 |
}
|
| 1970 |
|
| 1971 |
const int ls = bxi->scales[kqsx];
|
| 1972 |
const float d = bxi->d;
|
| 1973 |
-
#ifdef
|
| 1974 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 1975 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 1976 |
#else
|
| 1977 |
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 1978 |
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 1979 |
-
#endif //
|
| 1980 |
}
|
| 1981 |
}
|
| 1982 |
|
| 1983 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
|
| 1984 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1985 |
|
| 1986 |
-
#ifdef
|
| 1987 |
int * x_qs = (int *) x_tile;
|
| 1988 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 1989 |
#else
|
| 1990 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
|
| 1991 |
int * x_qs = (int *) x_tile;
|
| 1992 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 1993 |
-
#endif //
|
| 1994 |
|
| 1995 |
const int kqsx = threadIdx.x % (QI2_S/2);
|
| 1996 |
|
|
@@ -2022,38 +2025,38 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2022 |
const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
|
| 2023 |
const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
|
| 2024 |
|
| 2025 |
-
#ifdef
|
| 2026 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2027 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2028 |
#else
|
| 2029 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2030 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2031 |
-
#endif //
|
| 2032 |
}
|
| 2033 |
|
| 2034 |
const int ls = bxi->scales[kqsx];
|
| 2035 |
const float d = bxi->d;
|
| 2036 |
-
#ifdef
|
| 2037 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 2038 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 2039 |
#else
|
| 2040 |
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 2041 |
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 2042 |
-
#endif //
|
| 2043 |
}
|
| 2044 |
}
|
| 2045 |
|
| 2046 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
|
| 2047 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 2048 |
|
| 2049 |
-
#ifdef
|
| 2050 |
int * x_qs = (int *) x_tile;
|
| 2051 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 2052 |
#else
|
| 2053 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
|
| 2054 |
int * x_qs = (int *) x_tile;
|
| 2055 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2056 |
-
#endif //
|
| 2057 |
|
| 2058 |
const int kqsx = threadIdx.x % (QI3_XXS/2);
|
| 2059 |
|
|
@@ -2080,36 +2083,36 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2080 |
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
| 2081 |
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
| 2082 |
|
| 2083 |
-
#ifdef
|
| 2084 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2085 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2086 |
#else
|
| 2087 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2088 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2089 |
-
#endif //
|
| 2090 |
}
|
| 2091 |
|
| 2092 |
const int ls = aux32 >> 28;
|
| 2093 |
const float d = bxi->d;
|
| 2094 |
-
#ifdef
|
| 2095 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
|
| 2096 |
#else
|
| 2097 |
x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2;
|
| 2098 |
-
#endif //
|
| 2099 |
}
|
| 2100 |
}
|
| 2101 |
|
| 2102 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
|
| 2103 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 2104 |
|
| 2105 |
-
#ifdef
|
| 2106 |
int * x_qs = (int *) x_tile;
|
| 2107 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 2108 |
#else
|
| 2109 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
| 2110 |
int * x_qs = (int *) x_tile;
|
| 2111 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2112 |
-
#endif //
|
| 2113 |
|
| 2114 |
const int kqsx = threadIdx.x % (QI3_S/2);
|
| 2115 |
|
|
@@ -2143,36 +2146,36 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2143 |
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
| 2144 |
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
| 2145 |
|
| 2146 |
-
#ifdef
|
| 2147 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
|
| 2148 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
|
| 2149 |
#else
|
| 2150 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l;
|
| 2151 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h;
|
| 2152 |
-
#endif //
|
| 2153 |
}
|
| 2154 |
|
| 2155 |
const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
|
| 2156 |
const float d = bxi->d;
|
| 2157 |
-
#ifdef
|
| 2158 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
|
| 2159 |
#else
|
| 2160 |
x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d;
|
| 2161 |
-
#endif //
|
| 2162 |
}
|
| 2163 |
}
|
| 2164 |
|
| 2165 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
|
| 2166 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 2167 |
|
| 2168 |
-
#ifdef
|
| 2169 |
int * x_qs = (int *) x_tile;
|
| 2170 |
half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
|
| 2171 |
#else
|
| 2172 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
| 2173 |
int * x_qs = (int *) x_tile;
|
| 2174 |
half2 * x_ds = (half2 *) (x_qs + txs.qs);
|
| 2175 |
-
#endif //
|
| 2176 |
|
| 2177 |
const int kqsx = threadIdx.x % QI1_S;
|
| 2178 |
|
|
@@ -2198,37 +2201,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2198 |
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
|
| 2199 |
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
|
| 2200 |
|
| 2201 |
-
#ifdef
|
| 2202 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
|
| 2203 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
|
| 2204 |
#else
|
| 2205 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0;
|
| 2206 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1;
|
| 2207 |
-
#endif //
|
| 2208 |
}
|
| 2209 |
|
| 2210 |
const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
|
| 2211 |
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
|
| 2212 |
|
| 2213 |
-
#ifdef
|
| 2214 |
x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
|
| 2215 |
#else
|
| 2216 |
x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
|
| 2217 |
-
#endif //
|
| 2218 |
}
|
| 2219 |
}
|
| 2220 |
|
| 2221 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
|
| 2222 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 2223 |
|
| 2224 |
-
#ifdef
|
| 2225 |
int * x_qs = (int *) x_tile;
|
| 2226 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 2227 |
#else
|
| 2228 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
|
| 2229 |
int * x_qs = (int *) x_tile;
|
| 2230 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2231 |
-
#endif //
|
| 2232 |
|
| 2233 |
const int kbx = 0; // threadIdx.x / QI4_XS
|
| 2234 |
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
|
|
@@ -2246,13 +2249,13 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2246 |
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
| 2247 |
const int2 v = get_int_from_table_16(aux_q4);
|
| 2248 |
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
| 2249 |
-
#ifdef
|
| 2250 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
| 2251 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
| 2252 |
#else
|
| 2253 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
| 2254 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
| 2255 |
-
#endif //
|
| 2256 |
}
|
| 2257 |
|
| 2258 |
#pragma unroll
|
|
@@ -2270,11 +2273,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
| 2270 |
const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
|
| 2271 |
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
|
| 2272 |
|
| 2273 |
-
#ifdef
|
| 2274 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
|
| 2275 |
#else
|
| 2276 |
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
|
| 2277 |
-
#endif //
|
| 2278 |
}
|
| 2279 |
}
|
| 2280 |
|
|
@@ -2307,16 +2310,16 @@ template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
|
| 2307 |
static __device__ __forceinline__ void mmq_write_back_mma(
|
| 2308 |
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
| 2309 |
|
| 2310 |
-
typedef
|
| 2311 |
|
| 2312 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 2313 |
constexpr int rows_per_warp = 2 * granularity;
|
| 2314 |
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
| 2315 |
|
| 2316 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
|
| 2317 |
-
#ifdef
|
| 2318 |
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
|
| 2319 |
-
#endif //
|
| 2320 |
|
| 2321 |
#pragma unroll
|
| 2322 |
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
|
@@ -2505,13 +2508,13 @@ static __device__ void mul_mat_q_process_tile(
|
|
| 2505 |
int * tile_y = (int *) data_mul_mat_q;
|
| 2506 |
int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
|
| 2507 |
|
| 2508 |
-
#ifdef
|
| 2509 |
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
|
| 2510 |
constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
| 2511 |
#else
|
| 2512 |
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
|
| 2513 |
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 2514 |
-
#endif //
|
| 2515 |
|
| 2516 |
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
| 2517 |
|
|
@@ -2643,7 +2646,7 @@ static __global__ void mul_mat_q(
|
|
| 2643 |
const int jt = kbc / (blocks_per_ne00*nty);
|
| 2644 |
const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
|
| 2645 |
|
| 2646 |
-
constexpr bool fixup = true; // Last index writes
|
| 2647 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
| 2648 |
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
|
| 2649 |
it, jt, kb0_start, kb0_stop);
|
|
@@ -2749,7 +2752,7 @@ template<ggml_type type>
|
|
| 2749 |
static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
|
| 2750 |
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
|
| 2751 |
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
| 2752 |
-
const int shmem_x =
|
| 2753 |
const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
|
| 2754 |
return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
|
| 2755 |
}
|
|
|
|
| 87 |
};
|
| 88 |
|
| 89 |
static constexpr int get_mmq_x_max_host(const int cc) {
|
| 90 |
+
return new_mma_available(cc) ? 128 :
|
| 91 |
#ifdef GGML_CUDA_FORCE_MMQ
|
| 92 |
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64;
|
| 93 |
#else
|
|
|
|
| 96 |
}
|
| 97 |
|
| 98 |
static constexpr __device__ int get_mmq_x_max_device() {
|
| 99 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 100 |
return 128;
|
| 101 |
+
#else // NEW_MMA_AVAILABLE
|
| 102 |
|
| 103 |
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 104 |
return 128;
|
|
|
|
| 116 |
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
| 117 |
|
| 118 |
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 119 |
+
#endif // NEW_MMA_AVAILABLE
|
| 120 |
}
|
| 121 |
|
| 122 |
static constexpr int get_mmq_y_host(const int cc) {
|
|
|
|
| 209 |
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|
| 210 |
|
| 211 |
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
| 212 |
+
return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
|
| 213 |
}
|
| 214 |
|
| 215 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 216 |
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
|
| 217 |
return mmq_x >= 48 ? 16 : 8;
|
| 218 |
}
|
|
|
|
| 220 |
static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
|
| 221 |
return 8;
|
| 222 |
}
|
| 223 |
+
#endif // NEW_MMA_AVAILABLE
|
| 224 |
|
| 225 |
// ------------------------------------------------------------
|
| 226 |
|
| 227 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
|
| 228 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 229 |
|
| 230 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 231 |
int * x_qs = (int *) x_tile;
|
| 232 |
float * x_df = (float *) (x_qs + 2*WARP_SIZE);
|
| 233 |
#else
|
| 234 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
| 235 |
int * x_qs = (int *) x_tile;
|
| 236 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 237 |
+
#endif // NEW_MMA_AVAILABLE
|
| 238 |
|
| 239 |
const int kbx = threadIdx.x / QI4_0;
|
| 240 |
const int kqsx = threadIdx.x % QI4_0;
|
|
|
|
| 250 |
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
|
| 251 |
const int qs0 = get_int_b2(bxi->qs, kqsx);
|
| 252 |
|
| 253 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 254 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
|
| 255 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
|
| 256 |
#else
|
| 257 |
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
|
| 258 |
+
#endif // NEW_MMA_AVAILABLE
|
| 259 |
}
|
| 260 |
|
| 261 |
const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
|
|
|
|
| 271 |
|
| 272 |
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
|
| 273 |
|
| 274 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 275 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 276 |
#else
|
| 277 |
x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
|
| 278 |
+
#endif // NEW_MMA_AVAILABLE
|
| 279 |
}
|
| 280 |
}
|
| 281 |
|
|
|
|
| 322 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
|
| 323 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 324 |
|
| 325 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 326 |
int * x_qs = (int *) x_tile;
|
| 327 |
half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
|
| 328 |
#else
|
| 329 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
| 330 |
int * x_qs = (int *) x_tile;
|
| 331 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 332 |
+
#endif // NEW_MMA_AVAILABLE
|
| 333 |
|
| 334 |
const int kbx = threadIdx.x / QI4_1;
|
| 335 |
const int kqsx = threadIdx.x % QI4_1;
|
|
|
|
| 345 |
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
|
| 346 |
const int qs0 = get_int_b4(bxi->qs, kqsx);
|
| 347 |
|
| 348 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 349 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
| 350 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
|
| 351 |
#else
|
| 352 |
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
|
| 353 |
+
#endif // NEW_MMA_AVAILABLE
|
| 354 |
}
|
| 355 |
|
| 356 |
const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
|
|
|
|
| 366 |
|
| 367 |
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
|
| 368 |
|
| 369 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 370 |
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
| 371 |
#else
|
| 372 |
x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
|
| 373 |
+
#endif // NEW_MMA_AVAILABLE
|
| 374 |
}
|
| 375 |
}
|
| 376 |
|
|
|
|
| 417 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
|
| 418 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 419 |
|
| 420 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 421 |
int * x_qs = (int *) x_tile;
|
| 422 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 423 |
#else
|
| 424 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
|
| 425 |
int * x_qs = (int *) x_tile;
|
| 426 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 427 |
+
#endif // NEW_MMA_AVAILABLE
|
| 428 |
|
| 429 |
const int kbx = threadIdx.x / QI5_0;
|
| 430 |
const int kqsx = threadIdx.x % QI5_0;
|
|
|
|
| 456 |
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
| 457 |
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
|
| 458 |
|
| 459 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 460 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
| 461 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
| 462 |
#else
|
| 463 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
| 464 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
| 465 |
+
#endif // NEW_MMA_AVAILABLE
|
| 466 |
}
|
| 467 |
|
| 468 |
const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
|
|
|
|
| 478 |
|
| 479 |
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
|
| 480 |
|
| 481 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 482 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 483 |
#else
|
| 484 |
x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
|
| 485 |
+
#endif // NEW_MMA_AVAILABLE
|
| 486 |
}
|
| 487 |
}
|
| 488 |
|
| 489 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
|
| 490 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 491 |
|
| 492 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 493 |
int * x_qs = (int *) x_tile;
|
| 494 |
half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
|
| 495 |
#else
|
| 496 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
| 497 |
int * x_qs = (int *) x_tile;
|
| 498 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 499 |
+
#endif // NEW_MMA_AVAILABLE
|
| 500 |
|
| 501 |
const int kbx = threadIdx.x / QI5_1;
|
| 502 |
const int kqsx = threadIdx.x % QI5_1;
|
|
|
|
| 526 |
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
|
| 527 |
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
| 528 |
|
| 529 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 530 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
| 531 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
| 532 |
#else
|
| 533 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
| 534 |
x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
| 535 |
+
#endif // NEW_MMA_AVAILABLE
|
| 536 |
}
|
| 537 |
|
| 538 |
const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
|
|
|
|
| 548 |
|
| 549 |
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
|
| 550 |
|
| 551 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 552 |
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
| 553 |
#else
|
| 554 |
x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
|
| 555 |
+
#endif // NEW_MMA_AVAILABLE
|
| 556 |
}
|
| 557 |
}
|
| 558 |
|
| 559 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
|
| 560 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 561 |
|
| 562 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 563 |
int * x_qs = (int *) x_tile;
|
| 564 |
float * x_df = (float *) (x_tile + 2*WARP_SIZE);
|
| 565 |
#else
|
| 566 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
| 567 |
int * x_qs = (int *) x_tile;
|
| 568 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 569 |
+
#endif // NEW_MMA_AVAILABLE
|
| 570 |
|
| 571 |
const int kbx = threadIdx.x / QI8_0;
|
| 572 |
const int kqsx = threadIdx.x % QI8_0;
|
|
|
|
| 581 |
|
| 582 |
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
| 583 |
|
| 584 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 585 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
|
| 586 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
|
| 587 |
#else
|
| 588 |
x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
|
| 589 |
x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
|
| 590 |
+
#endif // NEW_MMA_AVAILABLE
|
| 591 |
}
|
| 592 |
|
| 593 |
const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
|
|
|
|
| 603 |
|
| 604 |
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
|
| 605 |
|
| 606 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 607 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
| 608 |
#else
|
| 609 |
x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
|
| 610 |
+
#endif // NEW_MMA_AVAILABLE
|
| 611 |
}
|
| 612 |
}
|
| 613 |
|
|
|
|
| 645 |
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
| 646 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 647 |
|
| 648 |
+
typedef mma_A_I16K8<int> mma_A;
|
| 649 |
+
typedef mma_B_J8K8<int> mma_B;
|
| 650 |
+
typedef mma_C_I16J8<int> mma_C;
|
| 651 |
|
| 652 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 653 |
constexpr int rows_per_warp = 2 * granularity;
|
|
|
|
| 672 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
| 673 |
const int k0 = k00 + k01;
|
| 674 |
|
| 675 |
+
A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
| 676 |
}
|
| 677 |
|
| 678 |
#pragma unroll
|
|
|
|
| 695 |
mma_B B;
|
| 696 |
float dB[mma_C::ne/2];
|
| 697 |
|
| 698 |
+
B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
| 699 |
|
| 700 |
#pragma unroll
|
| 701 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
|
|
| 711 |
#pragma unroll
|
| 712 |
for (int n = 0; n < ntx; ++n) {
|
| 713 |
mma_C C;
|
| 714 |
+
C.mma(A[n][k01/QI8_0], B);
|
| 715 |
|
| 716 |
#pragma unroll
|
| 717 |
for (int l = 0; l < mma_C::ne; ++l) {
|
|
|
|
| 756 |
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
| 757 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 758 |
|
| 759 |
+
typedef mma_A_I16K8<int> mma_A;
|
| 760 |
+
typedef mma_B_J8K8<int> mma_B;
|
| 761 |
+
typedef mma_C_I16J8<int> mma_C;
|
| 762 |
|
| 763 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 764 |
constexpr int rows_per_warp = 2 * granularity;
|
|
|
|
| 782 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 783 |
const int k0 = k00 + k01;
|
| 784 |
|
| 785 |
+
A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
| 786 |
}
|
| 787 |
|
| 788 |
#pragma unroll
|
|
|
|
| 805 |
mma_B B;
|
| 806 |
float2 dsB[mma_C::ne/2];
|
| 807 |
|
| 808 |
+
B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
| 809 |
|
| 810 |
#pragma unroll
|
| 811 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
|
|
| 817 |
#pragma unroll
|
| 818 |
for (int n = 0; n < ntx; ++n) {
|
| 819 |
mma_C C;
|
| 820 |
+
C.mma(A[n][k01/QI8_1], B);
|
| 821 |
|
| 822 |
#pragma unroll
|
| 823 |
for (int l = 0; l < mma_C::ne; ++l) {
|
|
|
|
| 864 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 865 |
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
| 866 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 867 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 868 |
|
| 869 |
+
typedef mma_A_I16K4<int> mma_A;
|
| 870 |
+
typedef mma_A_I16K8<int> mma_A_K8;
|
| 871 |
+
typedef mma_B_J8K4<int> mma_B;
|
| 872 |
+
typedef mma_C_I16J8<int> mma_C;
|
| 873 |
|
| 874 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 875 |
constexpr int rows_per_warp = 2 * granularity;
|
|
|
|
| 893 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 894 |
const int k0 = k00 + k01;
|
| 895 |
|
| 896 |
+
((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
| 897 |
}
|
| 898 |
|
| 899 |
#pragma unroll
|
|
|
|
| 916 |
mma_B B[2];
|
| 917 |
float dB[mma_C::ne/2];
|
| 918 |
|
| 919 |
+
// Here load_generic is faster than load_ldmatrix.
|
| 920 |
+
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
| 921 |
+
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
|
| 922 |
|
| 923 |
#pragma unroll
|
| 924 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
|
|
| 930 |
#pragma unroll
|
| 931 |
for (int n = 0; n < ntx; ++n) {
|
| 932 |
mma_C C[2];
|
| 933 |
+
C[0].mma(A[n][k01/4 + 0], B[0]);
|
| 934 |
+
C[1].mma(A[n][k01/4 + 1], B[1]);
|
| 935 |
|
| 936 |
#pragma unroll
|
| 937 |
for (int l = 0; l < mma_C::ne; ++l) {
|
|
|
|
| 943 |
#else
|
| 944 |
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
|
| 945 |
NO_DEVICE_CODE;
|
| 946 |
+
#endif // NEW_MMA_AVAILABLE
|
| 947 |
}
|
| 948 |
|
| 949 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
|
| 950 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 951 |
|
| 952 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 953 |
int * x_qs = (int *) x_tile;
|
| 954 |
half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
|
| 955 |
#else
|
| 956 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
| 957 |
int * x_qs = (int *) x_tile;
|
| 958 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 959 |
+
#endif // NEW_MMA_AVAILABLE
|
| 960 |
|
| 961 |
const int kqsx = threadIdx.x % QI2_K;
|
| 962 |
|
|
|
|
| 978 |
|
| 979 |
const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
|
| 980 |
|
| 981 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 982 |
x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
|
| 983 |
#else
|
| 984 |
x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
|
| 985 |
+
#endif // NEW_MMA_AVAILABLE
|
| 986 |
}
|
| 987 |
|
| 988 |
const int sc_m = bxi->scales[kqsx];
|
|
|
|
| 993 |
const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
|
| 994 |
#endif // FAST_FP16_AVAILABLE
|
| 995 |
|
| 996 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 997 |
x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
|
| 998 |
#else
|
| 999 |
x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik;
|
| 1000 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1001 |
}
|
| 1002 |
}
|
| 1003 |
|
|
|
|
| 1052 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1053 |
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
| 1054 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1055 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1056 |
|
| 1057 |
+
typedef mma_A_I16K4<int> mma_A;
|
| 1058 |
+
typedef mma_A_I16K8<int> mma_A_K8;
|
| 1059 |
+
typedef mma_B_J8K4<int> mma_B;
|
| 1060 |
+
typedef mma_C_I16J8<int> mma_C;
|
| 1061 |
|
| 1062 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 1063 |
constexpr int rows_per_warp = 2 * granularity;
|
|
|
|
| 1082 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 1083 |
const int k0 = k00 + k01;
|
| 1084 |
|
| 1085 |
+
((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
| 1086 |
}
|
| 1087 |
}
|
| 1088 |
|
|
|
|
| 1119 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
| 1120 |
mma_B B[2];
|
| 1121 |
|
| 1122 |
+
// Here load_generic is faster than load_ldmatrix.
|
| 1123 |
+
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
| 1124 |
+
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
|
| 1125 |
|
| 1126 |
mma_C Cm[2];
|
| 1127 |
if (k01 >= WARP_SIZE * 3/4) {
|
| 1128 |
mma_A A1;
|
| 1129 |
A1.x[0] = 0x01010101;
|
| 1130 |
A1.x[1] = 0x01010101;
|
| 1131 |
+
Cm[0].mma(A1, B[0]);
|
| 1132 |
+
Cm[1].mma(A1, B[1]);
|
| 1133 |
}
|
| 1134 |
|
| 1135 |
#pragma unroll
|
| 1136 |
for (int n = 0; n < ntx; ++n) {
|
| 1137 |
mma_C Cd[2];
|
| 1138 |
|
| 1139 |
+
Cd[0].mma(A[n][k01/4 + 0], B[0]);
|
| 1140 |
+
Cd[1].mma(A[n][k01/4 + 1], B[1]);
|
| 1141 |
|
| 1142 |
#pragma unroll
|
| 1143 |
for (int l = 0; l < mma_C::ne; ++l) {
|
|
|
|
| 1174 |
#else
|
| 1175 |
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
|
| 1176 |
NO_DEVICE_CODE;
|
| 1177 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1178 |
}
|
| 1179 |
|
| 1180 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
|
| 1181 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1182 |
|
| 1183 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1184 |
int * x_qs = (int *) x_tile;
|
| 1185 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 1186 |
#else
|
|
|
|
| 1188 |
int * x_qs = (int *) x_tile;
|
| 1189 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 1190 |
int * x_sc = (int *) (x_df + txs.dm);
|
| 1191 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1192 |
|
| 1193 |
const int kqsx = threadIdx.x % QI3_K;
|
| 1194 |
|
|
|
|
| 1214 |
|
| 1215 |
const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
|
| 1216 |
|
| 1217 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1218 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
|
| 1219 |
#else
|
| 1220 |
x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
|
| 1221 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1222 |
}
|
| 1223 |
}
|
| 1224 |
|
|
|
|
| 1244 |
|
| 1245 |
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
|
| 1246 |
|
| 1247 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1248 |
const int8_t * sc8 = (const int8_t *) ≻
|
| 1249 |
const float d = bxi->d;
|
| 1250 |
|
|
|
|
| 1254 |
}
|
| 1255 |
#else
|
| 1256 |
x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
|
| 1257 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1258 |
}
|
| 1259 |
|
| 1260 |
+
#ifndef NEW_MMA_AVAILABLE
|
| 1261 |
#pragma unroll
|
| 1262 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
|
| 1263 |
int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
|
|
|
|
| 1270 |
|
| 1271 |
x_df[i] = bxi->d;
|
| 1272 |
}
|
| 1273 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1274 |
}
|
| 1275 |
|
| 1276 |
template <int mmq_x, int mmq_y, int nwarps>
|
|
|
|
| 1319 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
|
| 1320 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1321 |
|
| 1322 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1323 |
int * x_qs = (int *) x_tile;
|
| 1324 |
half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
|
| 1325 |
#else
|
|
|
|
| 1327 |
int * x_qs = (int *) x_tile;
|
| 1328 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 1329 |
int * x_sc = (int *) (x_dm + txs.dm);
|
| 1330 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1331 |
|
| 1332 |
#pragma unroll
|
| 1333 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
|
|
| 1340 |
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
|
| 1341 |
const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
|
| 1342 |
|
| 1343 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1344 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
| 1345 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
|
| 1346 |
#else
|
| 1347 |
x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
|
| 1348 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1349 |
}
|
| 1350 |
|
| 1351 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1352 |
|
| 1353 |
#pragma unroll
|
| 1354 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
|
|
|
|
| 1409 |
|
| 1410 |
x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
|
| 1411 |
}
|
| 1412 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1413 |
}
|
| 1414 |
|
| 1415 |
template <int mmq_x, int mmq_y, int nwarps>
|
|
|
|
| 1448 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
|
| 1449 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1450 |
|
| 1451 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1452 |
int * x_qs = (int *) x_tile;
|
| 1453 |
half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
|
| 1454 |
#else
|
|
|
|
| 1456 |
int * x_qs = (int *) x_tile;
|
| 1457 |
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
| 1458 |
int * x_sc = (int *) (x_dm + txs.dm);
|
| 1459 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1460 |
|
| 1461 |
#pragma unroll
|
| 1462 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
|
|
| 1480 |
const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
|
| 1481 |
const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
|
| 1482 |
|
| 1483 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1484 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
|
| 1485 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
|
| 1486 |
#else
|
| 1487 |
x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
|
| 1488 |
x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
|
| 1489 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1490 |
}
|
| 1491 |
|
| 1492 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1493 |
|
| 1494 |
#pragma unroll
|
| 1495 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
|
|
|
|
| 1550 |
|
| 1551 |
x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
|
| 1552 |
}
|
| 1553 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1554 |
}
|
| 1555 |
|
| 1556 |
template <int mmq_x, int mmq_y, int nwarps>
|
|
|
|
| 1589 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
|
| 1590 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1591 |
|
| 1592 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1593 |
int * x_qs = (int *) x_tile;
|
| 1594 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 1595 |
int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
|
|
|
|
| 1598 |
int * x_qs = (int *) x_tile;
|
| 1599 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 1600 |
int * x_sc = (int *) (x_df + txs.dm);
|
| 1601 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1602 |
|
| 1603 |
#pragma unroll
|
| 1604 |
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
|
|
| 1621 |
const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
|
| 1622 |
const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
|
| 1623 |
|
| 1624 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1625 |
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
| 1626 |
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
| 1627 |
#else
|
| 1628 |
x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
| 1629 |
x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
| 1630 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1631 |
}
|
| 1632 |
|
| 1633 |
const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
|
|
|
|
| 1643 |
|
| 1644 |
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
|
| 1645 |
|
| 1646 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1647 |
x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d;
|
| 1648 |
#else
|
| 1649 |
x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
|
| 1650 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1651 |
}
|
| 1652 |
|
| 1653 |
#pragma unroll
|
|
|
|
| 1660 |
|
| 1661 |
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
|
| 1662 |
|
| 1663 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1664 |
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
|
| 1665 |
#else
|
| 1666 |
x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
|
| 1667 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1668 |
}
|
| 1669 |
}
|
| 1670 |
|
|
|
|
| 1704 |
template <int mmq_x, int mmq_y, int nwarps>
|
| 1705 |
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
| 1706 |
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
| 1707 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1708 |
|
| 1709 |
+
typedef mma_A_I16K4<int> mma_A;
|
| 1710 |
+
typedef mma_B_J8K4<int> mma_B;
|
| 1711 |
+
typedef mma_C_I16J8<int> mma_C;
|
| 1712 |
|
| 1713 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 1714 |
constexpr int rows_per_warp = 2 * granularity;
|
|
|
|
| 1734 |
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
| 1735 |
const int k0 = k00 + k01;
|
| 1736 |
|
| 1737 |
+
A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
|
| 1738 |
+
A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
|
| 1739 |
}
|
| 1740 |
|
| 1741 |
#pragma unroll
|
|
|
|
| 1773 |
mma_B B[2];
|
| 1774 |
float dB[mma_C::ne/2];
|
| 1775 |
|
| 1776 |
+
// Here load_generic is faster than load_ldmatrix.
|
| 1777 |
+
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
|
| 1778 |
+
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
|
| 1779 |
|
| 1780 |
#pragma unroll
|
| 1781 |
for (int l = 0; l < mma_C::ne/2; ++l) {
|
|
|
|
| 1787 |
#pragma unroll
|
| 1788 |
for (int n = 0; n < ntx; ++n) {
|
| 1789 |
mma_C C[2];
|
| 1790 |
+
C[0].mma(A[n][k01/4 + 0], B[0]);
|
| 1791 |
+
C[1].mma(A[n][k01/4 + 1], B[1]);
|
| 1792 |
|
| 1793 |
#pragma unroll
|
| 1794 |
for (int l = 0; l < mma_C::ne; ++l) {
|
|
|
|
| 1808 |
#else
|
| 1809 |
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
|
| 1810 |
NO_DEVICE_CODE;
|
| 1811 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1812 |
}
|
| 1813 |
|
| 1814 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
|
| 1815 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1816 |
|
| 1817 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1818 |
int * x_qs = (int *) x_tile;
|
| 1819 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 1820 |
#else
|
| 1821 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
|
| 1822 |
int * x_qs = (int *) x_tile;
|
| 1823 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 1824 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1825 |
|
| 1826 |
const int kbx = threadIdx.x / QI4_NL;
|
| 1827 |
const int kqsx = threadIdx.x % QI4_NL;
|
|
|
|
| 1839 |
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
|
| 1840 |
const int2 v = get_int_from_table_16(aux_q4);
|
| 1841 |
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
| 1842 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1843 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
| 1844 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
| 1845 |
#else
|
| 1846 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
| 1847 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
| 1848 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1849 |
}
|
| 1850 |
|
| 1851 |
const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
|
|
|
|
| 1861 |
|
| 1862 |
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
|
| 1863 |
|
| 1864 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1865 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
|
| 1866 |
#else
|
| 1867 |
x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
|
| 1868 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1869 |
}
|
| 1870 |
}
|
| 1871 |
|
| 1872 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
|
| 1873 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1874 |
|
| 1875 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1876 |
int * x_qs = (int *) x_tile;
|
| 1877 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 1878 |
#else
|
| 1879 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
|
| 1880 |
int * x_qs = (int *) x_tile;
|
| 1881 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 1882 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1883 |
|
| 1884 |
const int kqsx = threadIdx.x % (QI2_XXS/2);
|
| 1885 |
|
|
|
|
| 1908 |
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
| 1909 |
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
| 1910 |
|
| 1911 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1912 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
|
| 1913 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
|
| 1914 |
#else
|
| 1915 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0;
|
| 1916 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1;
|
| 1917 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1918 |
}
|
| 1919 |
|
| 1920 |
const int ls = aux32 >> 28;
|
| 1921 |
const float d = bxi->d;
|
| 1922 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1923 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
|
| 1924 |
#else
|
| 1925 |
x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4;
|
| 1926 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1927 |
}
|
| 1928 |
}
|
| 1929 |
|
| 1930 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
|
| 1931 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1932 |
|
| 1933 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1934 |
int * x_qs = (int *) x_tile;
|
| 1935 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 1936 |
#else
|
| 1937 |
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
| 1938 |
int * x_qs = (int *) x_tile;
|
| 1939 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 1940 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1941 |
|
| 1942 |
const int kqsx = threadIdx.x % (QI2_XS/2);
|
| 1943 |
|
|
|
|
| 1962 |
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
| 1963 |
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
| 1964 |
|
| 1965 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1966 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
| 1967 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
| 1968 |
#else
|
| 1969 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
| 1970 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
| 1971 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1972 |
}
|
| 1973 |
|
| 1974 |
const int ls = bxi->scales[kqsx];
|
| 1975 |
const float d = bxi->d;
|
| 1976 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1977 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 1978 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 1979 |
#else
|
| 1980 |
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 1981 |
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 1982 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1983 |
}
|
| 1984 |
}
|
| 1985 |
|
| 1986 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
|
| 1987 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 1988 |
|
| 1989 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 1990 |
int * x_qs = (int *) x_tile;
|
| 1991 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 1992 |
#else
|
| 1993 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
|
| 1994 |
int * x_qs = (int *) x_tile;
|
| 1995 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 1996 |
+
#endif // NEW_MMA_AVAILABLE
|
| 1997 |
|
| 1998 |
const int kqsx = threadIdx.x % (QI2_S/2);
|
| 1999 |
|
|
|
|
| 2025 |
const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
|
| 2026 |
const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
|
| 2027 |
|
| 2028 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2029 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2030 |
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2031 |
#else
|
| 2032 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2033 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2034 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2035 |
}
|
| 2036 |
|
| 2037 |
const int ls = bxi->scales[kqsx];
|
| 2038 |
const float d = bxi->d;
|
| 2039 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2040 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 2041 |
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 2042 |
#else
|
| 2043 |
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
| 2044 |
x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
| 2045 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2046 |
}
|
| 2047 |
}
|
| 2048 |
|
| 2049 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
|
| 2050 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 2051 |
|
| 2052 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2053 |
int * x_qs = (int *) x_tile;
|
| 2054 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 2055 |
#else
|
| 2056 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
|
| 2057 |
int * x_qs = (int *) x_tile;
|
| 2058 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2059 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2060 |
|
| 2061 |
const int kqsx = threadIdx.x % (QI3_XXS/2);
|
| 2062 |
|
|
|
|
| 2083 |
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
| 2084 |
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
| 2085 |
|
| 2086 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2087 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2088 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2089 |
#else
|
| 2090 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
| 2091 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
| 2092 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2093 |
}
|
| 2094 |
|
| 2095 |
const int ls = aux32 >> 28;
|
| 2096 |
const float d = bxi->d;
|
| 2097 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2098 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
|
| 2099 |
#else
|
| 2100 |
x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2;
|
| 2101 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2102 |
}
|
| 2103 |
}
|
| 2104 |
|
| 2105 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
|
| 2106 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 2107 |
|
| 2108 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2109 |
int * x_qs = (int *) x_tile;
|
| 2110 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 2111 |
#else
|
| 2112 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
| 2113 |
int * x_qs = (int *) x_tile;
|
| 2114 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2115 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2116 |
|
| 2117 |
const int kqsx = threadIdx.x % (QI3_S/2);
|
| 2118 |
|
|
|
|
| 2146 |
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
| 2147 |
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
| 2148 |
|
| 2149 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2150 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
|
| 2151 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
|
| 2152 |
#else
|
| 2153 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l;
|
| 2154 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h;
|
| 2155 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2156 |
}
|
| 2157 |
|
| 2158 |
const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
|
| 2159 |
const float d = bxi->d;
|
| 2160 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2161 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
|
| 2162 |
#else
|
| 2163 |
x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d;
|
| 2164 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2165 |
}
|
| 2166 |
}
|
| 2167 |
|
| 2168 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
|
| 2169 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 2170 |
|
| 2171 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2172 |
int * x_qs = (int *) x_tile;
|
| 2173 |
half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
|
| 2174 |
#else
|
| 2175 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
| 2176 |
int * x_qs = (int *) x_tile;
|
| 2177 |
half2 * x_ds = (half2 *) (x_qs + txs.qs);
|
| 2178 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2179 |
|
| 2180 |
const int kqsx = threadIdx.x % QI1_S;
|
| 2181 |
|
|
|
|
| 2201 |
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
|
| 2202 |
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
|
| 2203 |
|
| 2204 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2205 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
|
| 2206 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
|
| 2207 |
#else
|
| 2208 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0;
|
| 2209 |
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1;
|
| 2210 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2211 |
}
|
| 2212 |
|
| 2213 |
const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
|
| 2214 |
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
|
| 2215 |
|
| 2216 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2217 |
x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
|
| 2218 |
#else
|
| 2219 |
x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
|
| 2220 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2221 |
}
|
| 2222 |
}
|
| 2223 |
|
| 2224 |
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
|
| 2225 |
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
| 2226 |
|
| 2227 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2228 |
int * x_qs = (int *) x_tile;
|
| 2229 |
float * x_df = (float *) (x_qs + WARP_SIZE*2);
|
| 2230 |
#else
|
| 2231 |
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
|
| 2232 |
int * x_qs = (int *) x_tile;
|
| 2233 |
float * x_df = (float *) (x_qs + txs.qs);
|
| 2234 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2235 |
|
| 2236 |
const int kbx = 0; // threadIdx.x / QI4_XS
|
| 2237 |
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
|
|
|
|
| 2249 |
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
| 2250 |
const int2 v = get_int_from_table_16(aux_q4);
|
| 2251 |
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
|
| 2252 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2253 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
| 2254 |
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
| 2255 |
#else
|
| 2256 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
|
| 2257 |
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
|
| 2258 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2259 |
}
|
| 2260 |
|
| 2261 |
#pragma unroll
|
|
|
|
| 2273 |
const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
|
| 2274 |
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
|
| 2275 |
|
| 2276 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2277 |
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
|
| 2278 |
#else
|
| 2279 |
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
|
| 2280 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2281 |
}
|
| 2282 |
}
|
| 2283 |
|
|
|
|
| 2310 |
static __device__ __forceinline__ void mmq_write_back_mma(
|
| 2311 |
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
| 2312 |
|
| 2313 |
+
typedef mma_C_I16J8<int> mma_C;
|
| 2314 |
|
| 2315 |
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
| 2316 |
constexpr int rows_per_warp = 2 * granularity;
|
| 2317 |
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
| 2318 |
|
| 2319 |
const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
|
| 2320 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2321 |
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
|
| 2322 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2323 |
|
| 2324 |
#pragma unroll
|
| 2325 |
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
|
|
|
| 2508 |
int * tile_y = (int *) data_mul_mat_q;
|
| 2509 |
int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
|
| 2510 |
|
| 2511 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 2512 |
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
|
| 2513 |
constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
| 2514 |
#else
|
| 2515 |
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
|
| 2516 |
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
| 2517 |
+
#endif // NEW_MMA_AVAILABLE
|
| 2518 |
|
| 2519 |
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
| 2520 |
|
|
|
|
| 2646 |
const int jt = kbc / (blocks_per_ne00*nty);
|
| 2647 |
const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
|
| 2648 |
|
| 2649 |
+
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
| 2650 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
| 2651 |
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
|
| 2652 |
it, jt, kb0_start, kb0_stop);
|
|
|
|
| 2752 |
static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
|
| 2753 |
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
|
| 2754 |
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
| 2755 |
+
const int shmem_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
| 2756 |
const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
|
| 2757 |
return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
|
| 2758 |
}
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 16);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 16);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 16);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 16);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 16);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 16);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 32);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 32);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 32);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 32);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 32);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 32);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 64);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 64);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 64);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 64);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 64);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 8);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 8);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 8);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 8);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 8);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 8);
|
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
-
|
| 3 |
-
#include "../fattn-wmma-f16.cuh"
|
| 4 |
-
|
| 5 |
-
DECL_FATTN_WMMA_F16_CASE(64, 16, float);
|
| 6 |
-
DECL_FATTN_WMMA_F16_CASE(80, 16, float);
|
| 7 |
-
DECL_FATTN_WMMA_F16_CASE(96, 16, float);
|
| 8 |
-
DECL_FATTN_WMMA_F16_CASE(112, 16, float);
|
| 9 |
-
DECL_FATTN_WMMA_F16_CASE(128, 16, float);
|
| 10 |
-
DECL_FATTN_WMMA_F16_CASE(256, 16, float);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
-
|
| 3 |
-
#include "../fattn-wmma-f16.cuh"
|
| 4 |
-
|
| 5 |
-
DECL_FATTN_WMMA_F16_CASE(64, 32, float);
|
| 6 |
-
DECL_FATTN_WMMA_F16_CASE(80, 32, float);
|
| 7 |
-
DECL_FATTN_WMMA_F16_CASE(96, 32, float);
|
| 8 |
-
DECL_FATTN_WMMA_F16_CASE(112, 32, float);
|
| 9 |
-
DECL_FATTN_WMMA_F16_CASE(128, 32, float);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
-
|
| 3 |
-
#include "../fattn-wmma-f16.cuh"
|
| 4 |
-
|
| 5 |
-
DECL_FATTN_WMMA_F16_CASE(64, 16, half);
|
| 6 |
-
DECL_FATTN_WMMA_F16_CASE(80, 16, half);
|
| 7 |
-
DECL_FATTN_WMMA_F16_CASE(96, 16, half);
|
| 8 |
-
DECL_FATTN_WMMA_F16_CASE(112, 16, half);
|
| 9 |
-
DECL_FATTN_WMMA_F16_CASE(128, 16, half);
|
| 10 |
-
DECL_FATTN_WMMA_F16_CASE(256, 16, half);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
-
|
| 3 |
-
#include "../fattn-wmma-f16.cuh"
|
| 4 |
-
|
| 5 |
-
DECL_FATTN_WMMA_F16_CASE(64, 32, half);
|
| 6 |
-
DECL_FATTN_WMMA_F16_CASE(80, 32, half);
|
| 7 |
-
DECL_FATTN_WMMA_F16_CASE(96, 32, half);
|
| 8 |
-
DECL_FATTN_WMMA_F16_CASE(112, 32, half);
|
| 9 |
-
DECL_FATTN_WMMA_F16_CASE(128, 32, half);
|
| 10 |
-
DECL_FATTN_WMMA_F16_CASE(256, 32, half);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
-
|
| 3 |
-
#include "../fattn-wmma-f16.cuh"
|
| 4 |
-
|
| 5 |
-
DECL_FATTN_WMMA_F16_CASE(64, 8, half);
|
| 6 |
-
DECL_FATTN_WMMA_F16_CASE(96, 8, half);
|
| 7 |
-
DECL_FATTN_WMMA_F16_CASE(128, 8, half);
|
| 8 |
-
DECL_FATTN_WMMA_F16_CASE(256, 8, half);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ggml/src/ggml-cuda/template-instances/generate_cu_files.py
CHANGED
|
@@ -12,13 +12,13 @@ SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.p
|
|
| 12 |
DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v});
|
| 13 |
"""
|
| 14 |
|
| 15 |
-
|
| 16 |
|
| 17 |
-
#include "../fattn-
|
| 18 |
|
| 19 |
"""
|
| 20 |
|
| 21 |
-
|
| 22 |
|
| 23 |
TYPES_MMQ = [
|
| 24 |
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
|
@@ -57,20 +57,12 @@ for vkq_size in [16, 32]:
|
|
| 57 |
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
| 58 |
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
|
| 59 |
|
| 60 |
-
for
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
continue
|
| 64 |
|
| 65 |
-
|
| 66 |
-
f.write(
|
| 67 |
-
|
| 68 |
-
for head_size in [64, 80, 96, 112, 128, 256]:
|
| 69 |
-
if cols_per_block == 8 and head_size % 32 != 0: # wmma fragment is 8x32
|
| 70 |
-
continue
|
| 71 |
-
if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance
|
| 72 |
-
continue
|
| 73 |
-
f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size))
|
| 74 |
|
| 75 |
for type in TYPES_MMQ:
|
| 76 |
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|
|
|
|
| 12 |
DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v});
|
| 13 |
"""
|
| 14 |
|
| 15 |
+
SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 16 |
|
| 17 |
+
#include "../fattn-mma-f16.cuh"
|
| 18 |
|
| 19 |
"""
|
| 20 |
|
| 21 |
+
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n"
|
| 22 |
|
| 23 |
TYPES_MMQ = [
|
| 24 |
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
|
|
|
| 57 |
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
| 58 |
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
|
| 59 |
|
| 60 |
+
for cols_per_block in [8, 16, 32, 64]:
|
| 61 |
+
with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f:
|
| 62 |
+
f.write(SOURCE_FATTN_MMA_START)
|
|
|
|
| 63 |
|
| 64 |
+
for head_size in [64, 80, 96, 112, 128, 256]:
|
| 65 |
+
f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
for type in TYPES_MMQ:
|
| 68 |
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|
ggml/src/ggml-cuda/vendors/hip.h
CHANGED
|
@@ -25,6 +25,7 @@
|
|
| 25 |
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
|
| 26 |
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
| 27 |
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
|
|
|
| 28 |
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
| 29 |
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
| 30 |
#define cublasCreate hipblasCreate
|
|
|
|
| 25 |
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
|
| 26 |
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
| 27 |
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
| 28 |
+
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
| 29 |
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
| 30 |
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
| 31 |
#define cublasCreate hipblasCreate
|
ggml/src/ggml-hip/CMakeLists.txt
CHANGED
|
@@ -50,7 +50,7 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
|
|
| 50 |
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
|
| 51 |
|
| 52 |
file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
|
| 53 |
-
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-
|
| 54 |
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
| 55 |
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
| 56 |
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
|
|
|
| 50 |
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
|
| 51 |
|
| 52 |
file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
|
| 53 |
+
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
|
| 54 |
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
| 55 |
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
| 56 |
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
ggml/src/ggml-musa/CMakeLists.txt
CHANGED
|
@@ -29,7 +29,7 @@ if (MUSAToolkit_FOUND)
|
|
| 29 |
list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
|
| 30 |
|
| 31 |
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
|
| 32 |
-
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-
|
| 33 |
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
| 34 |
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
| 35 |
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
|
|
|
| 29 |
list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
|
| 30 |
|
| 31 |
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
|
| 32 |
+
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
|
| 33 |
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
| 34 |
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
| 35 |
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|