Spaces:
Running
Running
yolo : add backend support (ggml/924)
Browse files* yolo : add backend support
* metal : add sub and sqrt kernels
---------
Co-authored-by: Georgi Gerganov <[email protected]>
- ggml/src/ggml-cuda.cu +4 -0
- ggml/src/ggml-cuda/binbcast.cu +8 -0
- ggml/src/ggml-cuda/binbcast.cuh +1 -0
- ggml/src/ggml-metal.m +25 -0
- ggml/src/ggml-metal.metal +67 -1
ggml/src/ggml-cuda.cu
CHANGED
|
@@ -2181,6 +2181,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|
| 2181 |
case GGML_OP_ADD:
|
| 2182 |
ggml_cuda_op_add(ctx, dst);
|
| 2183 |
break;
|
|
|
|
|
|
|
|
|
|
| 2184 |
case GGML_OP_ACC:
|
| 2185 |
ggml_cuda_op_acc(ctx, dst);
|
| 2186 |
break;
|
|
@@ -2859,6 +2862,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 2859 |
case GGML_OP_TRANSPOSE:
|
| 2860 |
case GGML_OP_NORM:
|
| 2861 |
case GGML_OP_ADD:
|
|
|
|
| 2862 |
case GGML_OP_MUL:
|
| 2863 |
case GGML_OP_DIV:
|
| 2864 |
case GGML_OP_RMS_NORM:
|
|
|
|
| 2181 |
case GGML_OP_ADD:
|
| 2182 |
ggml_cuda_op_add(ctx, dst);
|
| 2183 |
break;
|
| 2184 |
+
case GGML_OP_SUB:
|
| 2185 |
+
ggml_cuda_op_sub(ctx, dst);
|
| 2186 |
+
break;
|
| 2187 |
case GGML_OP_ACC:
|
| 2188 |
ggml_cuda_op_acc(ctx, dst);
|
| 2189 |
break;
|
|
|
|
| 2862 |
case GGML_OP_TRANSPOSE:
|
| 2863 |
case GGML_OP_NORM:
|
| 2864 |
case GGML_OP_ADD:
|
| 2865 |
+
case GGML_OP_SUB:
|
| 2866 |
case GGML_OP_MUL:
|
| 2867 |
case GGML_OP_DIV:
|
| 2868 |
case GGML_OP_RMS_NORM:
|
ggml/src/ggml-cuda/binbcast.cu
CHANGED
|
@@ -9,6 +9,10 @@ static __device__ __forceinline__ float op_add(const float a, const float b) {
|
|
| 9 |
return a + b;
|
| 10 |
}
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
static __device__ __forceinline__ float op_mul(const float a, const float b) {
|
| 13 |
return a * b;
|
| 14 |
}
|
|
@@ -271,6 +275,10 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 271 |
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
|
| 272 |
}
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 275 |
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
|
| 276 |
}
|
|
|
|
| 9 |
return a + b;
|
| 10 |
}
|
| 11 |
|
| 12 |
+
static __device__ __forceinline__ float op_sub(const float a, const float b) {
|
| 13 |
+
return a - b;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
static __device__ __forceinline__ float op_mul(const float a, const float b) {
|
| 17 |
return a * b;
|
| 18 |
}
|
|
|
|
| 275 |
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
|
| 276 |
}
|
| 277 |
|
| 278 |
+
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 279 |
+
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 283 |
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
|
| 284 |
}
|
ggml/src/ggml-cuda/binbcast.cuh
CHANGED
|
@@ -2,5 +2,6 @@
|
|
| 2 |
|
| 3 |
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 4 |
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
|
|
| 5 |
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 6 |
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
|
|
| 2 |
|
| 3 |
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 4 |
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 5 |
+
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 6 |
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 7 |
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
ggml/src/ggml-metal.m
CHANGED
|
@@ -31,6 +31,8 @@ struct ggml_metal_kernel {
|
|
| 31 |
enum ggml_metal_kernel_type {
|
| 32 |
GGML_METAL_KERNEL_TYPE_ADD,
|
| 33 |
GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
|
|
|
|
|
|
| 34 |
GGML_METAL_KERNEL_TYPE_MUL,
|
| 35 |
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
| 36 |
GGML_METAL_KERNEL_TYPE_DIV,
|
|
@@ -205,6 +207,7 @@ enum ggml_metal_kernel_type {
|
|
| 205 |
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
| 206 |
GGML_METAL_KERNEL_TYPE_CONCAT,
|
| 207 |
GGML_METAL_KERNEL_TYPE_SQR,
|
|
|
|
| 208 |
GGML_METAL_KERNEL_TYPE_SIN,
|
| 209 |
GGML_METAL_KERNEL_TYPE_COS,
|
| 210 |
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
@@ -493,6 +496,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
|
| 493 |
|
| 494 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
| 495 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
|
|
|
|
|
|
| 496 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
| 497 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
| 498 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
|
@@ -667,6 +672,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
|
| 667 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
| 668 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
| 669 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
|
|
|
| 670 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
| 671 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
| 672 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
@@ -769,6 +775,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|
| 769 |
case GGML_OP_PERMUTE:
|
| 770 |
case GGML_OP_CONCAT:
|
| 771 |
case GGML_OP_ADD:
|
|
|
|
| 772 |
case GGML_OP_ACC:
|
| 773 |
case GGML_OP_MUL:
|
| 774 |
case GGML_OP_DIV:
|
|
@@ -777,6 +784,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|
| 777 |
case GGML_OP_CLAMP:
|
| 778 |
return true;
|
| 779 |
case GGML_OP_SQR:
|
|
|
|
| 780 |
case GGML_OP_SIN:
|
| 781 |
case GGML_OP_COS:
|
| 782 |
return ggml_is_contiguous(op->src[0]);
|
|
@@ -1057,6 +1065,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1057 |
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1058 |
} break;
|
| 1059 |
case GGML_OP_ADD:
|
|
|
|
| 1060 |
case GGML_OP_MUL:
|
| 1061 |
case GGML_OP_DIV:
|
| 1062 |
{
|
|
@@ -1080,6 +1089,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1080 |
nb = ne00 / 4;
|
| 1081 |
switch (dst->op) {
|
| 1082 |
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
|
|
|
| 1083 |
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
|
| 1084 |
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
|
| 1085 |
default: GGML_ABORT("fatal error");
|
|
@@ -1089,6 +1099,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1089 |
} else {
|
| 1090 |
switch (dst->op) {
|
| 1091 |
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
|
|
|
|
| 1092 |
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
| 1093 |
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
| 1094 |
default: GGML_ABORT("fatal error");
|
|
@@ -1416,6 +1427,20 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1416 |
|
| 1417 |
const int64_t n = ggml_nelements(dst);
|
| 1418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1419 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1420 |
} break;
|
| 1421 |
case GGML_OP_SIN:
|
|
|
|
| 31 |
enum ggml_metal_kernel_type {
|
| 32 |
GGML_METAL_KERNEL_TYPE_ADD,
|
| 33 |
GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
| 34 |
+
GGML_METAL_KERNEL_TYPE_SUB,
|
| 35 |
+
GGML_METAL_KERNEL_TYPE_SUB_ROW,
|
| 36 |
GGML_METAL_KERNEL_TYPE_MUL,
|
| 37 |
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
| 38 |
GGML_METAL_KERNEL_TYPE_DIV,
|
|
|
|
| 207 |
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
| 208 |
GGML_METAL_KERNEL_TYPE_CONCAT,
|
| 209 |
GGML_METAL_KERNEL_TYPE_SQR,
|
| 210 |
+
GGML_METAL_KERNEL_TYPE_SQRT,
|
| 211 |
GGML_METAL_KERNEL_TYPE_SIN,
|
| 212 |
GGML_METAL_KERNEL_TYPE_COS,
|
| 213 |
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
|
|
| 496 |
|
| 497 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
| 498 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
| 499 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
|
| 500 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
|
| 501 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
| 502 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
| 503 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
|
|
|
| 672 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
| 673 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
| 674 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
| 675 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
| 676 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
| 677 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
| 678 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
|
|
| 775 |
case GGML_OP_PERMUTE:
|
| 776 |
case GGML_OP_CONCAT:
|
| 777 |
case GGML_OP_ADD:
|
| 778 |
+
case GGML_OP_SUB:
|
| 779 |
case GGML_OP_ACC:
|
| 780 |
case GGML_OP_MUL:
|
| 781 |
case GGML_OP_DIV:
|
|
|
|
| 784 |
case GGML_OP_CLAMP:
|
| 785 |
return true;
|
| 786 |
case GGML_OP_SQR:
|
| 787 |
+
case GGML_OP_SQRT:
|
| 788 |
case GGML_OP_SIN:
|
| 789 |
case GGML_OP_COS:
|
| 790 |
return ggml_is_contiguous(op->src[0]);
|
|
|
|
| 1065 |
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1066 |
} break;
|
| 1067 |
case GGML_OP_ADD:
|
| 1068 |
+
case GGML_OP_SUB:
|
| 1069 |
case GGML_OP_MUL:
|
| 1070 |
case GGML_OP_DIV:
|
| 1071 |
{
|
|
|
|
| 1089 |
nb = ne00 / 4;
|
| 1090 |
switch (dst->op) {
|
| 1091 |
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
| 1092 |
+
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
|
| 1093 |
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
|
| 1094 |
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
|
| 1095 |
default: GGML_ABORT("fatal error");
|
|
|
|
| 1099 |
} else {
|
| 1100 |
switch (dst->op) {
|
| 1101 |
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
|
| 1102 |
+
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
| 1103 |
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
| 1104 |
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
| 1105 |
default: GGML_ABORT("fatal error");
|
|
|
|
| 1427 |
|
| 1428 |
const int64_t n = ggml_nelements(dst);
|
| 1429 |
|
| 1430 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1431 |
+
} break;
|
| 1432 |
+
case GGML_OP_SQRT:
|
| 1433 |
+
{
|
| 1434 |
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
| 1435 |
+
|
| 1436 |
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
|
| 1437 |
+
|
| 1438 |
+
[encoder setComputePipelineState:pipeline];
|
| 1439 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1440 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1441 |
+
|
| 1442 |
+
const int64_t n = ggml_nelements(dst);
|
| 1443 |
+
|
| 1444 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1445 |
} break;
|
| 1446 |
case GGML_OP_SIN:
|
ggml/src/ggml-metal.metal
CHANGED
|
@@ -17,7 +17,7 @@ enum ggml_sort_order {
|
|
| 17 |
GGML_SORT_ORDER_DESC,
|
| 18 |
};
|
| 19 |
|
| 20 |
-
// general-purpose kernel for addition, multiplication and division of two tensors
|
| 21 |
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
| 22 |
// cons: not very efficient
|
| 23 |
kernel void kernel_add(
|
|
@@ -70,6 +70,56 @@ kernel void kernel_add(
|
|
| 70 |
}
|
| 71 |
}
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
kernel void kernel_mul(
|
| 74 |
device const char * src0,
|
| 75 |
device const char * src1,
|
|
@@ -226,6 +276,15 @@ kernel void kernel_add_row(
|
|
| 226 |
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
| 227 |
}
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
kernel void kernel_mul_row(
|
| 230 |
device const float4 * src0,
|
| 231 |
device const float4 * src1,
|
|
@@ -358,6 +417,13 @@ kernel void kernel_sqr(
|
|
| 358 |
dst[tpig] = src0[tpig] * src0[tpig];
|
| 359 |
}
|
| 360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
kernel void kernel_sin(
|
| 362 |
device const float * src0,
|
| 363 |
device float * dst,
|
|
|
|
| 17 |
GGML_SORT_ORDER_DESC,
|
| 18 |
};
|
| 19 |
|
| 20 |
+
// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
|
| 21 |
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
| 22 |
// cons: not very efficient
|
| 23 |
kernel void kernel_add(
|
|
|
|
| 70 |
}
|
| 71 |
}
|
| 72 |
|
| 73 |
+
kernel void kernel_sub(
|
| 74 |
+
device const char * src0,
|
| 75 |
+
device const char * src1,
|
| 76 |
+
device char * dst,
|
| 77 |
+
constant int64_t & ne00,
|
| 78 |
+
constant int64_t & ne01,
|
| 79 |
+
constant int64_t & ne02,
|
| 80 |
+
constant int64_t & ne03,
|
| 81 |
+
constant uint64_t & nb00,
|
| 82 |
+
constant uint64_t & nb01,
|
| 83 |
+
constant uint64_t & nb02,
|
| 84 |
+
constant uint64_t & nb03,
|
| 85 |
+
constant int64_t & ne10,
|
| 86 |
+
constant int64_t & ne11,
|
| 87 |
+
constant int64_t & ne12,
|
| 88 |
+
constant int64_t & ne13,
|
| 89 |
+
constant uint64_t & nb10,
|
| 90 |
+
constant uint64_t & nb11,
|
| 91 |
+
constant uint64_t & nb12,
|
| 92 |
+
constant uint64_t & nb13,
|
| 93 |
+
constant int64_t & ne0,
|
| 94 |
+
constant int64_t & ne1,
|
| 95 |
+
constant int64_t & ne2,
|
| 96 |
+
constant int64_t & ne3,
|
| 97 |
+
constant uint64_t & nb0,
|
| 98 |
+
constant uint64_t & nb1,
|
| 99 |
+
constant uint64_t & nb2,
|
| 100 |
+
constant uint64_t & nb3,
|
| 101 |
+
constant int64_t & offs,
|
| 102 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 103 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 104 |
+
uint3 ntg[[threads_per_threadgroup]]) {
|
| 105 |
+
const int64_t i03 = tgpig.z;
|
| 106 |
+
const int64_t i02 = tgpig.y;
|
| 107 |
+
const int64_t i01 = tgpig.x;
|
| 108 |
+
|
| 109 |
+
const int64_t i13 = i03 % ne13;
|
| 110 |
+
const int64_t i12 = i02 % ne12;
|
| 111 |
+
const int64_t i11 = i01 % ne11;
|
| 112 |
+
|
| 113 |
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
| 114 |
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
| 115 |
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
| 116 |
+
|
| 117 |
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
| 118 |
+
const int i10 = i0 % ne10;
|
| 119 |
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
kernel void kernel_mul(
|
| 124 |
device const char * src0,
|
| 125 |
device const char * src1,
|
|
|
|
| 276 |
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
| 277 |
}
|
| 278 |
|
| 279 |
+
kernel void kernel_sub_row(
|
| 280 |
+
device const float4 * src0,
|
| 281 |
+
device const float4 * src1,
|
| 282 |
+
device float4 * dst,
|
| 283 |
+
constant uint64_t & nb [[buffer(28)]],
|
| 284 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 285 |
+
dst[tpig] = src0[tpig] - src1[tpig % nb];
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
kernel void kernel_mul_row(
|
| 289 |
device const float4 * src0,
|
| 290 |
device const float4 * src1,
|
|
|
|
| 417 |
dst[tpig] = src0[tpig] * src0[tpig];
|
| 418 |
}
|
| 419 |
|
| 420 |
+
kernel void kernel_sqrt(
|
| 421 |
+
device const float * src0,
|
| 422 |
+
device float * dst,
|
| 423 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 424 |
+
dst[tpig] = sqrt(src0[tpig]);
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
kernel void kernel_sin(
|
| 428 |
device const float * src0,
|
| 429 |
device float * dst,
|