Romain Biessy commited on
Commit
7b868ed
·
1 Parent(s): 345810b

sycl: Fix and disable more configurations of mul_mat (llama/15151)

Browse files

* sycl: Fix and disable more configurations of mul_mat

* Disable more configurations

Files changed (1) hide show
  1. ggml/src/ggml-sycl/ggml-sycl.cpp +28 -7
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -2705,9 +2705,9 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2705
  " : converting src1 to fp16");
2706
 
2707
  // iterate tensor dims and find the slowest moving dim and stride
2708
- int64_t last_dim=0;
2709
- int64_t last_str=0;
2710
- int64_t largest_str=0;
2711
  for(int i = 0; i< 4; i++){
2712
  // last stride is always the largest
2713
  if(src1->nb[i] == largest_str){
@@ -2783,7 +2783,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2783
  auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
2784
  const sycl::half *src1, float *dst,
2785
  int64_t a0, int64_t a1, int64_t batcha,
2786
- int64_t b0, int64_t b1, int64_t batchb,
2787
  int64_t sa0, int64_t sa1, int64_t sa2,
2788
  int64_t sb0, int64_t sb1, int64_t sb2,
2789
  int64_t sd2) {
@@ -2832,14 +2832,26 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2832
  }
2833
  };
2834
 
2835
- bool cont_batches_a = nb02 * ne02 == nb03;
2836
- bool cont_batches_b = nb12 * ne12 == nb13;
2837
- if (cont_batches_a && cont_batches_b) {
 
 
 
2838
  int64_t batches0 = ne02 * ne03;
2839
  int64_t batches1 = ne12 * ne13;
2840
  launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
2841
  ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
2842
  str_b2, nb2 / sizeof(float));
 
 
 
 
 
 
 
 
 
2843
  } else {
2844
  for (int64_t b_a = 0; b_a < ne03; b_a++) {
2845
  const sycl::half *src0_f16_shifted
@@ -4215,6 +4227,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4215
  // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
4216
  return false;
4217
  }
 
 
 
 
 
 
 
 
 
4218
  return true;
4219
  }
4220
  case GGML_OP_OUT_PROD:
 
2705
  " : converting src1 to fp16");
2706
 
2707
  // iterate tensor dims and find the slowest moving dim and stride
2708
+ int last_dim=0;
2709
+ int last_str=0;
2710
+ size_t largest_str=0;
2711
  for(int i = 0; i< 4; i++){
2712
  // last stride is always the largest
2713
  if(src1->nb[i] == largest_str){
 
2783
  auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
2784
  const sycl::half *src1, float *dst,
2785
  int64_t a0, int64_t a1, int64_t batcha,
2786
+ int64_t /*b0*/, int64_t b1, int64_t batchb,
2787
  int64_t sa0, int64_t sa1, int64_t sa2,
2788
  int64_t sb0, int64_t sb1, int64_t sb2,
2789
  int64_t sd2) {
 
2832
  }
2833
  };
2834
 
2835
+ const bool cont_batches_dim2_a = nb02 * ne02 == nb03;
2836
+ const bool cont_batches_dim2_b = nb12 * ne12 == nb13;
2837
+ const bool cont_batches_dim3_a = ne02 == 1 && nb02 * ne01 == nb03;
2838
+ const bool cont_batches_dim3_b = ne12 == 1 && nb12 * ne11 == nb13;
2839
+ if (cont_batches_dim2_a && cont_batches_dim2_b) {
2840
+ // A batch is considered contiguous if the dimension 2 is not strided
2841
  int64_t batches0 = ne02 * ne03;
2842
  int64_t batches1 = ne12 * ne13;
2843
  launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
2844
  ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
2845
  str_b2, nb2 / sizeof(float));
2846
+ } else if (cont_batches_dim3_a && cont_batches_dim3_b) {
2847
+ // This case is similar to the one above with the difference that only the batch in dimension 3 is used and the dimension 2 is of size 1.
2848
+ int64_t batches0 = ne02 * ne03;
2849
+ int64_t batches1 = ne12 * ne13;
2850
+ int64_t str_a3 = nb03 / type_size_src0;
2851
+ int64_t str_b3 = nb13 / type_size_src1;
2852
+ launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
2853
+ ne10, ne11, batches1, str_a0, str_a1, str_a3, str_b0, str_b1,
2854
+ str_b3, nb2 / sizeof(float));
2855
  } else {
2856
  for (int64_t b_a = 0; b_a < ne03; b_a++) {
2857
  const sycl::half *src0_f16_shifted
 
4227
  // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
4228
  return false;
4229
  }
4230
+ // TODO: The configuration below needs more work to be supported with oneDNN
4231
+ if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && a->ne[2] > 1 && a->ne[3] > 1) {
4232
+ return false;
4233
+ }
4234
+ // TODO: This specific configuration can fail with oneDNN and needs more debugging
4235
+ if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
4236
+ a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
4237
+ return false;
4238
+ }
4239
  return true;
4240
  }
4241
  case GGML_OP_OUT_PROD: