ggerganov commited on
Commit
bf6ccee
·
1 Parent(s): e2461ca

ggml : sync sycl (skip) (#0)

Browse files
ggml/src/ggml-sycl/backend.hpp CHANGED
@@ -19,5 +19,8 @@
19
  #include "dmmv.hpp"
20
  #include "mmq.hpp"
21
  #include "mmvq.hpp"
 
 
 
22
 
23
  #endif // GGML_SYCL_BACKEND_HPP
 
19
  #include "dmmv.hpp"
20
  #include "mmq.hpp"
21
  #include "mmvq.hpp"
22
+ #include "rope.hpp"
23
+ #include "norm.hpp"
24
+ #include "softmax.hpp"
25
 
26
  #endif // GGML_SYCL_BACKEND_HPP
ggml/src/ggml-sycl/common.hpp CHANGED
@@ -17,6 +17,7 @@
17
  #include <iostream>
18
 
19
  #include "dpct/helper.hpp"
 
20
  #include "presets.hpp"
21
 
22
  #define GGML_COMMON_DECL_SYCL
@@ -46,10 +47,6 @@ static int g_ggml_sycl_debug = 0;
46
  } \
47
  }()
48
 
49
- // #define DEBUG_SYCL_MALLOC
50
-
51
- static int g_work_group_size = 0;
52
- // typedef sycl::half ggml_fp16_t;
53
 
54
  #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
55
  #define VER_4VEC 610 // todo for hardward optimize.
@@ -192,6 +189,8 @@ struct ggml_sycl_device_info {
192
  sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {};
193
 
194
  std::array<float, GGML_SYCL_MAX_DEVICES> default_tensor_split = {};
 
 
195
  };
196
 
197
  const ggml_sycl_device_info & ggml_sycl_info();
@@ -294,5 +293,57 @@ struct ggml_backend_sycl_context {
294
  }
295
  };
296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
  #endif // GGML_SYCL_COMMON_HPP
 
17
  #include <iostream>
18
 
19
  #include "dpct/helper.hpp"
20
+ #include "ggml-sycl.h"
21
  #include "presets.hpp"
22
 
23
  #define GGML_COMMON_DECL_SYCL
 
47
  } \
48
  }()
49
 
 
 
 
 
50
 
51
  #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
52
  #define VER_4VEC 610 // todo for hardward optimize.
 
189
  sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {};
190
 
191
  std::array<float, GGML_SYCL_MAX_DEVICES> default_tensor_split = {};
192
+
193
+ int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0};
194
  };
195
 
196
  const ggml_sycl_device_info & ggml_sycl_info();
 
293
  }
294
  };
295
 
296
+ // common device functions
297
+
298
+ static __dpct_inline__ float warp_reduce_sum(float x,
299
+ const sycl::nd_item<3>& item_ct1) {
300
+ #pragma unroll
301
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
302
+ /*
303
+ DPCT1096:98: The right-most dimension of the work-group used in the SYCL
304
+ kernel that calls this function may be less than "32". The function
305
+ "dpct::permute_sub_group_by_xor" may return an unexpected result on the
306
+ CPU device. Modify the size of the work-group to ensure that the value
307
+ of the right-most dimension is a multiple of "32".
308
+ */
309
+ x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask);
310
+ }
311
+ return x;
312
+ }
313
+
314
+ static __dpct_inline__ sycl::float2
315
+ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
316
+ #pragma unroll
317
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
318
+ a.x() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.x(),
319
+ mask);
320
+ a.y() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.y(),
321
+ mask);
322
+ }
323
+ return a;
324
+ }
325
+
326
+ static __dpct_inline__ float warp_reduce_max(float x,
327
+ const sycl::nd_item<3>& item_ct1) {
328
+ #pragma unroll
329
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
330
+ /*
331
+ DPCT1096:97: The right-most dimension of the work-group used in the SYCL
332
+ kernel that calls this function may be less than "32". The function
333
+ "dpct::permute_sub_group_by_xor" may return an unexpected result on the
334
+ CPU device. Modify the size of the work-group to ensure that the value
335
+ of the right-most dimension is a multiple of "32".
336
+ */
337
+ x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
338
+ item_ct1.get_sub_group(), x, mask));
339
+ }
340
+ return x;
341
+ }
342
+
343
+ // Helper for vec loading aligned data
344
+ template <typename Tp, int n>
345
+ inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
346
+ return *reinterpret_cast<const sycl::vec<Tp, n>*>(aligned_ptr);
347
+ }
348
 
349
  #endif // GGML_SYCL_COMMON_HPP
ggml/src/ggml-sycl/convert.cpp CHANGED
@@ -152,12 +152,15 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
152
  dpct::has_capability_or_fail(stream->get_device(),
153
  {sycl::aspect::fp16});
154
 
155
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
 
 
156
  sycl::range<3>(1, 1, 32),
157
  sycl::range<3>(1, 1, 32)),
158
  [=](sycl::nd_item<3> item_ct1) {
159
- dequantize_block_q4_K(vx, y, item_ct1);
160
  });
 
161
  }
162
  }
163
 
 
152
  dpct::has_capability_or_fail(stream->get_device(),
153
  {sycl::aspect::fp16});
154
 
155
+ stream->submit([&](sycl::handler &cgh) {
156
+ sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
157
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
158
  sycl::range<3>(1, 1, 32),
159
  sycl::range<3>(1, 1, 32)),
160
  [=](sycl::nd_item<3> item_ct1) {
161
+ dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1);
162
  });
163
+ });
164
  }
165
  }
166
 
ggml/src/ggml-sycl/dequantize.hpp CHANGED
@@ -293,7 +293,8 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
293
  #if QK_K == 256
294
  static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
295
  if (j < 4) {
296
- d = q[j] & 63; m = q[j + 4] & 63;
 
297
  } else {
298
  d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
299
  m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
@@ -303,7 +304,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
303
 
304
  template<typename dst_t>
305
  static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
306
- const sycl::nd_item<3> &item_ct1) {
307
  const block_q4_K * x = (const block_q4_K *) vx;
308
 
309
  const int i = item_ct1.get_group(2);
@@ -318,19 +319,26 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
318
 
319
  dst_t * y = yy + i*QK_K + 64*il + n*ir;
320
 
321
- const float dall = x[i].dm[0];
322
- const float dmin = x[i].dm[1];
 
323
 
324
- const uint8_t * q = x[i].qs + 32*il + n*ir;
 
 
325
 
326
  uint8_t sc, m;
327
- get_scale_min_k4(is + 0, x[i].scales, sc, m);
328
- const float d1 = dall * sc; const float m1 = dmin * m;
329
- get_scale_min_k4(is + 1, x[i].scales, sc, m);
330
- const float d2 = dall * sc; const float m2 = dmin * m;
 
 
 
 
331
  for (int l = 0; l < n; ++l) {
332
- y[l + 0] = d1 * (q[l] & 0xF) - m1;
333
- y[l +32] = d2 * (q[l] >> 4) - m2;
334
  }
335
  #else
336
  const int tid = item_ct1.get_local_id(2);
 
293
  #if QK_K == 256
294
  static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
295
  if (j < 4) {
296
+ d = q[j] & 63;
297
+ m = q[j + 4] & 63;
298
  } else {
299
  d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
300
  m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
 
304
 
305
  template<typename dst_t>
306
  static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
307
+ uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
308
  const block_q4_K * x = (const block_q4_K *) vx;
309
 
310
  const int i = item_ct1.get_group(2);
 
319
 
320
  dst_t * y = yy + i*QK_K + 64*il + n*ir;
321
 
322
+ const sycl::half2 dm = x[i].dm;
323
+ const float dall = dm[0];
324
+ const float dmin = dm[1];
325
 
326
+ if (tid < 12)
327
+ scales_local[tid] = x[i].scales[tid];
328
+ item_ct1.barrier(sycl::access::fence_space::local_space);
329
 
330
  uint8_t sc, m;
331
+ get_scale_min_k4(is + 0, scales_local, sc, m);
332
+ const float d1 = dall * sc;
333
+ const float m1 = dmin * m;
334
+ get_scale_min_k4(is + 1, scales_local, sc, m);
335
+ const float d2 = dall * sc;
336
+ const float m2 = dmin * m;
337
+
338
+ sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
339
  for (int l = 0; l < n; ++l) {
340
+ y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
341
+ y[l +32] = d2 * (q_vec[l] >> 4) - m2;
342
  }
343
  #else
344
  const int tid = item_ct1.get_local_id(2);
ggml/src/ggml-sycl/dmmv.cpp CHANGED
@@ -3,6 +3,7 @@
3
  #include "dequantize.hpp"
4
  #include "presets.hpp"
5
 
 
6
  static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
7
  const sycl::half *x = (const sycl::half *)vx;
8
 
@@ -76,7 +77,7 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
76
 
77
  // sum up partial sums and write back result
78
  #pragma unroll
79
- for (int mask = 16; mask > 0; mask >>= 1) {
80
  tmp +=
81
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
82
  }
@@ -104,7 +105,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
104
 
105
  stream->parallel_for(
106
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
107
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
108
  dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
109
  nrows, item_ct1);
110
  });
@@ -227,7 +228,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
227
 
228
  // sum up partial sums and write back result
229
  #pragma unroll
230
- for (int mask = 16; mask > 0; mask >>= 1) {
231
  tmp +=
232
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
233
  }
@@ -346,7 +347,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
346
 
347
  // sum up partial sums and write back result
348
  #pragma unroll
349
- for (int mask = 16; mask > 0; mask >>= 1) {
350
  tmp +=
351
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
352
  }
@@ -499,7 +500,7 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
499
 
500
  // sum up partial sums and write back result
501
  #pragma unroll
502
- for (int mask = 16; mask > 0; mask >>= 1) {
503
  tmp +=
504
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
505
  }
@@ -633,7 +634,7 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
633
 
634
  // sum up partial sums and write back result
635
  #pragma unroll
636
- for (int mask = 16; mask > 0; mask >>= 1) {
637
  tmp +=
638
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
639
  }
@@ -748,7 +749,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
748
 
749
  // sum up partial sums and write back result
750
  #pragma unroll
751
- for (int mask = 16; mask > 0; mask >>= 1) {
752
  tmp +=
753
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
754
  }
@@ -774,7 +775,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
774
 
775
  stream->parallel_for(
776
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
777
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
778
  dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
779
  vx, y, dst, ncols, nrows, item_ct1);
780
  });
@@ -795,7 +796,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
795
 
796
  stream->parallel_for(
797
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
798
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
799
  dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
800
  vx, y, dst, ncols, nrows, item_ct1);
801
  });
@@ -816,7 +817,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
816
 
817
  stream->parallel_for(
818
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
819
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
820
  dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
821
  vx, y, dst, ncols, nrows, item_ct1);
822
  });
@@ -837,7 +838,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
837
 
838
  stream->parallel_for(
839
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
840
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
841
  dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
842
  vx, y, dst, ncols, nrows, item_ct1);
843
  });
@@ -858,7 +859,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
858
 
859
  stream->parallel_for(
860
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
861
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
862
  dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
863
  vx, y, dst, ncols, nrows, item_ct1);
864
  });
@@ -873,10 +874,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
873
  const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
874
  const int block_num_y = (nrows + ny - 1) / ny;
875
  const sycl::range<3> block_nums(1, 1, block_num_y);
876
- const sycl::range<3> block_dims(1, ny, 32);
877
  stream->parallel_for(
878
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
879
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
880
  dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
881
  });
882
  }
@@ -889,10 +890,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
889
  const int ny = 2 / K_QUANTS_PER_ITERATION;
890
  const int block_num_y = (nrows + ny - 1) / ny;
891
  const sycl::range<3> block_nums(1, 1, block_num_y);
892
- const sycl::range<3> block_dims(1, ny, 32);
893
  stream->parallel_for(
894
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
895
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
896
  dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
897
  });
898
  }
@@ -905,10 +906,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
905
  const int ny = 2 / K_QUANTS_PER_ITERATION;
906
  const int block_num_y = (nrows + ny - 1) / ny;
907
  const sycl::range<3> block_nums(1, 1, block_num_y);
908
- const sycl::range<3> block_dims(1, ny, 32);
909
  stream->parallel_for(
910
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
911
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
912
  dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
913
  });
914
  }
@@ -918,10 +919,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
918
  const int nrows,
919
  dpct::queue_ptr stream) {
920
  GGML_ASSERT(ncols % QK_K == 0);
921
- const sycl::range<3> block_dims(1, 1, 32);
922
  stream->parallel_for(
923
  sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
924
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
925
  dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
926
  });
927
  }
@@ -934,10 +935,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
934
  const int ny = 2 / K_QUANTS_PER_ITERATION;
935
  const int block_num_y = (nrows + ny - 1) / ny;
936
  const sycl::range<3> block_nums(1, 1, block_num_y);
937
- const sycl::range<3> block_dims(1, ny, 32);
938
  stream->parallel_for(
939
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
940
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
941
  dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
942
  });
943
  }
 
3
  #include "dequantize.hpp"
4
  #include "presets.hpp"
5
 
6
+
7
  static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
8
  const sycl::half *x = (const sycl::half *)vx;
9
 
 
77
 
78
  // sum up partial sums and write back result
79
  #pragma unroll
80
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
81
  tmp +=
82
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
83
  }
 
105
 
106
  stream->parallel_for(
107
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
108
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
109
  dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
110
  nrows, item_ct1);
111
  });
 
228
 
229
  // sum up partial sums and write back result
230
  #pragma unroll
231
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
232
  tmp +=
233
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
234
  }
 
347
 
348
  // sum up partial sums and write back result
349
  #pragma unroll
350
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
351
  tmp +=
352
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
353
  }
 
500
 
501
  // sum up partial sums and write back result
502
  #pragma unroll
503
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
504
  tmp +=
505
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
506
  }
 
634
 
635
  // sum up partial sums and write back result
636
  #pragma unroll
637
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
638
  tmp +=
639
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
640
  }
 
749
 
750
  // sum up partial sums and write back result
751
  #pragma unroll
752
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
753
  tmp +=
754
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
755
  }
 
775
 
776
  stream->parallel_for(
777
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
778
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
779
  dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
780
  vx, y, dst, ncols, nrows, item_ct1);
781
  });
 
796
 
797
  stream->parallel_for(
798
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
799
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
800
  dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
801
  vx, y, dst, ncols, nrows, item_ct1);
802
  });
 
817
 
818
  stream->parallel_for(
819
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
820
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
821
  dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
822
  vx, y, dst, ncols, nrows, item_ct1);
823
  });
 
838
 
839
  stream->parallel_for(
840
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
841
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
842
  dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
843
  vx, y, dst, ncols, nrows, item_ct1);
844
  });
 
859
 
860
  stream->parallel_for(
861
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
862
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
863
  dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
864
  vx, y, dst, ncols, nrows, item_ct1);
865
  });
 
874
  const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
875
  const int block_num_y = (nrows + ny - 1) / ny;
876
  const sycl::range<3> block_nums(1, 1, block_num_y);
877
+ const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
878
  stream->parallel_for(
879
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
880
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
881
  dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
882
  });
883
  }
 
890
  const int ny = 2 / K_QUANTS_PER_ITERATION;
891
  const int block_num_y = (nrows + ny - 1) / ny;
892
  const sycl::range<3> block_nums(1, 1, block_num_y);
893
+ const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
894
  stream->parallel_for(
895
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
896
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
897
  dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
898
  });
899
  }
 
906
  const int ny = 2 / K_QUANTS_PER_ITERATION;
907
  const int block_num_y = (nrows + ny - 1) / ny;
908
  const sycl::range<3> block_nums(1, 1, block_num_y);
909
+ const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
910
  stream->parallel_for(
911
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
912
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
913
  dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
914
  });
915
  }
 
919
  const int nrows,
920
  dpct::queue_ptr stream) {
921
  GGML_ASSERT(ncols % QK_K == 0);
922
+ const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
923
  stream->parallel_for(
924
  sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
925
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
926
  dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
927
  });
928
  }
 
935
  const int ny = 2 / K_QUANTS_PER_ITERATION;
936
  const int block_num_y = (nrows + ny - 1) / ny;
937
  const sycl::range<3> block_nums(1, 1, block_num_y);
938
+ const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
939
  stream->parallel_for(
940
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
941
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
942
  dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
943
  });
944
  }
ggml/src/ggml-sycl/dpct/helper.hpp CHANGED
@@ -255,7 +255,7 @@ namespace dpct
255
  void set_pitch(size_t pitch) { _pitch = pitch; }
256
 
257
  size_t get_x() { return _x; }
258
- void set_x(size_t x) { _x = x; };
259
 
260
  size_t get_y() { return _y; }
261
  void set_y(size_t y) { _y = y; }
@@ -1056,7 +1056,7 @@ namespace dpct
1056
  #error "Only support Windows and Linux."
1057
  #endif
1058
  next_free = mapped_address_space;
1059
- };
1060
 
1061
  public:
1062
  using buffer_id_t = int;
@@ -1077,7 +1077,7 @@ namespace dpct
1077
  #else
1078
  #error "Only support Windows and Linux."
1079
  #endif
1080
- };
1081
 
1082
  mem_mgr(const mem_mgr &) = delete;
1083
  mem_mgr &operator=(const mem_mgr &) = delete;
@@ -2426,6 +2426,7 @@ namespace dpct
2426
  b, ldb, beta, c, ldc, batch_size);
2427
  break;
2428
  }
 
2429
  case detail::get_type_combination_id(
2430
  library_data_t::real_int8, library_data_t::real_int8,
2431
  library_data_t::real_int32, library_data_t::real_int32):
@@ -2458,7 +2459,6 @@ namespace dpct
2458
  batch_size);
2459
  break;
2460
  }
2461
- #endif
2462
  case detail::get_type_combination_id(
2463
  library_data_t::real_half, library_data_t::real_half,
2464
  library_data_t::real_half, library_data_t::real_float):
@@ -2595,6 +2595,7 @@ namespace dpct
2595
  stride_c, batch_size);
2596
  break;
2597
  }
 
2598
  case detail::get_type_combination_id(
2599
  library_data_t::real_int8, library_data_t::real_int8,
2600
  library_data_t::real_int32, library_data_t::real_int32):
@@ -2623,7 +2624,6 @@ namespace dpct
2623
  beta, c, ldc, stride_c, batch_size);
2624
  break;
2625
  }
2626
- #endif
2627
  case detail::get_type_combination_id(
2628
  library_data_t::real_half, library_data_t::real_half,
2629
  library_data_t::real_half, library_data_t::real_float):
 
255
  void set_pitch(size_t pitch) { _pitch = pitch; }
256
 
257
  size_t get_x() { return _x; }
258
+ void set_x(size_t x) { _x = x; }
259
 
260
  size_t get_y() { return _y; }
261
  void set_y(size_t y) { _y = y; }
 
1056
  #error "Only support Windows and Linux."
1057
  #endif
1058
  next_free = mapped_address_space;
1059
+ }
1060
 
1061
  public:
1062
  using buffer_id_t = int;
 
1077
  #else
1078
  #error "Only support Windows and Linux."
1079
  #endif
1080
+ }
1081
 
1082
  mem_mgr(const mem_mgr &) = delete;
1083
  mem_mgr &operator=(const mem_mgr &) = delete;
 
2426
  b, ldb, beta, c, ldc, batch_size);
2427
  break;
2428
  }
2429
+ #endif
2430
  case detail::get_type_combination_id(
2431
  library_data_t::real_int8, library_data_t::real_int8,
2432
  library_data_t::real_int32, library_data_t::real_int32):
 
2459
  batch_size);
2460
  break;
2461
  }
 
2462
  case detail::get_type_combination_id(
2463
  library_data_t::real_half, library_data_t::real_half,
2464
  library_data_t::real_half, library_data_t::real_float):
 
2595
  stride_c, batch_size);
2596
  break;
2597
  }
2598
+ #endif
2599
  case detail::get_type_combination_id(
2600
  library_data_t::real_int8, library_data_t::real_int8,
2601
  library_data_t::real_int32, library_data_t::real_int32):
 
2624
  beta, c, ldc, stride_c, batch_size);
2625
  break;
2626
  }
 
2627
  case detail::get_type_combination_id(
2628
  library_data_t::real_half, library_data_t::real_half,
2629
  library_data_t::real_half, library_data_t::real_float):
ggml/src/ggml-sycl/mmvq.cpp CHANGED
@@ -37,7 +37,7 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
37
 
38
  // sum up partial sums and write back result
39
  #pragma unroll
40
- for (int mask = 16; mask > 0; mask >>= 1) {
41
  tmp +=
42
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
43
  }
@@ -85,7 +85,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
85
 
86
  // sum up partial sums and write back result
87
  #pragma unroll
88
- for (int mask = 16; mask > 0; mask >>= 1) {
89
  tmp +=
90
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
91
  }
@@ -133,7 +133,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
133
 
134
  // sum up partial sums and write back result
135
  #pragma unroll
136
- for (int mask = 16; mask > 0; mask >>= 1) {
137
  tmp +=
138
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
139
  }
@@ -181,7 +181,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
181
 
182
  // sum up partial sums and write back result
183
  #pragma unroll
184
- for (int mask = 16; mask > 0; mask >>= 1) {
185
  tmp +=
186
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
187
  }
@@ -229,7 +229,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
229
 
230
  // sum up partial sums and write back result
231
  #pragma unroll
232
- for (int mask = 16; mask > 0; mask >>= 1) {
233
  tmp +=
234
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
235
  }
@@ -277,7 +277,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
277
 
278
  // sum up partial sums and write back result
279
  #pragma unroll
280
- for (int mask = 16; mask > 0; mask >>= 1) {
281
  tmp +=
282
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
283
  }
@@ -325,7 +325,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
325
 
326
  // sum up partial sums and write back result
327
  #pragma unroll
328
- for (int mask = 16; mask > 0; mask >>= 1) {
329
  tmp +=
330
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
331
  }
@@ -373,7 +373,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
373
 
374
  // sum up partial sums and write back result
375
  #pragma unroll
376
- for (int mask = 16; mask > 0; mask >>= 1) {
377
  tmp +=
378
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
379
  }
@@ -421,7 +421,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
421
 
422
  // sum up partial sums and write back result
423
  #pragma unroll
424
- for (int mask = 16; mask > 0; mask >>= 1) {
425
  tmp +=
426
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
427
  }
@@ -470,7 +470,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
470
 
471
  // sum up partial sums and write back result
472
  #pragma unroll
473
- for (int mask = 16; mask > 0; mask >>= 1) {
474
  tmp +=
475
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
476
  }
@@ -495,7 +495,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
495
  cgh.parallel_for(
496
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
497
  [=](sycl::nd_item<3> item_ct1)
498
- [[intel::reqd_sub_group_size(32)]] {
499
  mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
500
  VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
501
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -519,7 +519,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
519
  cgh.parallel_for(
520
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
521
  [=](sycl::nd_item<3> item_ct1)
522
- [[intel::reqd_sub_group_size(32)]] {
523
  mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
524
  VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
525
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -543,7 +543,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
543
  cgh.parallel_for(
544
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
545
  [=](sycl::nd_item<3> item_ct1)
546
- [[intel::reqd_sub_group_size(32)]] {
547
  mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
548
  VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
549
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -567,7 +567,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
567
  cgh.parallel_for(
568
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
569
  [=](sycl::nd_item<3> item_ct1)
570
- [[intel::reqd_sub_group_size(32)]] {
571
  mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
572
  VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
573
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -591,7 +591,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
591
  cgh.parallel_for(
592
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
593
  [=](sycl::nd_item<3> item_ct1)
594
- [[intel::reqd_sub_group_size(32)]] {
595
  mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
596
  VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
597
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -615,7 +615,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
615
  cgh.parallel_for(
616
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
617
  [=](sycl::nd_item<3> item_ct1)
618
- [[intel::reqd_sub_group_size(32)]] {
619
  mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
620
  VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
621
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -639,7 +639,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
639
  cgh.parallel_for(
640
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
641
  [=](sycl::nd_item<3> item_ct1)
642
- [[intel::reqd_sub_group_size(32)]] {
643
  mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
644
  VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
645
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -663,7 +663,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
663
  cgh.parallel_for(
664
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
665
  [=](sycl::nd_item<3> item_ct1)
666
- [[intel::reqd_sub_group_size(32)]] {
667
  mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
668
  VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
669
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -687,7 +687,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
687
  cgh.parallel_for(
688
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
689
  [=](sycl::nd_item<3> item_ct1)
690
- [[intel::reqd_sub_group_size(32)]] {
691
  mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
692
  VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
693
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -711,7 +711,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
711
  cgh.parallel_for(
712
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
713
  [=](sycl::nd_item<3> item_ct1)
714
- [[intel::reqd_sub_group_size(32)]] {
715
  mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
716
  VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
717
  vx, vy, dst, ncols, nrows, item_ct1);
@@ -734,8 +734,8 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
734
  cgh.parallel_for(
735
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
736
  [=](sycl::nd_item<3> item_ct1)
737
- [[intel::reqd_sub_group_size(32)]] {
738
- mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS, block_iq2_xxs, 1>(
739
  vx, vy, dst, ncols, nrows, item_ct1);
740
  });
741
  });
@@ -759,8 +759,8 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
759
  cgh.parallel_for(
760
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
761
  [=](sycl::nd_item<3> item_ct1)
762
- [[intel::reqd_sub_group_size(32)]] {
763
- mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS, block_iq2_xs, 1>(
764
  vx, vy, dst, ncols, nrows, item_ct1);
765
  });
766
  });
@@ -784,8 +784,8 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
784
  cgh.parallel_for(
785
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
786
  [=](sycl::nd_item<3> item_ct1)
787
- [[intel::reqd_sub_group_size(32)]] {
788
- mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S, block_iq2_s, 1>(
789
  vx, vy, dst, ncols, nrows, item_ct1);
790
  });
791
  });
@@ -809,8 +809,8 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
809
  cgh.parallel_for(
810
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
811
  [=](sycl::nd_item<3> item_ct1)
812
- [[intel::reqd_sub_group_size(32)]] {
813
- mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS, block_iq3_xxs, 1>(
814
  vx, vy, dst, ncols, nrows, item_ct1);
815
  });
816
  });
@@ -833,8 +833,8 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
833
  cgh.parallel_for(
834
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
835
  [=](sycl::nd_item<3> item_ct1)
836
- [[intel::reqd_sub_group_size(32)]] {
837
- mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_XS, block_iq3_s, 1>(
838
  vx, vy, dst, ncols, nrows, item_ct1);
839
  });
840
  });
@@ -858,7 +858,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
858
  cgh.parallel_for(
859
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
860
  [=](sycl::nd_item<3> item_ct1)
861
- [[intel::reqd_sub_group_size(32)]] {
862
  mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
863
  vx, vy, dst, ncols, nrows, item_ct1);
864
  });
@@ -879,7 +879,7 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
879
  cgh.parallel_for(
880
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
881
  [=](sycl::nd_item<3> item_ct1)
882
- [[intel::reqd_sub_group_size(32)]] {
883
  mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
884
  vx, vy, dst, ncols, nrows, item_ct1);
885
  });
@@ -901,7 +901,7 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
901
  cgh.parallel_for(
902
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
903
  [=](sycl::nd_item<3> item_ct1)
904
- [[intel::reqd_sub_group_size(32)]] {
905
  mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 1>(
906
  vx, vy, dst, ncols, nrows, item_ct1);
907
  });
@@ -923,8 +923,8 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
923
  cgh.parallel_for(
924
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
925
  [=](sycl::nd_item<3> item_ct1)
926
- [[intel::reqd_sub_group_size(32)]] {
927
- mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS, block_iq4_xs, 1>(
928
  vx, vy, dst, ncols, nrows, item_ct1);
929
  });
930
  });
@@ -936,7 +936,7 @@ void ggml_sycl_op_mul_mat_vec_q(
936
  const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
937
  const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
938
  float *dst_dd_i, const int64_t row_low, const int64_t row_high,
939
- const int64_t src1_ncols, const int64_t src1_padded_row_size,
940
  const dpct::queue_ptr &stream) {
941
 
942
  const int64_t ne10 = src1->ne[0];
@@ -948,77 +948,80 @@ void ggml_sycl_op_mul_mat_vec_q(
948
  int id;
949
  SYCL_CHECK(
950
  CHECK_TRY_ERROR(id = get_current_device_id()));
951
-
 
952
  // the main device has a larger memory buffer to hold the results from all GPUs
953
  // nrows_dst == nrows of the matrix that the kernel writes into
954
  const int64_t nrows_dst = id == ctx.device ? ne00 : row_diff;
955
-
956
- switch (src0->type) {
 
 
 
 
957
  case GGML_TYPE_Q4_0:
958
- mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
959
  break;
960
  case GGML_TYPE_Q4_1:
961
- mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
962
  break;
963
  case GGML_TYPE_Q5_0:
964
- mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
965
  break;
966
  case GGML_TYPE_Q5_1:
967
- mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
968
  break;
969
  case GGML_TYPE_Q8_0:
970
- mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
971
  break;
972
  case GGML_TYPE_Q2_K:
973
- mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
974
  break;
975
  case GGML_TYPE_Q3_K:
976
- mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
977
  break;
978
  case GGML_TYPE_Q4_K:
979
- mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
980
  break;
981
  case GGML_TYPE_Q5_K:
982
- mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
983
  break;
984
  case GGML_TYPE_Q6_K:
985
- mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
986
  break;
987
  case GGML_TYPE_IQ1_S:
988
- mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
989
  break;
990
  case GGML_TYPE_IQ1_M:
991
- mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
992
  break;
993
  case GGML_TYPE_IQ2_XXS:
994
- mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
995
  break;
996
  case GGML_TYPE_IQ2_XS:
997
- mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
998
  break;
999
  case GGML_TYPE_IQ2_S:
1000
- mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
1001
  break;
1002
  case GGML_TYPE_IQ3_XXS:
1003
- mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
1004
  break;
1005
  case GGML_TYPE_IQ3_S:
1006
- mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
1007
  break;
1008
  case GGML_TYPE_IQ4_NL:
1009
- mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
1010
  break;
1011
  case GGML_TYPE_IQ4_XS:
1012
- mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
1013
  break;
1014
  default:
1015
  GGML_ASSERT(false);
1016
  break;
 
1017
  }
1018
-
1019
  (void) src1;
1020
  (void) dst;
1021
  (void) src1_ddf_i;
1022
- (void) src1_ncols;
1023
- (void) src1_padded_row_size;
1024
  }
 
37
 
38
  // sum up partial sums and write back result
39
  #pragma unroll
40
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
41
  tmp +=
42
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
43
  }
 
85
 
86
  // sum up partial sums and write back result
87
  #pragma unroll
88
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
89
  tmp +=
90
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
91
  }
 
133
 
134
  // sum up partial sums and write back result
135
  #pragma unroll
136
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
137
  tmp +=
138
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
139
  }
 
181
 
182
  // sum up partial sums and write back result
183
  #pragma unroll
184
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
185
  tmp +=
186
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
187
  }
 
229
 
230
  // sum up partial sums and write back result
231
  #pragma unroll
232
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
233
  tmp +=
234
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
235
  }
 
277
 
278
  // sum up partial sums and write back result
279
  #pragma unroll
280
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
281
  tmp +=
282
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
283
  }
 
325
 
326
  // sum up partial sums and write back result
327
  #pragma unroll
328
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
329
  tmp +=
330
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
331
  }
 
373
 
374
  // sum up partial sums and write back result
375
  #pragma unroll
376
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
377
  tmp +=
378
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
379
  }
 
421
 
422
  // sum up partial sums and write back result
423
  #pragma unroll
424
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
425
  tmp +=
426
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
427
  }
 
470
 
471
  // sum up partial sums and write back result
472
  #pragma unroll
473
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
474
  tmp +=
475
  dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
476
  }
 
495
  cgh.parallel_for(
496
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
497
  [=](sycl::nd_item<3> item_ct1)
498
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
499
  mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
500
  VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
501
  vx, vy, dst, ncols, nrows, item_ct1);
 
519
  cgh.parallel_for(
520
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
521
  [=](sycl::nd_item<3> item_ct1)
522
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
523
  mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
524
  VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
525
  vx, vy, dst, ncols, nrows, item_ct1);
 
543
  cgh.parallel_for(
544
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
545
  [=](sycl::nd_item<3> item_ct1)
546
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
547
  mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
548
  VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
549
  vx, vy, dst, ncols, nrows, item_ct1);
 
567
  cgh.parallel_for(
568
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
569
  [=](sycl::nd_item<3> item_ct1)
570
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
571
  mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
572
  VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
573
  vx, vy, dst, ncols, nrows, item_ct1);
 
591
  cgh.parallel_for(
592
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
593
  [=](sycl::nd_item<3> item_ct1)
594
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
595
  mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
596
  VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
597
  vx, vy, dst, ncols, nrows, item_ct1);
 
615
  cgh.parallel_for(
616
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
617
  [=](sycl::nd_item<3> item_ct1)
618
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
619
  mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
620
  VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
621
  vx, vy, dst, ncols, nrows, item_ct1);
 
639
  cgh.parallel_for(
640
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
641
  [=](sycl::nd_item<3> item_ct1)
642
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
643
  mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
644
  VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
645
  vx, vy, dst, ncols, nrows, item_ct1);
 
663
  cgh.parallel_for(
664
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
665
  [=](sycl::nd_item<3> item_ct1)
666
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
667
  mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
668
  VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
669
  vx, vy, dst, ncols, nrows, item_ct1);
 
687
  cgh.parallel_for(
688
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
689
  [=](sycl::nd_item<3> item_ct1)
690
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
691
  mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
692
  VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
693
  vx, vy, dst, ncols, nrows, item_ct1);
 
711
  cgh.parallel_for(
712
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
713
  [=](sycl::nd_item<3> item_ct1)
714
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
715
  mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
716
  VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
717
  vx, vy, dst, ncols, nrows, item_ct1);
 
734
  cgh.parallel_for(
735
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
736
  [=](sycl::nd_item<3> item_ct1)
737
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
738
+ mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
739
  vx, vy, dst, ncols, nrows, item_ct1);
740
  });
741
  });
 
759
  cgh.parallel_for(
760
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
761
  [=](sycl::nd_item<3> item_ct1)
762
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
763
+ mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
764
  vx, vy, dst, ncols, nrows, item_ct1);
765
  });
766
  });
 
784
  cgh.parallel_for(
785
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
786
  [=](sycl::nd_item<3> item_ct1)
787
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
788
+ mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
789
  vx, vy, dst, ncols, nrows, item_ct1);
790
  });
791
  });
 
809
  cgh.parallel_for(
810
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
811
  [=](sycl::nd_item<3> item_ct1)
812
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
813
+ mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
814
  vx, vy, dst, ncols, nrows, item_ct1);
815
  });
816
  });
 
833
  cgh.parallel_for(
834
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
835
  [=](sycl::nd_item<3> item_ct1)
836
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
837
+ mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
838
  vx, vy, dst, ncols, nrows, item_ct1);
839
  });
840
  });
 
858
  cgh.parallel_for(
859
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
860
  [=](sycl::nd_item<3> item_ct1)
861
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
862
  mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
863
  vx, vy, dst, ncols, nrows, item_ct1);
864
  });
 
879
  cgh.parallel_for(
880
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
881
  [=](sycl::nd_item<3> item_ct1)
882
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
883
  mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
884
  vx, vy, dst, ncols, nrows, item_ct1);
885
  });
 
901
  cgh.parallel_for(
902
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
903
  [=](sycl::nd_item<3> item_ct1)
904
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
905
  mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 1>(
906
  vx, vy, dst, ncols, nrows, item_ct1);
907
  });
 
923
  cgh.parallel_for(
924
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
925
  [=](sycl::nd_item<3> item_ct1)
926
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
927
+ mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
928
  vx, vy, dst, ncols, nrows, item_ct1);
929
  });
930
  });
 
936
  const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
937
  const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
938
  float *dst_dd_i, const int64_t row_low, const int64_t row_high,
939
+ const int64_t src1_ncols, const int64_t src1_padded_col_size,
940
  const dpct::queue_ptr &stream) {
941
 
942
  const int64_t ne10 = src1->ne[0];
 
948
  int id;
949
  SYCL_CHECK(
950
  CHECK_TRY_ERROR(id = get_current_device_id()));
951
+ const size_t q8_1_ts = sizeof(block_q8_1);
952
+ const size_t q8_1_bs = QK8_1;
953
  // the main device has a larger memory buffer to hold the results from all GPUs
954
  // nrows_dst == nrows of the matrix that the kernel writes into
955
  const int64_t nrows_dst = id == ctx.device ? ne00 : row_diff;
956
+ for (int i = 0; i < src1_ncols; i++)
957
+ {
958
+ const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;
959
+ const char* src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
960
+ float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
961
+ switch (src0->type) {
962
  case GGML_TYPE_Q4_0:
963
+ mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
964
  break;
965
  case GGML_TYPE_Q4_1:
966
+ mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
967
  break;
968
  case GGML_TYPE_Q5_0:
969
+ mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
970
  break;
971
  case GGML_TYPE_Q5_1:
972
+ mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
973
  break;
974
  case GGML_TYPE_Q8_0:
975
+ mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
976
  break;
977
  case GGML_TYPE_Q2_K:
978
+ mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
979
  break;
980
  case GGML_TYPE_Q3_K:
981
+ mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
982
  break;
983
  case GGML_TYPE_Q4_K:
984
+ mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
985
  break;
986
  case GGML_TYPE_Q5_K:
987
+ mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
988
  break;
989
  case GGML_TYPE_Q6_K:
990
+ mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
991
  break;
992
  case GGML_TYPE_IQ1_S:
993
+ mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
994
  break;
995
  case GGML_TYPE_IQ1_M:
996
+ mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
997
  break;
998
  case GGML_TYPE_IQ2_XXS:
999
+ mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1000
  break;
1001
  case GGML_TYPE_IQ2_XS:
1002
+ mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1003
  break;
1004
  case GGML_TYPE_IQ2_S:
1005
+ mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1006
  break;
1007
  case GGML_TYPE_IQ3_XXS:
1008
+ mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1009
  break;
1010
  case GGML_TYPE_IQ3_S:
1011
+ mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1012
  break;
1013
  case GGML_TYPE_IQ4_NL:
1014
+ mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1015
  break;
1016
  case GGML_TYPE_IQ4_XS:
1017
+ mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
1018
  break;
1019
  default:
1020
  GGML_ASSERT(false);
1021
  break;
1022
+ }
1023
  }
 
1024
  (void) src1;
1025
  (void) dst;
1026
  (void) src1_ddf_i;
 
 
1027
  }
ggml/src/ggml-sycl/norm.cpp ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "norm.hpp"
2
+
3
+ static void norm_f32(const float* x, float* dst, const int ncols, const float eps,
4
+ const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
5
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
6
+ item_ct1.get_local_id(1);
7
+ const int tid = item_ct1.get_local_id(2);
8
+
9
+ const int nthreads = item_ct1.get_local_range(2);
10
+ const int nwarps = nthreads / WARP_SIZE;
11
+ assert(nwarps % WARP_SIZE == 0);
12
+ sycl::float2 mean_var = sycl::float2(0.f, 0.f);
13
+
14
+ for (int col = tid; col < ncols; col += block_size) {
15
+ const float xi = x[row * ncols + col];
16
+ mean_var.x() += xi;
17
+ mean_var.y() += xi * xi;
18
+ }
19
+
20
+ // sum up partial sums
21
+ mean_var = warp_reduce_sum(mean_var, item_ct1);
22
+ if (block_size > WARP_SIZE) {
23
+
24
+ int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
25
+ int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
26
+ if (lane_id == 0) {
27
+ s_sum[warp_id] = mean_var;
28
+ }
29
+ /*
30
+ DPCT1118:0: SYCL group functions and algorithms must be encountered in
31
+ converged control flow. You may need to adjust the code.
32
+ */
33
+ item_ct1.barrier(sycl::access::fence_space::local_space);
34
+ mean_var = 0.f;
35
+ int nreduce = nwarps / WARP_SIZE;
36
+ for (size_t i = 0; i < nreduce; i += 1)
37
+ {
38
+ mean_var += s_sum[lane_id + i * WARP_SIZE];
39
+ }
40
+ mean_var = warp_reduce_sum(mean_var, item_ct1);
41
+ }
42
+
43
+ const float mean = mean_var.x() / ncols;
44
+ const float var = mean_var.y() / ncols - mean * mean;
45
+ const float inv_std = sycl::rsqrt(var + eps);
46
+
47
+ for (int col = tid; col < ncols; col += block_size) {
48
+ dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std;
49
+ }
50
+ }
51
+
52
+ static void group_norm_f32(const float* x, float* dst, const int group_size, const int ne_elements, const float eps,
53
+ const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
54
+ int start = item_ct1.get_group(2) * group_size;
55
+ int end = start + group_size;
56
+ const int nthreads = item_ct1.get_local_range(2);
57
+ const int nwarps = nthreads / WARP_SIZE;
58
+ assert(nwarps % WARP_SIZE == 0);
59
+ start += item_ct1.get_local_id(2);
60
+ int nreduce = nwarps / WARP_SIZE;
61
+
62
+ if (end >= ne_elements) {
63
+ end = ne_elements;
64
+ }
65
+
66
+ float tmp = 0.0f; // partial sum for thread in warp
67
+
68
+ for (int j = start; j < end; j += block_size) {
69
+ tmp += x[j];
70
+ }
71
+
72
+ tmp = warp_reduce_sum(tmp, item_ct1);
73
+ if (block_size > WARP_SIZE) {
74
+
75
+ int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
76
+ int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
77
+ if (lane_id == 0) {
78
+ s_sum[warp_id] = tmp;
79
+ }
80
+ /*
81
+ DPCT1118:1: SYCL group functions and algorithms must be encountered in
82
+ converged control flow. You may need to adjust the code.
83
+ */
84
+ /*
85
+ DPCT1065:54: Consider replacing sycl::nd_item::barrier() with
86
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
87
+ better performance if there is no access to global memory.
88
+ */
89
+ item_ct1.barrier();
90
+ tmp = 0.f;
91
+ for (size_t i = 0; i < nreduce; i += 1)
92
+ {
93
+ tmp += s_sum[lane_id + i * WARP_SIZE];
94
+ }
95
+ tmp = warp_reduce_sum(tmp, item_ct1);
96
+ }
97
+
98
+ float mean = tmp / group_size;
99
+ tmp = 0.0f;
100
+
101
+ for (int j = start; j < end; j += block_size) {
102
+ float xi = x[j] - mean;
103
+ dst[j] = xi;
104
+ tmp += xi * xi;
105
+ }
106
+
107
+ tmp = warp_reduce_sum(tmp, item_ct1);
108
+ if (block_size > WARP_SIZE) {
109
+
110
+ int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
111
+ int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
112
+ if (lane_id == 0) {
113
+ s_sum[warp_id] = tmp;
114
+ }
115
+ /*
116
+ DPCT1118:2: SYCL group functions and algorithms must be encountered in
117
+ converged control flow. You may need to adjust the code.
118
+ */
119
+ /*
120
+ DPCT1065:55: Consider replacing sycl::nd_item::barrier() with
121
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
122
+ better performance if there is no access to global memory.
123
+ */
124
+ item_ct1.barrier();
125
+ tmp = 0.f;
126
+ for (size_t i = 0; i < nreduce; i += 1)
127
+ {
128
+ tmp += s_sum[lane_id + i * WARP_SIZE];
129
+ }
130
+ tmp = warp_reduce_sum(tmp, item_ct1);
131
+ }
132
+
133
+ float variance = tmp / group_size;
134
+ float scale = sycl::rsqrt(variance + eps);
135
+ for (int j = start; j < end; j += block_size) {
136
+ dst[j] *= scale;
137
+ }
138
+ }
139
+
140
+ static void rms_norm_f32(const float* x, float* dst, const int ncols, const float eps,
141
+ const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
142
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
143
+ item_ct1.get_local_id(1);
144
+ const int tid = item_ct1.get_local_id(2);
145
+ const int nthreads = item_ct1.get_local_range(2);
146
+ const int nwarps = nthreads / WARP_SIZE;
147
+ assert(nwarps % WARP_SIZE == 0);
148
+ float tmp = 0.0f; // partial sum for thread in warp
149
+
150
+ for (int col = tid; col < ncols; col += block_size) {
151
+ const float xi = x[row * ncols + col];
152
+ tmp += xi * xi;
153
+ }
154
+
155
+ // sum up partial sums
156
+ tmp = warp_reduce_sum(tmp, item_ct1);
157
+ if (block_size > WARP_SIZE) {
158
+
159
+ int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
160
+ int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
161
+ if (lane_id == 0) {
162
+ s_sum[warp_id] = tmp;
163
+ }
164
+ /*
165
+ DPCT1118:3: SYCL group functions and algorithms must be encountered in
166
+ converged control flow. You may need to adjust the code.
167
+ */
168
+ item_ct1.barrier(sycl::access::fence_space::local_space);
169
+ int nreduce = nwarps / WARP_SIZE;
170
+ tmp = 0.f;
171
+ for (size_t i = 0; i < nreduce; i += 1)
172
+ {
173
+ tmp += s_sum[lane_id + i * WARP_SIZE];
174
+ }
175
+ tmp = warp_reduce_sum(tmp, item_ct1);
176
+ }
177
+
178
+ const float mean = tmp / ncols;
179
+ const float scale = sycl::rsqrt(mean + eps);
180
+
181
+ for (int col = tid; col < ncols; col += block_size) {
182
+ dst[row * ncols + col] = scale * x[row * ncols + col];
183
+ }
184
+ }
185
+
186
+ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
187
+ const int nrows, const float eps,
188
+ queue_ptr stream, int device) {
189
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
190
+ if (ncols < 1024) {
191
+ const sycl::range<3> block_dims(1, 1, WARP_SIZE);
192
+ stream->submit([&](sycl::handler& cgh) {
193
+ cgh.parallel_for(
194
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
195
+ block_dims),
196
+ [=](sycl::nd_item<3> item_ct1)
197
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
198
+ norm_f32(x, dst, ncols, eps, item_ct1,
199
+ nullptr, WARP_SIZE);
200
+ });
201
+ });
202
+ }
203
+ else {
204
+ const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
205
+ const sycl::range<3> block_dims(1, 1, work_group_size);
206
+ /*
207
+ DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
208
+ the limit. To get the device limit, query
209
+ info::device::max_work_group_size. Adjust the work-group size if needed.
210
+ */
211
+ stream->submit([&](sycl::handler& cgh) {
212
+ sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
213
+ sycl::range<1>(work_group_size / WARP_SIZE), cgh);
214
+
215
+ cgh.parallel_for(
216
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
217
+ block_dims),
218
+ [=](sycl::nd_item<3> item_ct1)
219
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
220
+ norm_f32(x, dst, ncols, eps, item_ct1,
221
+ s_sum_acc_ct1.get_pointer(), work_group_size);
222
+ });
223
+ });
224
+ }
225
+ }
226
+
227
+ static void group_norm_f32_sycl(const float* x, float* dst,
228
+ const int num_groups, const int group_size,
229
+ const int ne_elements, queue_ptr stream, int device) {
230
+ static const float eps = 1e-6f;
231
+ if (group_size < 1024) {
232
+ const sycl::range<3> block_dims(1, 1, WARP_SIZE);
233
+ stream->submit([&](sycl::handler& cgh) {
234
+ const float eps_ct4 = eps;
235
+ cgh.parallel_for(
236
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
237
+ block_dims),
238
+ [=](sycl::nd_item<3> item_ct1)
239
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
240
+ group_norm_f32(
241
+ x, dst, group_size, ne_elements, eps_ct4, item_ct1,
242
+ nullptr, WARP_SIZE);
243
+ });
244
+ });
245
+ }
246
+ else {
247
+ const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
248
+ const sycl::range<3> block_dims(1, 1, work_group_size);
249
+ /*
250
+ DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
251
+ the limit. To get the device limit, query
252
+ info::device::max_work_group_size. Adjust the work-group size if needed.
253
+ */
254
+
255
+ stream->submit([&](sycl::handler& cgh) {
256
+ sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
257
+ cgh);
258
+
259
+ const float eps_ct4 = eps;
260
+
261
+ cgh.parallel_for(
262
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
263
+ block_dims),
264
+ [=](sycl::nd_item<3> item_ct1)
265
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
266
+ group_norm_f32(x, dst, group_size, ne_elements,
267
+ eps_ct4, item_ct1,
268
+ s_sum_acc_ct1.get_pointer(), work_group_size);
269
+ });
270
+ });
271
+ }
272
+ }
273
+
274
+ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
275
+ const int nrows, const float eps,
276
+ queue_ptr stream, int device) {
277
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
278
+ // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
279
+ if (ncols < 1024) {
280
+ const sycl::range<3> block_dims(1, 1, WARP_SIZE);
281
+ stream->submit([&](sycl::handler& cgh) {
282
+ cgh.parallel_for(
283
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
284
+ block_dims),
285
+ [=](sycl::nd_item<3> item_ct1)
286
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
287
+ rms_norm_f32(x, dst, ncols, eps, item_ct1,
288
+ nullptr, WARP_SIZE);
289
+ });
290
+ });
291
+ }
292
+ else {
293
+ const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
294
+ const sycl::range<3> block_dims(1, 1, work_group_size);
295
+ /*
296
+ DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
297
+ the limit. To get the device limit, query
298
+ info::device::max_work_group_size. Adjust the work-group size if needed.
299
+ */
300
+ stream->submit([&](sycl::handler& cgh) {
301
+ sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
302
+ cgh);
303
+ cgh.parallel_for(
304
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
305
+ block_dims),
306
+ [=](sycl::nd_item<3> item_ct1)
307
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
308
+ rms_norm_f32(x, dst, ncols, eps, item_ct1,
309
+ s_sum_acc_ct1.get_pointer(), work_group_size);
310
+ });
311
+ });
312
+ }
313
+ }
314
+
315
+ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
316
+ ggml_tensor* dst, const float* src0_dd,
317
+ const float* src1_dd, float* dst_dd,
318
+ const queue_ptr& main_stream) {
319
+
320
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
321
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
322
+
323
+ const int64_t ne00 = src0->ne[0];
324
+ const int64_t nrows = ggml_nrows(src0);
325
+
326
+ float eps;
327
+ memcpy(&eps, dst->op_params, sizeof(float));
328
+
329
+ norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
330
+
331
+ (void)src1;
332
+ (void)dst;
333
+ (void)src1_dd;
334
+ }
335
+
336
+ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
337
+ const ggml_tensor* src1, ggml_tensor* dst,
338
+ const float* src0_dd, const float* src1_dd,
339
+ float* dst_dd,
340
+ const queue_ptr& main_stream) {
341
+
342
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
343
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
344
+
345
+ int num_groups = dst->op_params[0];
346
+ int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
347
+ group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
348
+
349
+ (void)src1;
350
+ (void)dst;
351
+ (void)src1_dd;
352
+ }
353
+
354
+ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
355
+ const ggml_tensor* src1, ggml_tensor* dst,
356
+ const float* src0_dd, const float* src1_dd,
357
+ float* dst_dd,
358
+ const queue_ptr& main_stream) {
359
+
360
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
361
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
362
+
363
+ const int64_t ne00 = src0->ne[0];
364
+ const int64_t nrows = ggml_nrows(src0);
365
+
366
+ float eps;
367
+ memcpy(&eps, dst->op_params, sizeof(float));
368
+
369
+ rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
370
+
371
+ (void)src1;
372
+ (void)dst;
373
+ (void)src1_dd;
374
+ }
ggml/src/ggml-sycl/norm.hpp ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #ifndef GGML_SYCL_NORM_HPP
14
+ #define GGML_SYCL_NORM_HPP
15
+
16
+ #include "common.hpp"
17
+
18
+ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
19
+ ggml_tensor* dst, const float* src0_dd,
20
+ const float* src1_dd, float* dst_dd,
21
+ const queue_ptr& main_stream);
22
+
23
+ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
24
+ const ggml_tensor* src1, ggml_tensor* dst,
25
+ const float* src0_dd, const float* src1_dd,
26
+ float* dst_dd,
27
+ const queue_ptr& main_stream);
28
+
29
+ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
30
+ const ggml_tensor* src1, ggml_tensor* dst,
31
+ const float* src0_dd, const float* src1_dd,
32
+ float* dst_dd,
33
+ const queue_ptr& main_stream);
34
+
35
+ #endif // GGML_SYCL_NORM_HPP
ggml/src/ggml-sycl/presets.hpp CHANGED
@@ -15,10 +15,8 @@
15
 
16
  #define GGML_SYCL_MAX_STREAMS 8
17
  #define GGML_SYCL_MAX_BUFFERS 256
18
- #define GGML_SYCL_MAX_DEVICES 48
19
- #define GGML_SYCL_NAME "SYCL"
20
 
21
- #define WARP_SIZE 32
22
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
23
 
24
  #define SYCL_GELU_BLOCK_SIZE 256
@@ -64,4 +62,5 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
64
 
65
  #define MUL_MAT_SRC1_COL_STRIDE 128
66
 
 
67
  #endif // GGML_SYCL_PRESETS_HPP
 
15
 
16
  #define GGML_SYCL_MAX_STREAMS 8
17
  #define GGML_SYCL_MAX_BUFFERS 256
 
 
18
 
19
+ #define WARP_SIZE GGML_SYCL_WARP_SIZE
20
  #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
21
 
22
  #define SYCL_GELU_BLOCK_SIZE 256
 
62
 
63
  #define MUL_MAT_SRC1_COL_STRIDE 128
64
 
65
+ #define QK_WARP_SIZE 32
66
  #endif // GGML_SYCL_PRESETS_HPP
ggml/src/ggml-sycl/rope.cpp ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "rope.hpp"
2
+
3
+ struct rope_corr_dims {
4
+ float v[2];
5
+ };
6
+
7
+ static float rope_yarn_ramp(const float low, const float high, const int i0) {
8
+ const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);
9
+ return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
10
+ }
11
+
12
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
13
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
14
+ static void rope_yarn(
15
+ float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
16
+ float * cos_theta, float * sin_theta) {
17
+ // Get n-d rotational scaling corrected for extrapolation
18
+ float theta_interp = freq_scale * theta_extrap;
19
+ float theta = theta_interp;
20
+ if (ext_factor != 0.0f) {
21
+ float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
22
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
23
+
24
+ // Get n-d magnitude scaling corrected for interpolation
25
+ mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
26
+ }
27
+ *cos_theta = sycl::cos(theta) * mscale;
28
+ *sin_theta = sycl::sin(theta) * mscale;
29
+ }
30
+
31
+ template<typename T, bool has_ff>
32
+ static void rope_norm(
33
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
34
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
35
+ const sycl::nd_item<3> &item_ct1) {
36
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
37
+ item_ct1.get_local_id(1));
38
+
39
+ if (i0 >= ne0) {
40
+ return;
41
+ }
42
+
43
+ const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
44
+ item_ct1.get_local_id(2);
45
+
46
+ if (i0 >= n_dims) {
47
+ const int i = row*ne0 + i0;
48
+
49
+ dst[i + 0] = x[i + 0];
50
+ dst[i + 1] = x[i + 1];
51
+
52
+ return;
53
+ }
54
+
55
+ const int i = row*ne0 + i0;
56
+ const int i2 = row/p_delta_rows;
57
+
58
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
59
+
60
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
61
+
62
+ float cos_theta;
63
+ float sin_theta;
64
+
65
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
66
+
67
+ const float x0 = x[i + 0];
68
+ const float x1 = x[i + 1];
69
+
70
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
71
+ dst[i + 1] = x0*sin_theta + x1*cos_theta;
72
+ }
73
+
74
+ template<typename T, bool has_ff>
75
+ static void rope_neox(
76
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
77
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
78
+ const sycl::nd_item<3> &item_ct1) {
79
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
80
+ item_ct1.get_local_id(1));
81
+
82
+ if (i0 >= ne0) {
83
+ return;
84
+ }
85
+
86
+ const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
87
+ item_ct1.get_local_id(2);
88
+
89
+ if (i0 >= n_dims) {
90
+ const int i = row*ne0 + i0;
91
+
92
+ dst[i + 0] = x[i + 0];
93
+ dst[i + 1] = x[i + 1];
94
+
95
+ return;
96
+ }
97
+
98
+ const int i = row*ne0 + i0/2;
99
+ const int i2 = row/p_delta_rows;
100
+
101
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
102
+
103
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
104
+
105
+ float cos_theta;
106
+ float sin_theta;
107
+
108
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
109
+
110
+ const float x0 = x[i + 0];
111
+ const float x1 = x[i + n_dims/2];
112
+
113
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
114
+ dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
115
+ }
116
+
117
+ template <typename T>
118
+ static void rope_norm_sycl(
119
+ const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
120
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
121
+ GGML_ASSERT(ne0 % 2 == 0);
122
+ const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
123
+ const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
124
+ const sycl::range<3> block_nums(1, num_blocks_x, nr);
125
+
126
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
127
+
128
+ dpct::has_capability_or_fail(stream->get_device(),
129
+ {sycl::aspect::fp16});
130
+
131
+ if (freq_factors == nullptr) {
132
+ /*
133
+ DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
134
+ the limit. To get the device limit, query
135
+ info::device::max_work_group_size. Adjust the work-group size if needed.
136
+ */
137
+ stream->parallel_for(
138
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
139
+ [=](sycl::nd_item<3> item_ct1) {
140
+ rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
141
+ ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
142
+ item_ct1);
143
+ });
144
+ } else {
145
+ /*
146
+ DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
147
+ the limit. To get the device limit, query
148
+ info::device::max_work_group_size. Adjust the work-group size if needed.
149
+ */
150
+ stream->parallel_for(
151
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
152
+ [=](sycl::nd_item<3> item_ct1) {
153
+ rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
154
+ ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
155
+ item_ct1);
156
+ });
157
+ }
158
+ }
159
+
160
+ template <typename T>
161
+ static void rope_neox_sycl(
162
+ const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
163
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
164
+ GGML_ASSERT(ne0 % 2 == 0);
165
+ const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
166
+ const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
167
+ const sycl::range<3> block_nums(1, num_blocks_x, nr);
168
+
169
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
170
+
171
+ dpct::has_capability_or_fail(stream->get_device(),
172
+ {sycl::aspect::fp16});
173
+
174
+ if (freq_factors == nullptr) {
175
+ stream->parallel_for(
176
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
177
+ [=](sycl::nd_item<3> item_ct1) {
178
+ rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale,
179
+ p_delta_rows, ext_factor, attn_factor,
180
+ corr_dims, theta_scale, freq_factors,
181
+ item_ct1);
182
+ });
183
+ } else {
184
+ stream->parallel_for(
185
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
186
+ [=](sycl::nd_item<3> item_ct1) {
187
+ rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale,
188
+ p_delta_rows, ext_factor, attn_factor,
189
+ corr_dims, theta_scale, freq_factors,
190
+ item_ct1);
191
+ });
192
+ }
193
+ }
194
+
195
+ void ggml_sycl_op_rope(
196
+ ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
197
+ const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) {
198
+ const ggml_tensor * src2 = dst->src[2];
199
+
200
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
201
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
202
+ GGML_ASSERT(src0->type == dst->type);
203
+
204
+ const int64_t ne00 = src0->ne[0];
205
+ const int64_t ne01 = src0->ne[1];
206
+ const int64_t nr = ggml_nrows(src0);
207
+
208
+ //const int n_past = ((int32_t *) dst->op_params)[0];
209
+ const int n_dims = ((int32_t *) dst->op_params)[1];
210
+ const int mode = ((int32_t *) dst->op_params)[2];
211
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
212
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
213
+
214
+ // RoPE alteration for extended context
215
+ float freq_base;
216
+ float freq_scale;
217
+ float ext_factor;
218
+ float attn_factor;
219
+ float beta_fast;
220
+ float beta_slow;
221
+
222
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
223
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
224
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
225
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
226
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
227
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
228
+
229
+ const bool is_neox = mode & 2;
230
+
231
+ const int32_t * pos = (const int32_t *) src1_dd;
232
+
233
+ const float * freq_factors = nullptr;
234
+ if (src2 != nullptr) {
235
+ freq_factors = (const float *) src2->data;
236
+ }
237
+
238
+ rope_corr_dims corr_dims;
239
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
240
+
241
+ // compute
242
+ if (is_neox) {
243
+ if (src0->type == GGML_TYPE_F32) {
244
+ rope_neox_sycl(
245
+ (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
246
+ attn_factor, corr_dims, freq_factors, main_stream
247
+ );
248
+ } else if (src0->type == GGML_TYPE_F16) {
249
+ rope_neox_sycl(
250
+ (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
251
+ attn_factor, corr_dims, freq_factors, main_stream
252
+ );
253
+ } else {
254
+ GGML_ASSERT(false);
255
+ }
256
+ } else {
257
+ if (src0->type == GGML_TYPE_F32) {
258
+ rope_norm_sycl(
259
+ (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
260
+ attn_factor, corr_dims, freq_factors, main_stream
261
+ );
262
+ } else if (src0->type == GGML_TYPE_F16) {
263
+ rope_norm_sycl(
264
+ (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
265
+ attn_factor, corr_dims, freq_factors, main_stream
266
+ );
267
+ } else {
268
+ GGML_ASSERT(false);
269
+ }
270
+ }
271
+
272
+ (void) src1;
273
+ (void) dst;
274
+ (void) src1_dd;
275
+ }
ggml/src/ggml-sycl/rope.hpp ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #ifndef GGML_SYCL_ROPE_HPP
14
+ #define GGML_SYCL_ROPE_HPP
15
+
16
+ #include "common.hpp"
17
+
18
+ void ggml_sycl_op_rope(
19
+ ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
20
+ const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream);
21
+
22
+ #endif // GGML_SYCL_ROPE_HPP
ggml/src/ggml-sycl/softmax.cpp ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "norm.hpp"
2
+
3
+ template <bool vals_smem, int ncols_template, int block_size_template>
4
+ static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
5
+ const int nrows_y, const float scale, const float max_bias, const float m0,
6
+ const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
7
+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
8
+
9
+ const int tid = item_ct1.get_local_id(2);
10
+ const int rowx = item_ct1.get_group(2);
11
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
12
+
13
+ const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
14
+
15
+ const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
16
+ const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
17
+ const int nthreads = block_size;
18
+ const int nwarps = nthreads / WARP_SIZE;
19
+ int nreduce = nwarps / WARP_SIZE;
20
+ float slope = 1.0f;
21
+
22
+ // ALiBi
23
+ if (max_bias > 0.0f) {
24
+ const uint32_t h = rowx/nrows_y; // head index
25
+
26
+ const float base = h < n_head_log2 ? m0 : m1;
27
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
28
+
29
+ slope = sycl::pow(base, float(exp));
30
+ }
31
+
32
+ float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
33
+ float max_val = -INFINITY;
34
+
35
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
36
+ const int col = col0 + tid;
37
+
38
+ if (ncols_template == 0 && col >= ncols) {
39
+ break;
40
+ }
41
+
42
+ const int ix = rowx*ncols + col;
43
+ const int iy = rowy*ncols + col;
44
+
45
+ const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
46
+
47
+ vals[col] = val;
48
+ max_val = sycl::max(max_val, val);
49
+ }
50
+
51
+ // find the max value in the block
52
+ max_val = warp_reduce_max(max_val, item_ct1);
53
+ if (block_size > WARP_SIZE) {
54
+ if (warp_id == 0) {
55
+ buf[lane_id] = -INFINITY;
56
+ for (size_t i = 1; i < nreduce; i += 1)
57
+ buf[lane_id + i * WARP_SIZE] = -INFINITY;
58
+ }
59
+ item_ct1.barrier(sycl::access::fence_space::local_space);
60
+
61
+ if (lane_id == 0) {
62
+ buf[warp_id] = max_val;
63
+ }
64
+ item_ct1.barrier(sycl::access::fence_space::local_space);
65
+ max_val = buf[lane_id];
66
+ for (size_t i = 1; i < nreduce; i += 1)
67
+ {
68
+ max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
69
+ }
70
+ max_val = warp_reduce_max(max_val, item_ct1);
71
+ }
72
+
73
+ float tmp = 0.f;
74
+ #pragma unroll
75
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
76
+ const int col = col0 + tid;
77
+ if (ncols_template == 0 && col >= ncols) {
78
+ break;
79
+ }
80
+
81
+ const float val = sycl::native::exp(vals[col] - max_val);
82
+ tmp += val;
83
+ vals[col] = val;
84
+ }
85
+
86
+ // find the sum of exps in the block
87
+ tmp = warp_reduce_sum(tmp, item_ct1);
88
+ if (block_size > WARP_SIZE) {
89
+ item_ct1.barrier(sycl::access::fence_space::local_space);
90
+ if (warp_id == 0) {
91
+ buf[lane_id] = 0.f;
92
+ for (size_t i = 1; i < nreduce; i += 1)
93
+ buf[lane_id + i * WARP_SIZE] = 0.f;
94
+ }
95
+ item_ct1.barrier(sycl::access::fence_space::local_space);
96
+
97
+ if (lane_id == 0) {
98
+ buf[warp_id] = tmp;
99
+ }
100
+ item_ct1.barrier(sycl::access::fence_space::local_space);
101
+
102
+ tmp = buf[lane_id];
103
+ for (size_t i = 1; i < nreduce; i += 1)
104
+ {
105
+ tmp += buf[lane_id + i * WARP_SIZE];
106
+ }
107
+ tmp = warp_reduce_sum(tmp, item_ct1);
108
+ }
109
+
110
+ const float inv_sum = 1.f / tmp;
111
+
112
+ #pragma unroll
113
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
114
+ const int col = col0 + tid;
115
+
116
+ if (ncols_template == 0 && col >= ncols) {
117
+ return;
118
+ }
119
+
120
+ const int idst = rowx*ncols + col;
121
+ dst[idst] = vals[col] * inv_sum;
122
+ }
123
+ }
124
+
125
+ template <bool vals_smem, int ncols_template, int block_size_template>
126
+ static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
127
+ const int nrows_y, const float scale, const float max_bias, const float m0,
128
+ const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
129
+ const size_t n_local_scratch, queue_ptr stream) {
130
+ stream->submit([&](sycl::handler &cgh) {
131
+ sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
132
+
133
+ cgh.parallel_for(
134
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
135
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
136
+ soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
137
+ nrows_y, scale, max_bias, m0,
138
+ m1, n_head_log2, item_ct1,
139
+ local_buf_acc.get_pointer());
140
+ });
141
+ });
142
+ }
143
+
144
+ static void soft_max_f32_sycl(const float * x, const float * mask,
145
+ float * dst, const int ncols_x, const int nrows_x,
146
+ const int nrows_y, const float scale, const float max_bias,
147
+ queue_ptr stream, int device) {
148
+ int nth = WARP_SIZE;
149
+ int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
150
+ while (nth < ncols_x && nth < max_block_size) nth *= 2;
151
+ if (nth>max_block_size) nth = max_block_size;
152
+
153
+ const sycl::range<3> block_dims(1, 1, nth);
154
+ const sycl::range<3> block_nums(1, 1, nrows_x);
155
+ const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
156
+
157
+ const uint32_t n_head_kv = nrows_x/nrows_y;
158
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
159
+
160
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
161
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
162
+
163
+ const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
164
+ if (n_local_scratch*sizeof(float) < local_mem_size) {
165
+ if (ncols_x > max_block_size) {
166
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
167
+ max_bias, m0, m1, n_head_log2, block_nums,
168
+ block_dims, n_local_scratch, stream);
169
+ return;
170
+ }
171
+ switch (ncols_x) {
172
+ case 32:
173
+ soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
174
+ max_bias, m0, m1, n_head_log2, block_nums,
175
+ block_dims, n_local_scratch, stream);
176
+ break;
177
+ case 64:
178
+ soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
179
+ max_bias, m0, m1, n_head_log2, block_nums,
180
+ block_dims, n_local_scratch, stream);
181
+ break;
182
+ case 128:
183
+ soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
184
+ max_bias, m0, m1, n_head_log2, block_nums,
185
+ block_dims, n_local_scratch, stream);
186
+ break;
187
+ case 256:
188
+ soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
189
+ max_bias, m0, m1, n_head_log2, block_nums,
190
+ block_dims, n_local_scratch, stream);
191
+ break;
192
+ case 512:
193
+ soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
194
+ max_bias, m0, m1, n_head_log2, block_nums,
195
+ block_dims, n_local_scratch, stream);
196
+ break;
197
+ case 1024:
198
+ soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
199
+ max_bias, m0, m1, n_head_log2, block_nums,
200
+ block_dims, n_local_scratch, stream);
201
+ break;
202
+ case 2048:
203
+ soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
204
+ max_bias, m0, m1, n_head_log2, block_nums,
205
+ block_dims, n_local_scratch, stream);
206
+ break;
207
+ case 4096:
208
+ soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
209
+ max_bias, m0, m1, n_head_log2, block_nums,
210
+ block_dims, n_local_scratch, stream);
211
+ break;
212
+ default:
213
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
214
+ max_bias, m0, m1, n_head_log2, block_nums,
215
+ block_dims, n_local_scratch, stream);
216
+ break;
217
+ }
218
+ } else {
219
+ soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
220
+ max_bias, m0, m1, n_head_log2, block_nums,
221
+ block_dims, WARP_SIZE, stream);
222
+ }
223
+ }
224
+
225
+ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
226
+ const ggml_tensor *src1, ggml_tensor *dst,
227
+ const float *src0_dd, const float *src1_dd,
228
+ float *dst_dd,
229
+ const queue_ptr &main_stream) {
230
+
231
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
232
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
233
+
234
+ #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
235
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
236
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
237
+
238
+ const int64_t ne00 = src0->ne[0];
239
+ const int64_t nrows_x = ggml_nrows(src0);
240
+ const int64_t nrows_y = src0->ne[1];
241
+
242
+ float scale = 1.0f;
243
+ float max_bias = 0.0f;
244
+
245
+ memcpy(&scale, dst->op_params + 0, sizeof(float));
246
+ memcpy(&max_bias, dst->op_params + 1, sizeof(float));
247
+
248
+ soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
249
+ nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
250
+ }
ggml/src/ggml-sycl/softmax.hpp ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #ifndef GGML_SYCL_SOFTMAX_HPP
14
+ #define GGML_SYCL_SOFTMAX_HPP
15
+
16
+ #include "common.hpp"
17
+
18
+ void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0,
19
+ const ggml_tensor *src1, ggml_tensor *dst,
20
+ const float *src0_dd, const float *src1_dd,
21
+ float *dst_dd,
22
+ const queue_ptr &main_stream);
23
+
24
+ #endif // GGML_SYCL_SOFTMAX_HPP
ggml/src/ggml-sycl/vecdotq.hpp CHANGED
@@ -820,7 +820,6 @@ vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq,
820
  #if QK_K == 256
821
  const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
822
 
823
- #if QR2_XXS == 8
824
  const int ib32 = iqs;
825
  const uint16_t * q2 = bq2->qs + 4*ib32;
826
  const uint8_t * aux8 = (const uint8_t *)q2;
@@ -838,26 +837,6 @@ vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq,
838
  }
839
  const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.25f;
840
  return d * sumi;
841
- #else
842
- // iqs is 0...15
843
- const int ib32 = iqs/2;
844
- const int il = iqs%2;
845
- const uint16_t * q2 = bq2->qs + 4*ib32;
846
- const uint8_t * aux8 = (const uint8_t *)q2;
847
- const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
848
- const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
849
- const uint32_t aux32 = q2[2] | (q2[3] << 16);
850
- const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * bq8_1[ib32].ds[0] * 0.25f;
851
- const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
852
- const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
853
- const int8_t * q8 = bq8_1[ib32].qs + 16*il;
854
- int sumi1 = 0, sumi2 = 0;
855
- for (int j = 0; j < 8; ++j) {
856
- sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
857
- sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
858
- }
859
- return d * (sumi1 + sumi2);
860
- #endif
861
  #else
862
  assert(false);
863
  return 0.f;
 
820
  #if QK_K == 256
821
  const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
822
 
 
823
  const int ib32 = iqs;
824
  const uint16_t * q2 = bq2->qs + 4*ib32;
825
  const uint8_t * aux8 = (const uint8_t *)q2;
 
837
  }
838
  const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.25f;
839
  return d * sumi;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
840
  #else
841
  assert(false);
842
  return 0.f;