rgerganov ggerganov commited on
Commit
630d713
·
1 Parent(s): f158bc0

yolo : add backend support (ggml/924)

Browse files

* yolo : add backend support

* metal : add sub and sqrt kernels

---------

Co-authored-by: Georgi Gerganov <[email protected]>

ggml/src/ggml-cuda.cu CHANGED
@@ -2181,6 +2181,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2181
  case GGML_OP_ADD:
2182
  ggml_cuda_op_add(ctx, dst);
2183
  break;
 
 
 
2184
  case GGML_OP_ACC:
2185
  ggml_cuda_op_acc(ctx, dst);
2186
  break;
@@ -2859,6 +2862,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2859
  case GGML_OP_TRANSPOSE:
2860
  case GGML_OP_NORM:
2861
  case GGML_OP_ADD:
 
2862
  case GGML_OP_MUL:
2863
  case GGML_OP_DIV:
2864
  case GGML_OP_RMS_NORM:
 
2181
  case GGML_OP_ADD:
2182
  ggml_cuda_op_add(ctx, dst);
2183
  break;
2184
+ case GGML_OP_SUB:
2185
+ ggml_cuda_op_sub(ctx, dst);
2186
+ break;
2187
  case GGML_OP_ACC:
2188
  ggml_cuda_op_acc(ctx, dst);
2189
  break;
 
2862
  case GGML_OP_TRANSPOSE:
2863
  case GGML_OP_NORM:
2864
  case GGML_OP_ADD:
2865
+ case GGML_OP_SUB:
2866
  case GGML_OP_MUL:
2867
  case GGML_OP_DIV:
2868
  case GGML_OP_RMS_NORM:
ggml/src/ggml-cuda/binbcast.cu CHANGED
@@ -9,6 +9,10 @@ static __device__ __forceinline__ float op_add(const float a, const float b) {
9
  return a + b;
10
  }
11
 
 
 
 
 
12
  static __device__ __forceinline__ float op_mul(const float a, const float b) {
13
  return a * b;
14
  }
@@ -271,6 +275,10 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
271
  ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
272
  }
273
 
 
 
 
 
274
  void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
275
  ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
276
  }
 
9
  return a + b;
10
  }
11
 
12
+ static __device__ __forceinline__ float op_sub(const float a, const float b) {
13
+ return a - b;
14
+ }
15
+
16
  static __device__ __forceinline__ float op_mul(const float a, const float b) {
17
  return a * b;
18
  }
 
275
  ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
276
  }
277
 
278
+ void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
279
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
280
+ }
281
+
282
  void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
283
  ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
284
  }
ggml/src/ggml-cuda/binbcast.cuh CHANGED
@@ -2,5 +2,6 @@
2
 
3
  void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
4
  void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
5
  void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6
  void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
2
 
3
  void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
4
  void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
5
+ void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6
  void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
7
  void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-metal.m CHANGED
@@ -31,6 +31,8 @@ struct ggml_metal_kernel {
31
  enum ggml_metal_kernel_type {
32
  GGML_METAL_KERNEL_TYPE_ADD,
33
  GGML_METAL_KERNEL_TYPE_ADD_ROW,
 
 
34
  GGML_METAL_KERNEL_TYPE_MUL,
35
  GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
  GGML_METAL_KERNEL_TYPE_DIV,
@@ -205,6 +207,7 @@ enum ggml_metal_kernel_type {
205
  GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
206
  GGML_METAL_KERNEL_TYPE_CONCAT,
207
  GGML_METAL_KERNEL_TYPE_SQR,
 
208
  GGML_METAL_KERNEL_TYPE_SIN,
209
  GGML_METAL_KERNEL_TYPE_COS,
210
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -493,6 +496,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
493
 
494
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
 
 
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
@@ -667,6 +672,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
667
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
668
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
669
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
 
670
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
671
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
672
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
@@ -769,6 +775,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
769
  case GGML_OP_PERMUTE:
770
  case GGML_OP_CONCAT:
771
  case GGML_OP_ADD:
 
772
  case GGML_OP_ACC:
773
  case GGML_OP_MUL:
774
  case GGML_OP_DIV:
@@ -777,6 +784,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
777
  case GGML_OP_CLAMP:
778
  return true;
779
  case GGML_OP_SQR:
 
780
  case GGML_OP_SIN:
781
  case GGML_OP_COS:
782
  return ggml_is_contiguous(op->src[0]);
@@ -1057,6 +1065,7 @@ static enum ggml_status ggml_metal_graph_compute(
1057
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1058
  } break;
1059
  case GGML_OP_ADD:
 
1060
  case GGML_OP_MUL:
1061
  case GGML_OP_DIV:
1062
  {
@@ -1080,6 +1089,7 @@ static enum ggml_status ggml_metal_graph_compute(
1080
  nb = ne00 / 4;
1081
  switch (dst->op) {
1082
  case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
 
1083
  case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
1084
  case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
1085
  default: GGML_ABORT("fatal error");
@@ -1089,6 +1099,7 @@ static enum ggml_status ggml_metal_graph_compute(
1089
  } else {
1090
  switch (dst->op) {
1091
  case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
 
1092
  case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
1093
  case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
1094
  default: GGML_ABORT("fatal error");
@@ -1416,6 +1427,20 @@ static enum ggml_status ggml_metal_graph_compute(
1416
 
1417
  const int64_t n = ggml_nelements(dst);
1418
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1419
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1420
  } break;
1421
  case GGML_OP_SIN:
 
31
  enum ggml_metal_kernel_type {
32
  GGML_METAL_KERNEL_TYPE_ADD,
33
  GGML_METAL_KERNEL_TYPE_ADD_ROW,
34
+ GGML_METAL_KERNEL_TYPE_SUB,
35
+ GGML_METAL_KERNEL_TYPE_SUB_ROW,
36
  GGML_METAL_KERNEL_TYPE_MUL,
37
  GGML_METAL_KERNEL_TYPE_MUL_ROW,
38
  GGML_METAL_KERNEL_TYPE_DIV,
 
207
  GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
208
  GGML_METAL_KERNEL_TYPE_CONCAT,
209
  GGML_METAL_KERNEL_TYPE_SQR,
210
+ GGML_METAL_KERNEL_TYPE_SQRT,
211
  GGML_METAL_KERNEL_TYPE_SIN,
212
  GGML_METAL_KERNEL_TYPE_COS,
213
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
 
496
 
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
499
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
500
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
501
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
502
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
503
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
 
672
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
673
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
674
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
675
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
676
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
677
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
678
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
 
775
  case GGML_OP_PERMUTE:
776
  case GGML_OP_CONCAT:
777
  case GGML_OP_ADD:
778
+ case GGML_OP_SUB:
779
  case GGML_OP_ACC:
780
  case GGML_OP_MUL:
781
  case GGML_OP_DIV:
 
784
  case GGML_OP_CLAMP:
785
  return true;
786
  case GGML_OP_SQR:
787
+ case GGML_OP_SQRT:
788
  case GGML_OP_SIN:
789
  case GGML_OP_COS:
790
  return ggml_is_contiguous(op->src[0]);
 
1065
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1066
  } break;
1067
  case GGML_OP_ADD:
1068
+ case GGML_OP_SUB:
1069
  case GGML_OP_MUL:
1070
  case GGML_OP_DIV:
1071
  {
 
1089
  nb = ne00 / 4;
1090
  switch (dst->op) {
1091
  case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
1092
+ case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
1093
  case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
1094
  case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
1095
  default: GGML_ABORT("fatal error");
 
1099
  } else {
1100
  switch (dst->op) {
1101
  case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
1102
+ case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
1103
  case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
1104
  case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
1105
  default: GGML_ABORT("fatal error");
 
1427
 
1428
  const int64_t n = ggml_nelements(dst);
1429
 
1430
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1431
+ } break;
1432
+ case GGML_OP_SQRT:
1433
+ {
1434
+ GGML_ASSERT(ggml_is_contiguous(src0));
1435
+
1436
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
1437
+
1438
+ [encoder setComputePipelineState:pipeline];
1439
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1440
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1441
+
1442
+ const int64_t n = ggml_nelements(dst);
1443
+
1444
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1445
  } break;
1446
  case GGML_OP_SIN:
ggml/src/ggml-metal.metal CHANGED
@@ -17,7 +17,7 @@ enum ggml_sort_order {
17
  GGML_SORT_ORDER_DESC,
18
  };
19
 
20
- // general-purpose kernel for addition, multiplication and division of two tensors
21
  // pros: works for non-contiguous tensors, supports broadcast across all dims
22
  // cons: not very efficient
23
  kernel void kernel_add(
@@ -70,6 +70,56 @@ kernel void kernel_add(
70
  }
71
  }
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  kernel void kernel_mul(
74
  device const char * src0,
75
  device const char * src1,
@@ -226,6 +276,15 @@ kernel void kernel_add_row(
226
  dst[tpig] = src0[tpig] + src1[tpig % nb];
227
  }
228
 
 
 
 
 
 
 
 
 
 
229
  kernel void kernel_mul_row(
230
  device const float4 * src0,
231
  device const float4 * src1,
@@ -358,6 +417,13 @@ kernel void kernel_sqr(
358
  dst[tpig] = src0[tpig] * src0[tpig];
359
  }
360
 
 
 
 
 
 
 
 
361
  kernel void kernel_sin(
362
  device const float * src0,
363
  device float * dst,
 
17
  GGML_SORT_ORDER_DESC,
18
  };
19
 
20
+ // general-purpose kernel for addition, subtraction, multiplication and division of two tensors
21
  // pros: works for non-contiguous tensors, supports broadcast across all dims
22
  // cons: not very efficient
23
  kernel void kernel_add(
 
70
  }
71
  }
72
 
73
+ kernel void kernel_sub(
74
+ device const char * src0,
75
+ device const char * src1,
76
+ device char * dst,
77
+ constant int64_t & ne00,
78
+ constant int64_t & ne01,
79
+ constant int64_t & ne02,
80
+ constant int64_t & ne03,
81
+ constant uint64_t & nb00,
82
+ constant uint64_t & nb01,
83
+ constant uint64_t & nb02,
84
+ constant uint64_t & nb03,
85
+ constant int64_t & ne10,
86
+ constant int64_t & ne11,
87
+ constant int64_t & ne12,
88
+ constant int64_t & ne13,
89
+ constant uint64_t & nb10,
90
+ constant uint64_t & nb11,
91
+ constant uint64_t & nb12,
92
+ constant uint64_t & nb13,
93
+ constant int64_t & ne0,
94
+ constant int64_t & ne1,
95
+ constant int64_t & ne2,
96
+ constant int64_t & ne3,
97
+ constant uint64_t & nb0,
98
+ constant uint64_t & nb1,
99
+ constant uint64_t & nb2,
100
+ constant uint64_t & nb3,
101
+ constant int64_t & offs,
102
+ uint3 tgpig[[threadgroup_position_in_grid]],
103
+ uint3 tpitg[[thread_position_in_threadgroup]],
104
+ uint3 ntg[[threads_per_threadgroup]]) {
105
+ const int64_t i03 = tgpig.z;
106
+ const int64_t i02 = tgpig.y;
107
+ const int64_t i01 = tgpig.x;
108
+
109
+ const int64_t i13 = i03 % ne13;
110
+ const int64_t i12 = i02 % ne12;
111
+ const int64_t i11 = i01 % ne11;
112
+
113
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
114
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
115
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
116
+
117
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
118
+ const int i10 = i0 % ne10;
119
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
120
+ }
121
+ }
122
+
123
  kernel void kernel_mul(
124
  device const char * src0,
125
  device const char * src1,
 
276
  dst[tpig] = src0[tpig] + src1[tpig % nb];
277
  }
278
 
279
+ kernel void kernel_sub_row(
280
+ device const float4 * src0,
281
+ device const float4 * src1,
282
+ device float4 * dst,
283
+ constant uint64_t & nb [[buffer(28)]],
284
+ uint tpig[[thread_position_in_grid]]) {
285
+ dst[tpig] = src0[tpig] - src1[tpig % nb];
286
+ }
287
+
288
  kernel void kernel_mul_row(
289
  device const float4 * src0,
290
  device const float4 * src1,
 
417
  dst[tpig] = src0[tpig] * src0[tpig];
418
  }
419
 
420
+ kernel void kernel_sqrt(
421
+ device const float * src0,
422
+ device float * dst,
423
+ uint tpig[[thread_position_in_grid]]) {
424
+ dst[tpig] = sqrt(src0[tpig]);
425
+ }
426
+
427
  kernel void kernel_sin(
428
  device const float * src0,
429
  device float * dst,