JohannesGaessler ggerganov commited on
Commit
88864af
·
1 Parent(s): 0bd2be3

CUDA: fix quantized KV cache + multiple sequences (llama/14822)

Browse files

* CUDA: fix quantized KV cache + multiple sequences

* Update ggml/src/ggml-cuda/fattn-common.cuh

Co-authored-by: Georgi Gerganov <[email protected]>

---------

Co-authored-by: Georgi Gerganov <[email protected]>

ggml/src/ggml-cuda/convert.cu CHANGED
@@ -6,24 +6,33 @@
6
  #define CUDA_Q8_0_NE_ALIGN 2048
7
 
8
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
9
- static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
10
- const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
 
 
11
 
12
- if (i >= k) {
13
  return;
14
  }
15
 
16
- const int64_t ib = i/qk; // block index
17
- const int64_t iqs = (i%qk)/qr; // quant index
18
- const int64_t iybs = i - i%qk; // y block start index
 
 
 
 
 
 
19
  const int64_t y_offset = qr == 1 ? 1 : qk/2;
20
 
21
  // dequantize
22
  dfloat2 v;
23
  dequantize_kernel(vx, ib, iqs, v);
24
 
25
- y[iybs + iqs + 0] = v.x;
26
- y[iybs + iqs + y_offset] = v.y;
 
27
  }
28
 
29
  template <bool need_check>
@@ -457,9 +466,17 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
457
  }
458
 
459
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
460
- static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
461
- const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
462
- dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 
 
 
 
 
 
 
 
463
  }
464
 
465
  static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
@@ -624,14 +641,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
624
  case GGML_TYPE_Q4_1:
625
  return dequantize_row_q4_1_cuda;
626
  case GGML_TYPE_Q5_0:
627
- return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
628
  case GGML_TYPE_Q5_1:
629
- return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
630
  case GGML_TYPE_Q8_0:
631
  if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
632
  return dequantize_block_q8_0_f16_cuda;
633
  }
634
- return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
635
  case GGML_TYPE_Q2_K:
636
  return dequantize_row_q2_K_cuda;
637
  case GGML_TYPE_Q3_K:
@@ -676,11 +693,11 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
676
  case GGML_TYPE_Q4_1:
677
  return dequantize_row_q4_1_cuda;
678
  case GGML_TYPE_Q5_0:
679
- return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
680
  case GGML_TYPE_Q5_1:
681
- return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
682
  case GGML_TYPE_Q8_0:
683
- return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
684
  case GGML_TYPE_Q2_K:
685
  return dequantize_row_q2_K_cuda;
686
  case GGML_TYPE_Q3_K:
@@ -722,6 +739,16 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
722
  switch (type) {
723
  case GGML_TYPE_F32:
724
  return convert_unary_cuda<float>;
 
 
 
 
 
 
 
 
 
 
725
  case GGML_TYPE_BF16:
726
  return convert_unary_cuda<nv_bfloat16>;
727
  default:
@@ -733,6 +760,16 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
733
  switch (type) {
734
  case GGML_TYPE_F32:
735
  return convert_unary_cuda<float, nv_bfloat16>;
 
 
 
 
 
 
 
 
 
 
736
  case GGML_TYPE_F16:
737
  return convert_unary_cuda<half, nv_bfloat16>;
738
  default:
@@ -744,6 +781,16 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
744
  switch (type) {
745
  case GGML_TYPE_F16:
746
  return convert_unary_cuda<half, float>;
 
 
 
 
 
 
 
 
 
 
747
  case GGML_TYPE_BF16:
748
  return convert_unary_cuda<nv_bfloat16, float>;
749
  default:
 
6
  #define CUDA_Q8_0_NE_ALIGN 2048
7
 
8
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
9
+ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
10
+ const int64_t ne00, const int64_t ne01, const int64_t ne02,
11
+ const int64_t s01, const int64_t s02, const int64_t s03) {
12
+ const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
13
 
14
+ if (i00 >= ne00) {
15
  return;
16
  }
17
 
18
+ const int64_t i01 = blockIdx.y;
19
+ const int64_t i02 = blockIdx.z % ne02;
20
+ const int64_t i03 = blockIdx.z / ne02;
21
+
22
+ const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
23
+
24
+ const int64_t ib = ibx0 + i00/qk; // block index
25
+ const int64_t iqs = (i00%qk)/qr; // quant index
26
+ const int64_t iybs = i00 - i00%qk; // y block start index
27
  const int64_t y_offset = qr == 1 ? 1 : qk/2;
28
 
29
  // dequantize
30
  dfloat2 v;
31
  dequantize_kernel(vx, ib, iqs, v);
32
 
33
+ const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
34
+ y[iy0 + 0] = v.x;
35
+ y[iy0 + y_offset] = v.y;
36
  }
37
 
38
  template <bool need_check>
 
466
  }
467
 
468
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
469
+ static void dequantize_block_cuda(const void * vx, dst_t * y,
470
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
471
+ const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
472
+ const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
473
+ dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
474
+ (vx, y, ne00, ne01, ne02, s01, s02, s03);
475
+ }
476
+
477
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
478
+ static void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
479
+ dequantize_block_cuda<qk, qr, dequantize_kernel, dst_t>(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream);
480
  }
481
 
482
  static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
 
641
  case GGML_TYPE_Q4_1:
642
  return dequantize_row_q4_1_cuda;
643
  case GGML_TYPE_Q5_0:
644
+ return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
645
  case GGML_TYPE_Q5_1:
646
+ return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
647
  case GGML_TYPE_Q8_0:
648
  if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
649
  return dequantize_block_q8_0_f16_cuda;
650
  }
651
+ return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
652
  case GGML_TYPE_Q2_K:
653
  return dequantize_row_q2_K_cuda;
654
  case GGML_TYPE_Q3_K:
 
693
  case GGML_TYPE_Q4_1:
694
  return dequantize_row_q4_1_cuda;
695
  case GGML_TYPE_Q5_0:
696
+ return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
697
  case GGML_TYPE_Q5_1:
698
+ return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
699
  case GGML_TYPE_Q8_0:
700
+ return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
701
  case GGML_TYPE_Q2_K:
702
  return dequantize_row_q2_K_cuda;
703
  case GGML_TYPE_Q3_K:
 
739
  switch (type) {
740
  case GGML_TYPE_F32:
741
  return convert_unary_cuda<float>;
742
+ case GGML_TYPE_Q4_0:
743
+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
744
+ case GGML_TYPE_Q4_1:
745
+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
746
+ case GGML_TYPE_Q5_0:
747
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
748
+ case GGML_TYPE_Q5_1:
749
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
750
+ case GGML_TYPE_Q8_0:
751
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
752
  case GGML_TYPE_BF16:
753
  return convert_unary_cuda<nv_bfloat16>;
754
  default:
 
760
  switch (type) {
761
  case GGML_TYPE_F32:
762
  return convert_unary_cuda<float, nv_bfloat16>;
763
+ case GGML_TYPE_Q4_0:
764
+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
765
+ case GGML_TYPE_Q4_1:
766
+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
767
+ case GGML_TYPE_Q5_0:
768
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
769
+ case GGML_TYPE_Q5_1:
770
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
771
+ case GGML_TYPE_Q8_0:
772
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
773
  case GGML_TYPE_F16:
774
  return convert_unary_cuda<half, nv_bfloat16>;
775
  default:
 
781
  switch (type) {
782
  case GGML_TYPE_F16:
783
  return convert_unary_cuda<half, float>;
784
+ case GGML_TYPE_Q4_0:
785
+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
786
+ case GGML_TYPE_Q4_1:
787
+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
788
+ case GGML_TYPE_Q5_0:
789
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
790
+ case GGML_TYPE_Q5_1:
791
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
792
+ case GGML_TYPE_Q8_0:
793
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
794
  case GGML_TYPE_BF16:
795
  return convert_unary_cuda<nv_bfloat16, float>;
796
  default:
ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -745,33 +745,58 @@ void launch_fattn(
745
  size_t nb23 = V ? V->nb[3] : nb13;
746
 
747
  if (need_f16_K && K->type != GGML_TYPE_F16) {
748
- GGML_ASSERT(ggml_is_contiguously_allocated(K));
749
- K_f16.alloc(ggml_nelements(K));
750
- to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
751
- to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
752
- K_data = (char *) K_f16.ptr;
753
-
754
  const size_t bs = ggml_blck_size(K->type);
755
  const size_t ts = ggml_type_size(K->type);
756
 
757
- nb11 = nb11*bs*sizeof(half)/ts;
758
- nb12 = nb12*bs*sizeof(half)/ts;
759
- nb13 = nb13*bs*sizeof(half)/ts;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760
  }
761
 
762
  if (V && need_f16_V && V->type != GGML_TYPE_F16) {
763
- GGML_ASSERT(ggml_is_contiguously_allocated(V));
764
- V_f16.alloc(ggml_nelements(V));
765
- to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
766
- to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
767
- V_data = (char *) V_f16.ptr;
768
-
769
  const size_t bs = ggml_blck_size(V->type);
770
  const size_t ts = ggml_type_size(V->type);
771
 
772
- nb21 = nb21*bs*sizeof(half)/ts;
773
- nb22 = nb22*bs*sizeof(half)/ts;
774
- nb23 = nb23*bs*sizeof(half)/ts;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775
  }
776
 
777
  int parallel_blocks = 1;
 
745
  size_t nb23 = V ? V->nb[3] : nb13;
746
 
747
  if (need_f16_K && K->type != GGML_TYPE_F16) {
 
 
 
 
 
 
748
  const size_t bs = ggml_blck_size(K->type);
749
  const size_t ts = ggml_type_size(K->type);
750
 
751
+ K_f16.alloc(ggml_nelements(K));
752
+ if (ggml_is_contiguously_allocated(K)) {
753
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
754
+ to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
755
+
756
+ nb11 = nb11*bs*sizeof(half)/ts;
757
+ nb12 = nb12*bs*sizeof(half)/ts;
758
+ nb13 = nb13*bs*sizeof(half)/ts;
759
+ } else {
760
+ GGML_ASSERT(K->nb[0] == ts);
761
+ to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
762
+ const int64_t s01 = nb11 / ts;
763
+ const int64_t s02 = nb12 / ts;
764
+ const int64_t s03 = nb13 / ts;
765
+ to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
766
+
767
+ nb11 = K->ne[0] * sizeof(half);
768
+ nb12 = K->ne[1] * nb11;
769
+ nb13 = K->ne[2] * nb12;
770
+ }
771
+ K_data = (char *) K_f16.ptr;
772
  }
773
 
774
  if (V && need_f16_V && V->type != GGML_TYPE_F16) {
 
 
 
 
 
 
775
  const size_t bs = ggml_blck_size(V->type);
776
  const size_t ts = ggml_type_size(V->type);
777
 
778
+ V_f16.alloc(ggml_nelements(V));
779
+ if (ggml_is_contiguously_allocated(V)) {
780
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
781
+ to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
782
+ V_data = (char *) V_f16.ptr;
783
+
784
+ nb21 = nb21*bs*sizeof(half)/ts;
785
+ nb22 = nb22*bs*sizeof(half)/ts;
786
+ nb23 = nb23*bs*sizeof(half)/ts;
787
+ } else {
788
+ GGML_ASSERT(V->nb[0] == ts);
789
+ to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
790
+ const int64_t s01 = nb21 / ts;
791
+ const int64_t s02 = nb22 / ts;
792
+ const int64_t s03 = nb23 / ts;
793
+ to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
794
+
795
+ nb21 = V->ne[0] * sizeof(half);
796
+ nb22 = V->ne[1] * nb21;
797
+ nb23 = V->ne[2] * nb22;
798
+ }
799
+ V_data = (char *) V_f16.ptr;
800
  }
801
 
802
  int parallel_blocks = 1;