Spaces:
Sleeping
llama : initial Mamba-2 support (llama/9126)
Browse files* llama : initial Mamba-2 support
* ggml : SIMD ggml_ssm_scan for Mamba-2
* ggml : improve ggml_mul speed when masking recurrent states
* llama : support running Mamba-Codestral-7B-v0.1
* llama : fix Mamba-2 conv state saving
* ggml : make the ggml_mul fast broadcast path more consistently formatted
* llama : remove unused variable
* llama : add missing break
* convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present
The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires
workarounds to work correctly.
* llama : avoid redundant state copy for Mamba 1 and 2
* metal : attempt to adapt SSM_SCAN for Mamba-2
* metal : fix SSM_SCAN pipeline scope
* metal : use log and exp instead of log1pf and expf in SSM_SCAN
* metal : remove unused arguments for SSM_SCAN
The max index is 31, so trimming the arguments is necessary.
* metal : add back n_seqs to SSM_SCAN args
Whoops, this is needed for the offset in the concatenated output.
* metal : fix SSM_SCAN state head offset
* metal : fix wrong number of tokens per sequence in SSM_SCAN
* ggml : remove unused fast broadcast path in GGML_MUL
This was initially added because states were masked with ggml_mul,
but this is no longer done and so this "optimisation" is no longer
necessary, or at least not worth the additional code complexity.
* ggml : avoid multiply by D in GGML_OP_SSM_SCAN
This makes the weight buft detection in src/llama.cpp simpler.
* convert : transpose Mamba-2 A, D and reshape SSM_NORM
This breaks existing conversions of Mamba-2 models
to avoid some reshapes.
Not sure if it's a good idea,
but it makes the graph slightly cleaner.
* llama : more appropriate SSM_SCAN and SSM_CONV buft support checks
* convert : fix flake8 lint
* metal : fix confusion between ; and ,
* metal : add missing args for nb references in ssm_scan_f32_group
* metal : single-user mamba2 inference works
* kv-cache : remove const_cast when setting inputs for s_copy
And also fix multi-user inference for recurrent models
by using cell_id instead of i as the kv cell index
when populating s_copy.
* convert : avoid AutoConfig for Mamba and Mamba2 hparams
* kv-cache : allow context shift for recurrent models
* graph : fix recurrent state copies when avoiding copies
Works, but using lambda functions might not be that clean.
* ggml : fix mamba2 ssm scan when compiled with SVE
* ggml-cpu : reorder SVE FMA for consistency with other SIMD arches
* cuda : implement ssm scan for Mamba2
There is still room for improvement, but it works!
* cuda : adapt Mamba1 ssm scan to shape changes from Mamba2
* mamba : fix mismatched new and delete size for llm_build_mamba
Subclasses of llm_graph_context cannot have extra fields,
because the called destructor is not the one from the subclass.
This otherwise would cause problems when runnning Mamba-(1|2) inference
when compiled -DGGML_SANITIZE_ADDRESS=ON
* cuda : graceful fallback for Mamba-1 models with weird embd size
- ggml/include/ggml.h +2 -1
- ggml/src/ggml-cpu/ops.cpp +184 -94
- ggml/src/ggml-cpu/simd-mappings.h +1 -1
- ggml/src/ggml-cpu/vec.cpp +9 -9
- ggml/src/ggml-cpu/vec.h +9 -9
- ggml/src/ggml-cuda/ggml-cuda.cu +15 -2
- ggml/src/ggml-cuda/ssm-scan.cu +180 -51
- ggml/src/ggml-metal/ggml-metal-impl.h +5 -6
- ggml/src/ggml-metal/ggml-metal.m +61 -32
- ggml/src/ggml-metal/ggml-metal.metal +96 -23
- ggml/src/ggml.c +31 -18
|
@@ -2028,7 +2028,8 @@ extern "C" {
|
|
| 2028 |
struct ggml_tensor * dt,
|
| 2029 |
struct ggml_tensor * A,
|
| 2030 |
struct ggml_tensor * B,
|
| 2031 |
-
struct ggml_tensor * C
|
|
|
|
| 2032 |
|
| 2033 |
// partition into non-overlapping windows with padding if needed
|
| 2034 |
// example:
|
|
|
|
| 2028 |
struct ggml_tensor * dt,
|
| 2029 |
struct ggml_tensor * A,
|
| 2030 |
struct ggml_tensor * B,
|
| 2031 |
+
struct ggml_tensor * C,
|
| 2032 |
+
struct ggml_tensor * ids);
|
| 2033 |
|
| 2034 |
// partition into non-overlapping windows with padding if needed
|
| 2035 |
// example:
|
|
@@ -8337,120 +8337,210 @@ void ggml_compute_forward_ssm_conv(
|
|
| 8337 |
static void ggml_compute_forward_ssm_scan_f32(
|
| 8338 |
const ggml_compute_params * params,
|
| 8339 |
ggml_tensor * dst) {
|
| 8340 |
-
const ggml_tensor * src0 = dst->src[0]; // s
|
| 8341 |
-
const ggml_tensor * src1 = dst->src[1]; // x
|
| 8342 |
-
const ggml_tensor * src2 = dst->src[2]; // dt
|
| 8343 |
-
const ggml_tensor * src3 = dst->src[3]; // A
|
| 8344 |
-
const ggml_tensor * src4 = dst->src[4]; // B
|
| 8345 |
-
const ggml_tensor * src5 = dst->src[5]; // C
|
|
|
|
| 8346 |
|
| 8347 |
const int ith = params->ith;
|
| 8348 |
const int nth = params->nth;
|
| 8349 |
|
| 8350 |
-
const int64_t nc
|
| 8351 |
-
const int64_t nr
|
| 8352 |
-
const int64_t
|
| 8353 |
-
const int64_t
|
|
|
|
|
|
|
| 8354 |
|
| 8355 |
-
|
|
|
|
|
|
|
|
|
|
| 8356 |
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
| 8357 |
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
| 8358 |
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
| 8359 |
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
| 8360 |
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
| 8361 |
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
| 8362 |
-
|
| 8363 |
-
|
| 8364 |
-
|
| 8365 |
-
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
| 8366 |
-
// required to get correct offset for state destination (i.e. src1->nb[3])
|
| 8367 |
-
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
|
| 8368 |
|
| 8369 |
-
//
|
| 8370 |
-
const int
|
| 8371 |
|
| 8372 |
-
//
|
| 8373 |
-
const int
|
| 8374 |
-
const int
|
| 8375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8376 |
|
| 8377 |
-
|
| 8378 |
-
|
| 8379 |
-
|
| 8380 |
-
|
| 8381 |
-
|
| 8382 |
-
|
| 8383 |
-
|
| 8384 |
-
|
| 8385 |
-
|
| 8386 |
-
|
| 8387 |
-
|
| 8388 |
-
|
| 8389 |
-
|
| 8390 |
-
|
| 8391 |
-
|
| 8392 |
-
|
| 8393 |
-
|
| 8394 |
-
|
| 8395 |
-
|
| 8396 |
-
|
| 8397 |
-
|
| 8398 |
-
|
| 8399 |
-
|
| 8400 |
-
for (int64_t k = 0; k < nc; k += svcntw()) {
|
| 8401 |
-
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
|
| 8402 |
-
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
|
| 8403 |
-
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
|
| 8404 |
-
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
|
| 8405 |
-
|
| 8406 |
-
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
| 8407 |
-
t1 = exp_ps_sve(svptrue_b32(), t1);
|
| 8408 |
-
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
| 8409 |
-
|
| 8410 |
-
vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
|
| 8411 |
-
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
| 8412 |
-
|
| 8413 |
-
GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
|
| 8414 |
}
|
| 8415 |
-
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
| 8416 |
}
|
| 8417 |
-
}
|
| 8418 |
-
|
| 8419 |
-
|
| 8420 |
-
|
| 8421 |
-
|
| 8422 |
-
|
| 8423 |
-
|
| 8424 |
-
|
| 8425 |
-
|
| 8426 |
-
|
| 8427 |
-
|
| 8428 |
-
|
| 8429 |
-
|
| 8430 |
-
|
| 8431 |
-
|
| 8432 |
-
|
| 8433 |
-
|
| 8434 |
-
|
| 8435 |
-
|
| 8436 |
-
|
| 8437 |
-
|
| 8438 |
-
|
| 8439 |
-
|
| 8440 |
-
|
| 8441 |
-
|
| 8442 |
-
|
| 8443 |
-
|
| 8444 |
-
|
| 8445 |
-
|
| 8446 |
-
|
| 8447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8448 |
}
|
| 8449 |
-
y[i1] = sumf;
|
| 8450 |
}
|
| 8451 |
}
|
|
|
|
|
|
|
| 8452 |
}
|
| 8453 |
-
|
| 8454 |
}
|
| 8455 |
|
| 8456 |
void ggml_compute_forward_ssm_scan(
|
|
|
|
| 8337 |
static void ggml_compute_forward_ssm_scan_f32(
|
| 8338 |
const ggml_compute_params * params,
|
| 8339 |
ggml_tensor * dst) {
|
| 8340 |
+
const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
|
| 8341 |
+
const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
|
| 8342 |
+
const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
|
| 8343 |
+
const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
|
| 8344 |
+
const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
|
| 8345 |
+
const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
|
| 8346 |
+
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
|
| 8347 |
|
| 8348 |
const int ith = params->ith;
|
| 8349 |
const int nth = params->nth;
|
| 8350 |
|
| 8351 |
+
const int64_t nc = src0->ne[0]; // d_state
|
| 8352 |
+
const int64_t nr = src0->ne[1]; // dim
|
| 8353 |
+
const int64_t nh = src1->ne[1]; // n_head
|
| 8354 |
+
const int64_t ng = src4->ne[1];
|
| 8355 |
+
const int64_t nt = src1->ne[2]; // number of tokens per sequence
|
| 8356 |
+
const int64_t ns = src1->ne[3]; // number of sequences in the batch
|
| 8357 |
|
| 8358 |
+
// can't use ggml_nbytes because src1 is not necessarily contiguous
|
| 8359 |
+
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
|
| 8360 |
+
|
| 8361 |
+
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
|
| 8362 |
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
| 8363 |
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
| 8364 |
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
| 8365 |
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
| 8366 |
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
| 8367 |
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
| 8368 |
+
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
| 8369 |
+
// allows optimizing the modulo since n_group should be a power of 2
|
| 8370 |
+
GGML_ASSERT((ng & -ng) == ng);
|
|
|
|
|
|
|
|
|
|
| 8371 |
|
| 8372 |
+
// heads per thread
|
| 8373 |
+
const int dh = (nh + nth - 1)/nth;
|
| 8374 |
|
| 8375 |
+
// head range for this thread
|
| 8376 |
+
const int ih0 = dh*ith;
|
| 8377 |
+
const int ih1 = MIN(ih0 + dh, nh);
|
| 8378 |
+
|
| 8379 |
+
const int32_t * ids = (const int32_t *) src6->data;
|
| 8380 |
+
|
| 8381 |
+
for (int i3 = 0; i3 < ns; ++i3) {
|
| 8382 |
+
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
|
| 8383 |
+
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
|
| 8384 |
+
|
| 8385 |
+
for (int i2 = 0; i2 < nt; ++i2) {
|
| 8386 |
+
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
|
| 8387 |
+
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
|
| 8388 |
+
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
|
| 8389 |
+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
|
| 8390 |
+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
|
| 8391 |
+
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
|
| 8392 |
+
|
| 8393 |
+
if (src3->ne[0] == 1) {
|
| 8394 |
+
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
|
| 8395 |
+
|
| 8396 |
+
// n_head
|
| 8397 |
+
for (int h = ih0; h < ih1; ++h) {
|
| 8398 |
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
| 8399 |
+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
| 8400 |
+
const float dA = expf(dt_soft_plus * A[h]);
|
| 8401 |
+
|
| 8402 |
+
// dim
|
| 8403 |
+
for (int i1 = 0; i1 < nr; ++i1) {
|
| 8404 |
+
const int ii = i1 + h*nr;
|
| 8405 |
+
const float x_dt = x[ii] * dt_soft_plus;
|
| 8406 |
+
float sumf = 0.0f;
|
| 8407 |
+
#if defined(GGML_SIMD)
|
| 8408 |
+
#if defined(__ARM_FEATURE_SVE)
|
| 8409 |
+
const int ggml_f32_epr = svcntw();
|
| 8410 |
+
const int ggml_f32_step = 1 * ggml_f32_epr;
|
| 8411 |
+
|
| 8412 |
+
const int np = (nc & ~(ggml_f32_step - 1));
|
| 8413 |
+
|
| 8414 |
+
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
|
| 8415 |
+
|
| 8416 |
+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
| 8417 |
+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
| 8418 |
+
|
| 8419 |
+
for (int i = 0; i < np; i += ggml_f32_step) {
|
| 8420 |
+
// TODO: maybe unroll more?
|
| 8421 |
+
for (int j = 0; j < 1; j++) {
|
| 8422 |
+
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
|
| 8423 |
+
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
| 8424 |
+
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
|
| 8425 |
+
|
| 8426 |
+
t0 = GGML_F32_VEC_MUL(t0, adA);
|
| 8427 |
+
t1 = GGML_F32_VEC_MUL(t1, axdt);
|
| 8428 |
+
|
| 8429 |
+
t0 = GGML_F32_VEC_ADD(t0, t1);
|
| 8430 |
+
|
| 8431 |
+
sum = GGML_F32_VEC_FMA(sum, t0, t2);
|
| 8432 |
+
|
| 8433 |
+
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
|
| 8434 |
+
}
|
| 8435 |
+
}
|
| 8436 |
+
|
| 8437 |
+
sumf = GGML_F32xt_REDUCE_ONE(sum);
|
| 8438 |
+
#else
|
| 8439 |
+
const int np = (nc & ~(GGML_F32_STEP - 1));
|
| 8440 |
+
|
| 8441 |
+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
| 8442 |
+
|
| 8443 |
+
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
| 8444 |
+
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
| 8445 |
+
|
| 8446 |
+
GGML_F32_VEC ax[GGML_F32_ARR];
|
| 8447 |
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
| 8448 |
+
GGML_F32_VEC az[GGML_F32_ARR];
|
| 8449 |
+
|
| 8450 |
+
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
| 8451 |
+
for (int j = 0; j < GGML_F32_ARR; j++) {
|
| 8452 |
+
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
|
| 8453 |
+
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
| 8454 |
+
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
| 8455 |
+
|
| 8456 |
+
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
|
| 8457 |
+
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
|
| 8458 |
+
|
| 8459 |
+
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
|
| 8460 |
|
| 8461 |
+
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
|
| 8462 |
+
|
| 8463 |
+
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
|
| 8464 |
+
}
|
| 8465 |
+
}
|
| 8466 |
+
|
| 8467 |
+
// reduce sum0..sum3 to sum0
|
| 8468 |
+
GGML_F32_VEC_REDUCE(sumf, sum);
|
| 8469 |
+
#endif
|
| 8470 |
+
#else
|
| 8471 |
+
const int np = 0;
|
| 8472 |
+
#endif
|
| 8473 |
+
// d_state
|
| 8474 |
+
for (int i0 = np; i0 < nc; ++i0) {
|
| 8475 |
+
const int i = i0 + ii*nc;
|
| 8476 |
+
const int ig = i0 + (h & (ng - 1))*nc;
|
| 8477 |
+
// state = prev_state * dA + dB * x
|
| 8478 |
+
const float state = (s0[i] * dA) + (B[ig] * x_dt);
|
| 8479 |
+
// y = rowwise_dotprod(state, C)
|
| 8480 |
+
sumf += state * C[ig];
|
| 8481 |
+
s[i] = state;
|
| 8482 |
+
}
|
| 8483 |
+
y[ii] = sumf;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8484 |
}
|
|
|
|
| 8485 |
}
|
| 8486 |
+
} else {
|
| 8487 |
+
// Mamba-1 has an element-wise decay factor for the states
|
| 8488 |
+
|
| 8489 |
+
// n_head
|
| 8490 |
+
for (int h = ih0; h < ih1; ++h) {
|
| 8491 |
+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
| 8492 |
+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
| 8493 |
+
|
| 8494 |
+
// dim
|
| 8495 |
+
for (int i1 = 0; i1 < nr; ++i1) {
|
| 8496 |
+
const int ii = i1 + h*nr;
|
| 8497 |
+
const float x_dt = x[ii] * dt_soft_plus;
|
| 8498 |
+
#if defined(__ARM_FEATURE_SVE)
|
| 8499 |
+
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
|
| 8500 |
+
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
|
| 8501 |
+
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
|
| 8502 |
+
|
| 8503 |
+
// d_state
|
| 8504 |
+
// TODO: what happens when (d_state % svcntw()) != 0?
|
| 8505 |
+
for (int64_t k = 0; k < nc; k += svcntw()) {
|
| 8506 |
+
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
|
| 8507 |
+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
|
| 8508 |
+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
|
| 8509 |
+
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
|
| 8510 |
+
|
| 8511 |
+
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
|
| 8512 |
+
t1 = exp_ps_sve(svptrue_b32(), t1);
|
| 8513 |
+
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
|
| 8514 |
+
|
| 8515 |
+
vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
|
| 8516 |
+
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
|
| 8517 |
+
|
| 8518 |
+
GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
|
| 8519 |
+
}
|
| 8520 |
+
y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
|
| 8521 |
+
#else
|
| 8522 |
+
float sumf = 0.0f;
|
| 8523 |
+
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
|
| 8524 |
+
// and also because expf is used within the loop.
|
| 8525 |
+
// d_state
|
| 8526 |
+
for (int i0 = 0; i0 < nc; ++i0) {
|
| 8527 |
+
const int i = i0 + ii*nc;
|
| 8528 |
+
const int ig = i0 + (h & (ng - 1))*nc;
|
| 8529 |
+
// state = prev_state * dA + dB * x
|
| 8530 |
+
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
| 8531 |
+
// y = rowwise_dotprod(state, C)
|
| 8532 |
+
sumf += state * C[ig];
|
| 8533 |
+
s[i] = state;
|
| 8534 |
+
}
|
| 8535 |
+
y[ii] = sumf;
|
| 8536 |
+
#endif
|
| 8537 |
}
|
|
|
|
| 8538 |
}
|
| 8539 |
}
|
| 8540 |
+
// use the output as the source when it's not the first token-wise iteration
|
| 8541 |
+
s0 = s;
|
| 8542 |
}
|
| 8543 |
+
}
|
| 8544 |
}
|
| 8545 |
|
| 8546 |
void ggml_compute_forward_ssm_scan(
|
|
@@ -189,7 +189,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
|
| 189 |
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
| 190 |
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
|
| 191 |
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
| 192 |
-
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg,
|
| 193 |
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
|
| 194 |
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
|
| 195 |
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
|
|
|
| 189 |
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
| 190 |
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
|
| 191 |
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
| 192 |
+
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
|
| 193 |
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
|
| 194 |
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
|
| 195 |
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
|
@@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
|
|
| 37 |
for (int i = 0; i < np; i += ggml_f32_step) {
|
| 38 |
ax1 = GGML_F32_VEC_LOAD(x + i);
|
| 39 |
ay1 = GGML_F32_VEC_LOAD(y + i);
|
| 40 |
-
sum1 = GGML_F32_VEC_FMA(ax1, ay1
|
| 41 |
|
| 42 |
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
|
| 43 |
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
|
| 44 |
-
sum2 = GGML_F32_VEC_FMA(ax2, ay2
|
| 45 |
|
| 46 |
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
|
| 47 |
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
|
| 48 |
-
sum3 = GGML_F32_VEC_FMA(ax3, ay3
|
| 49 |
|
| 50 |
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
|
| 51 |
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
|
| 52 |
-
sum4 = GGML_F32_VEC_FMA(ax4, ay4
|
| 53 |
|
| 54 |
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
|
| 55 |
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
|
| 56 |
-
sum5 = GGML_F32_VEC_FMA(ax5, ay5
|
| 57 |
|
| 58 |
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
|
| 59 |
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
|
| 60 |
-
sum6 = GGML_F32_VEC_FMA(ax6, ay6
|
| 61 |
|
| 62 |
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
|
| 63 |
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
|
| 64 |
-
sum7 = GGML_F32_VEC_FMA(ax7, ay7
|
| 65 |
|
| 66 |
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
|
| 67 |
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
|
| 68 |
-
sum8 = GGML_F32_VEC_FMA(ax8, ay8
|
| 69 |
}
|
| 70 |
// leftovers
|
| 71 |
// Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
|
|
@@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
|
|
| 73 |
for (int i = np; i < np2; i += ggml_f32_epr) {
|
| 74 |
ax1 = GGML_F32_VEC_LOAD(x + i);
|
| 75 |
ay1 = GGML_F32_VEC_LOAD(y + i);
|
| 76 |
-
sum1 = GGML_F32_VEC_FMA(ax1, ay1
|
| 77 |
}
|
| 78 |
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
|
| 79 |
if (np2 < n) {
|
|
|
|
| 37 |
for (int i = 0; i < np; i += ggml_f32_step) {
|
| 38 |
ax1 = GGML_F32_VEC_LOAD(x + i);
|
| 39 |
ay1 = GGML_F32_VEC_LOAD(y + i);
|
| 40 |
+
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
|
| 41 |
|
| 42 |
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
|
| 43 |
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
|
| 44 |
+
sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
|
| 45 |
|
| 46 |
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
|
| 47 |
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
|
| 48 |
+
sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
|
| 49 |
|
| 50 |
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
|
| 51 |
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
|
| 52 |
+
sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
|
| 53 |
|
| 54 |
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
|
| 55 |
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
|
| 56 |
+
sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
|
| 57 |
|
| 58 |
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
|
| 59 |
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
|
| 60 |
+
sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
|
| 61 |
|
| 62 |
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
|
| 63 |
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
|
| 64 |
+
sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
|
| 65 |
|
| 66 |
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
|
| 67 |
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
|
| 68 |
+
sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
|
| 69 |
}
|
| 70 |
// leftovers
|
| 71 |
// Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
|
|
|
|
| 73 |
for (int i = np; i < np2; i += ggml_f32_epr) {
|
| 74 |
ax1 = GGML_F32_VEC_LOAD(x + i);
|
| 75 |
ay1 = GGML_F32_VEC_LOAD(y + i);
|
| 76 |
+
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
|
| 77 |
}
|
| 78 |
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
|
| 79 |
if (np2 < n) {
|
|
@@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
|
|
| 163 |
|
| 164 |
ax1 = GGML_F32_VEC_LOAD(x + i);
|
| 165 |
ay1 = GGML_F32_VEC_LOAD(y + i);
|
| 166 |
-
ay1 = GGML_F32_VEC_FMA(ax1, vx
|
| 167 |
|
| 168 |
GGML_F32_VEC_STORE(y + i, ay1);
|
| 169 |
|
| 170 |
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
|
| 171 |
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
|
| 172 |
-
ay2 = GGML_F32_VEC_FMA(ax2, vx
|
| 173 |
|
| 174 |
GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
|
| 175 |
|
| 176 |
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
|
| 177 |
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
|
| 178 |
-
ay3 = GGML_F32_VEC_FMA(ax3, vx
|
| 179 |
|
| 180 |
GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
|
| 181 |
|
| 182 |
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
|
| 183 |
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
|
| 184 |
-
ay4 = GGML_F32_VEC_FMA(ax4, vx
|
| 185 |
|
| 186 |
GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
|
| 187 |
|
| 188 |
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
|
| 189 |
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
|
| 190 |
-
ay5 = GGML_F32_VEC_FMA(ax5, vx
|
| 191 |
|
| 192 |
GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
|
| 193 |
|
| 194 |
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
|
| 195 |
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
|
| 196 |
-
ay6 = GGML_F32_VEC_FMA(ax6, vx
|
| 197 |
|
| 198 |
GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
|
| 199 |
|
| 200 |
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
|
| 201 |
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
|
| 202 |
-
ay7 = GGML_F32_VEC_FMA(ax7, vx
|
| 203 |
|
| 204 |
GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
|
| 205 |
|
| 206 |
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
|
| 207 |
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
|
| 208 |
-
ay8 = GGML_F32_VEC_FMA(ax8, vx
|
| 209 |
|
| 210 |
GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
|
| 211 |
}
|
|
@@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
|
|
| 215 |
for (int i = np; i < np2; i += ggml_f32_epr) {
|
| 216 |
ax1 = GGML_F32_VEC_LOAD(x + i);
|
| 217 |
ay1 = GGML_F32_VEC_LOAD(y + i);
|
| 218 |
-
ay1 = GGML_F32_VEC_FMA(ax1, vx
|
| 219 |
|
| 220 |
GGML_F32_VEC_STORE(y + i, ay1);
|
| 221 |
}
|
|
|
|
| 163 |
|
| 164 |
ax1 = GGML_F32_VEC_LOAD(x + i);
|
| 165 |
ay1 = GGML_F32_VEC_LOAD(y + i);
|
| 166 |
+
ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
|
| 167 |
|
| 168 |
GGML_F32_VEC_STORE(y + i, ay1);
|
| 169 |
|
| 170 |
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
|
| 171 |
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
|
| 172 |
+
ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
|
| 173 |
|
| 174 |
GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
|
| 175 |
|
| 176 |
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
|
| 177 |
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
|
| 178 |
+
ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
|
| 179 |
|
| 180 |
GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
|
| 181 |
|
| 182 |
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
|
| 183 |
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
|
| 184 |
+
ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
|
| 185 |
|
| 186 |
GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
|
| 187 |
|
| 188 |
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
|
| 189 |
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
|
| 190 |
+
ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
|
| 191 |
|
| 192 |
GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
|
| 193 |
|
| 194 |
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
|
| 195 |
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
|
| 196 |
+
ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
|
| 197 |
|
| 198 |
GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
|
| 199 |
|
| 200 |
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
|
| 201 |
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
|
| 202 |
+
ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
|
| 203 |
|
| 204 |
GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
|
| 205 |
|
| 206 |
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
|
| 207 |
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
|
| 208 |
+
ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
|
| 209 |
|
| 210 |
GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
|
| 211 |
}
|
|
|
|
| 215 |
for (int i = np; i < np2; i += ggml_f32_epr) {
|
| 216 |
ax1 = GGML_F32_VEC_LOAD(x + i);
|
| 217 |
ay1 = GGML_F32_VEC_LOAD(y + i);
|
| 218 |
+
ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
|
| 219 |
|
| 220 |
GGML_F32_VEC_STORE(y + i, ay1);
|
| 221 |
}
|
|
@@ -3321,9 +3321,22 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3321 |
case GGML_OP_COS:
|
| 3322 |
case GGML_OP_CLAMP:
|
| 3323 |
case GGML_OP_LOG:
|
| 3324 |
-
case GGML_OP_SSM_SCAN:
|
| 3325 |
-
case GGML_OP_SSM_CONV:
|
| 3326 |
return true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3327 |
case GGML_OP_CONT:
|
| 3328 |
return op->src[0]->type != GGML_TYPE_BF16;
|
| 3329 |
case GGML_OP_DIAG_MASK_INF:
|
|
|
|
| 3321 |
case GGML_OP_COS:
|
| 3322 |
case GGML_OP_CLAMP:
|
| 3323 |
case GGML_OP_LOG:
|
|
|
|
|
|
|
| 3324 |
return true;
|
| 3325 |
+
case GGML_OP_SSM_SCAN: {
|
| 3326 |
+
if (op->src[3]->ne[0] == 1) {
|
| 3327 |
+
// Mamba2
|
| 3328 |
+
// (kernel only supports d_state == 128 && d_head % 16 == 0)
|
| 3329 |
+
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
|
| 3330 |
+
} else {
|
| 3331 |
+
// Mamba
|
| 3332 |
+
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
|
| 3333 |
+
return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
|
| 3334 |
+
}
|
| 3335 |
+
}
|
| 3336 |
+
case GGML_OP_SSM_CONV: {
|
| 3337 |
+
// assumes d_inner % threads == 0
|
| 3338 |
+
return op->src[0]->ne[1] % 128 == 0;
|
| 3339 |
+
}
|
| 3340 |
case GGML_OP_CONT:
|
| 3341 |
return op->src[0]->type != GGML_TYPE_BF16;
|
| 3342 |
case GGML_OP_DIAG_MASK_INF:
|
|
@@ -4,16 +4,15 @@ template <size_t splitD, size_t N>
|
|
| 4 |
__global__ void __launch_bounds__(splitD, 2)
|
| 5 |
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
| 6 |
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
|
| 7 |
-
const
|
| 8 |
-
const int
|
| 9 |
-
const int
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
GGML_UNUSED(src2_nb0);
|
| 13 |
|
| 14 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 15 |
-
const int bidx = blockIdx.x; // split along B
|
| 16 |
-
const int bidy = blockIdx.y; // split along D
|
| 17 |
const int tid = threadIdx.x;
|
| 18 |
const int wid = tid / 32;
|
| 19 |
const int wtid = tid % 32;
|
|
@@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2)
|
|
| 24 |
float * smem_A = smem;
|
| 25 |
float * smem_s0 = smem_A + splitD * stride_sA;
|
| 26 |
|
| 27 |
-
const float * s0_block = (const float *) ((const char *) src0 + bidx *
|
| 28 |
-
const float * x_block = (const float *) ((const char *) src1 + (bidx *
|
| 29 |
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
|
| 30 |
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
|
| 31 |
-
const float * B_block = (const float *) ((const char *) src4 + (bidx *
|
| 32 |
-
const float * C_block = (const float *) ((const char *) src5 + (bidx *
|
| 33 |
-
float * y_block = (float *) ((char *) dst + (bidx *
|
| 34 |
-
float * s_block = (float *) ((char *) dst +
|
| 35 |
|
| 36 |
-
const int stride_s0 =
|
| 37 |
-
const int stride_x =
|
| 38 |
const int stride_dt = src2_nb1 / sizeof(float);
|
| 39 |
const int stride_A = src3_nb1 / sizeof(float);
|
| 40 |
-
const int stride_B =
|
| 41 |
-
const int stride_C =
|
| 42 |
const int stride_s = stride_s0;
|
| 43 |
-
const int stride_y =
|
| 44 |
|
| 45 |
// can N not be 16? for example 32?
|
| 46 |
if (N == 16) {
|
|
@@ -84,24 +83,156 @@ __global__ void __launch_bounds__(splitD, 2)
|
|
| 84 |
}
|
| 85 |
}
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
|
| 88 |
-
const float * src4, const float * src5, const
|
| 89 |
-
const int
|
| 90 |
-
const int
|
| 91 |
-
const int
|
| 92 |
-
|
| 93 |
cudaStream_t stream) {
|
| 94 |
const int threads = 128;
|
| 95 |
-
//
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
} else {
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
}
|
| 106 |
}
|
| 107 |
|
|
@@ -112,30 +243,25 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 112 |
const struct ggml_tensor * src3 = dst->src[3]; // A
|
| 113 |
const struct ggml_tensor * src4 = dst->src[4]; // B
|
| 114 |
const struct ggml_tensor * src5 = dst->src[5]; // C
|
| 115 |
-
|
| 116 |
-
// const int64_t d_state = src0->ne[0];
|
| 117 |
-
// const int64_t d_inner = src0->ne[1];
|
| 118 |
-
// const int64_t l = src1->ne[1];
|
| 119 |
-
// const int64_t b = src0->ne[2];
|
| 120 |
|
| 121 |
const int64_t nc = src0->ne[0]; // d_state
|
| 122 |
-
const int64_t nr = src0->ne[1]; //
|
| 123 |
-
const int64_t
|
| 124 |
-
const int64_t
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
-
GGML_ASSERT(ggml_nelements(src1) +
|
| 127 |
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
| 128 |
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
| 129 |
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
| 130 |
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
| 131 |
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
| 132 |
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
| 133 |
-
|
| 134 |
-
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
|
| 135 |
-
// required for per-sequence offsets for states
|
| 136 |
-
GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
|
| 137 |
-
// required to get correct offset for state destination (i.e. src1->nb[3])
|
| 138 |
-
GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
|
| 139 |
|
| 140 |
const float * src0_d = (const float *) src0->data;
|
| 141 |
const float * src1_d = (const float *) src1->data;
|
|
@@ -143,13 +269,16 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 143 |
const float * src3_d = (const float *) src3->data;
|
| 144 |
const float * src4_d = (const float *) src4->data;
|
| 145 |
const float * src5_d = (const float *) src5->data;
|
|
|
|
| 146 |
float * dst_d = (float *) dst->data;
|
| 147 |
cudaStream_t stream = ctx.stream();
|
| 148 |
|
| 149 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
| 150 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 151 |
|
| 152 |
-
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d,
|
| 153 |
-
|
| 154 |
-
|
|
|
|
| 155 |
}
|
|
|
|
| 4 |
__global__ void __launch_bounds__(splitD, 2)
|
| 5 |
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
| 6 |
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
|
| 7 |
+
const int32_t * __restrict__ src6, float * __restrict__ dst,
|
| 8 |
+
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
|
| 9 |
+
const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
| 10 |
+
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
|
| 11 |
+
const int64_t s_off, const int64_t d_inner, const int64_t L) {
|
|
|
|
| 12 |
|
| 13 |
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
| 14 |
+
const int bidx = blockIdx.x; // split along B (sequences)
|
| 15 |
+
const int bidy = blockIdx.y; // split along D (d_inner)
|
| 16 |
const int tid = threadIdx.x;
|
| 17 |
const int wid = tid / 32;
|
| 18 |
const int wtid = tid % 32;
|
|
|
|
| 23 |
float * smem_A = smem;
|
| 24 |
float * smem_s0 = smem_A + splitD * stride_sA;
|
| 25 |
|
| 26 |
+
const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
|
| 27 |
+
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
|
| 28 |
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
|
| 29 |
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
|
| 30 |
+
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
|
| 31 |
+
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
|
| 32 |
+
float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
|
| 33 |
+
float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
|
| 34 |
|
| 35 |
+
const int stride_s0 = src0_nb2 / sizeof(float);
|
| 36 |
+
const int stride_x = src1_nb2 / sizeof(float);
|
| 37 |
const int stride_dt = src2_nb1 / sizeof(float);
|
| 38 |
const int stride_A = src3_nb1 / sizeof(float);
|
| 39 |
+
const int stride_B = src4_nb2 / sizeof(float);
|
| 40 |
+
const int stride_C = src5_nb2 / sizeof(float);
|
| 41 |
const int stride_s = stride_s0;
|
| 42 |
+
const int stride_y = d_inner;
|
| 43 |
|
| 44 |
// can N not be 16? for example 32?
|
| 45 |
if (N == 16) {
|
|
|
|
| 83 |
}
|
| 84 |
}
|
| 85 |
|
| 86 |
+
// assumes as many threads as d_state
|
| 87 |
+
template <int splitH, int d_state>
|
| 88 |
+
__global__ void __launch_bounds__(d_state, 1)
|
| 89 |
+
ssm_scan_f32_group(
|
| 90 |
+
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
| 91 |
+
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
|
| 92 |
+
const int32_t * __restrict__ src6, float * __restrict__ dst,
|
| 93 |
+
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
|
| 94 |
+
const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
| 95 |
+
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
|
| 96 |
+
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
|
| 97 |
+
|
| 98 |
+
const int head_idx = (blockIdx.x * splitH) / d_head;
|
| 99 |
+
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
|
| 100 |
+
const int seq_idx = blockIdx.y;
|
| 101 |
+
|
| 102 |
+
const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
|
| 103 |
+
|
| 104 |
+
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
| 105 |
+
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
|
| 106 |
+
const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
|
| 107 |
+
const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
|
| 108 |
+
const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
|
| 109 |
+
const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
|
| 110 |
+
float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
|
| 111 |
+
float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
|
| 112 |
+
|
| 113 |
+
// strides across n_seq_tokens
|
| 114 |
+
const int stride_x = src1_nb2 / sizeof(float);
|
| 115 |
+
const int stride_dt = src2_nb1 / sizeof(float);
|
| 116 |
+
const int stride_B = src4_nb2 / sizeof(float);
|
| 117 |
+
const int stride_C = src5_nb2 / sizeof(float);
|
| 118 |
+
const int stride_y = n_head * d_head;
|
| 119 |
+
|
| 120 |
+
float state[splitH];
|
| 121 |
+
// for the parallel accumulation
|
| 122 |
+
__shared__ float stateC[splitH * d_state];
|
| 123 |
+
|
| 124 |
+
#pragma unroll
|
| 125 |
+
for (int j = 0; j < splitH; j++) {
|
| 126 |
+
state[j] = s0_block[j * d_state + threadIdx.x];
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
for (int64_t i = 0; i < n_tok; i++) {
|
| 130 |
+
// TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
|
| 131 |
+
// TODO: only calculate B and C once per head group
|
| 132 |
+
// NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
|
| 133 |
+
float dt_soft_plus = dt_block[i * stride_dt];
|
| 134 |
+
if (dt_soft_plus <= 20.0f) {
|
| 135 |
+
dt_soft_plus = log1pf(expf(dt_soft_plus));
|
| 136 |
+
}
|
| 137 |
+
const float dA = expf(dt_soft_plus * A_block[0]);
|
| 138 |
+
const float B = B_block[i * stride_B + threadIdx.x];
|
| 139 |
+
const float C = C_block[i * stride_C + threadIdx.x];
|
| 140 |
+
|
| 141 |
+
// across d_head
|
| 142 |
+
#pragma unroll
|
| 143 |
+
for (int j = 0; j < splitH; j++) {
|
| 144 |
+
const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
|
| 145 |
+
|
| 146 |
+
state[j] = (state[j] * dA) + (B * x_dt);
|
| 147 |
+
|
| 148 |
+
stateC[j * d_state + threadIdx.x] = state[j] * C;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
__syncthreads();
|
| 152 |
+
|
| 153 |
+
// parallel accumulation for stateC
|
| 154 |
+
// TODO: simplify
|
| 155 |
+
{
|
| 156 |
+
static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
|
| 157 |
+
static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
|
| 158 |
+
|
| 159 |
+
// reduce until w matches the warp size
|
| 160 |
+
// TODO: does this work even when the physical warp size is 64?
|
| 161 |
+
#pragma unroll
|
| 162 |
+
for (int w = d_state; w > WARP_SIZE; w >>= 1) {
|
| 163 |
+
// (assuming there are d_state threads)
|
| 164 |
+
#pragma unroll
|
| 165 |
+
for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
|
| 166 |
+
// TODO: check for bank conflicts
|
| 167 |
+
const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
|
| 168 |
+
stateC[k] += stateC[k + (w >> 1)];
|
| 169 |
+
|
| 170 |
+
}
|
| 171 |
+
__syncthreads();
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
static_assert(splitH >= d_state / WARP_SIZE);
|
| 175 |
+
|
| 176 |
+
#pragma unroll
|
| 177 |
+
for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
|
| 178 |
+
float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
|
| 179 |
+
y = warp_reduce_sum(y);
|
| 180 |
+
|
| 181 |
+
// store the above accumulations
|
| 182 |
+
if (threadIdx.x % WARP_SIZE == 0) {
|
| 183 |
+
const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
|
| 184 |
+
y_block[i * stride_y + k] = y;
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
// write back the state
|
| 191 |
+
#pragma unroll
|
| 192 |
+
for (int j = 0; j < splitH; j++) {
|
| 193 |
+
s_block[j * d_state + threadIdx.x] = state[j];
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
|
| 198 |
+
const float * src4, const float * src5, const int32_t * src6, float * dst,
|
| 199 |
+
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
|
| 200 |
+
const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
|
| 201 |
+
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
|
| 202 |
+
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
|
| 203 |
cudaStream_t stream) {
|
| 204 |
const int threads = 128;
|
| 205 |
+
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
| 206 |
+
if (src3_nb1 == sizeof(float)) {
|
| 207 |
+
// Mamba-2
|
| 208 |
+
if (d_state == 128) {
|
| 209 |
+
GGML_ASSERT(d_state % threads == 0);
|
| 210 |
+
// NOTE: can be any power of two between 4 and 64
|
| 211 |
+
const int splitH = 16;
|
| 212 |
+
GGML_ASSERT(head_dim % splitH == 0);
|
| 213 |
+
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
|
| 214 |
+
ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
|
| 215 |
+
src0, src1, src2, src3, src4, src5, src6, dst,
|
| 216 |
+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
|
| 217 |
+
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
|
| 218 |
+
} else {
|
| 219 |
+
GGML_ABORT("doesn't support d_state!=128.");
|
| 220 |
+
}
|
| 221 |
} else {
|
| 222 |
+
// Mamba-1
|
| 223 |
+
GGML_ASSERT(n_head % threads == 0);
|
| 224 |
+
GGML_ASSERT(head_dim == 1);
|
| 225 |
+
GGML_ASSERT(n_group == 1);
|
| 226 |
+
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
|
| 227 |
+
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
|
| 228 |
+
if (d_state == 16) {
|
| 229 |
+
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
|
| 230 |
+
src0, src1, src2, src3, src4, src5, src6, dst,
|
| 231 |
+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
| 232 |
+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
| 233 |
+
} else {
|
| 234 |
+
GGML_ABORT("doesn't support d_state!=16.");
|
| 235 |
+
}
|
| 236 |
}
|
| 237 |
}
|
| 238 |
|
|
|
|
| 243 |
const struct ggml_tensor * src3 = dst->src[3]; // A
|
| 244 |
const struct ggml_tensor * src4 = dst->src[4]; // B
|
| 245 |
const struct ggml_tensor * src5 = dst->src[5]; // C
|
| 246 |
+
const struct ggml_tensor * src6 = dst->src[6]; // ids
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
const int64_t nc = src0->ne[0]; // d_state
|
| 249 |
+
const int64_t nr = src0->ne[1]; // head_dim or 1
|
| 250 |
+
const int64_t nh = src1->ne[1]; // n_head
|
| 251 |
+
const int64_t ng = src4->ne[1]; // n_group
|
| 252 |
+
const int64_t n_t = src1->ne[2]; // number of tokens per sequence
|
| 253 |
+
const int64_t n_s = src1->ne[3]; // number of sequences in the batch
|
| 254 |
+
|
| 255 |
+
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
|
| 256 |
|
| 257 |
+
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
|
| 258 |
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
| 259 |
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
| 260 |
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
| 261 |
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
| 262 |
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
| 263 |
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
| 264 |
+
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
const float * src0_d = (const float *) src0->data;
|
| 267 |
const float * src1_d = (const float *) src1->data;
|
|
|
|
| 269 |
const float * src3_d = (const float *) src3->data;
|
| 270 |
const float * src4_d = (const float *) src4->data;
|
| 271 |
const float * src5_d = (const float *) src5->data;
|
| 272 |
+
const int32_t * src6_d = (const int32_t *) src6->data;
|
| 273 |
float * dst_d = (float *) dst->data;
|
| 274 |
cudaStream_t stream = ctx.stream();
|
| 275 |
|
| 276 |
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
| 277 |
+
GGML_ASSERT(src6->type == GGML_TYPE_I32);
|
| 278 |
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 279 |
|
| 280 |
+
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
|
| 281 |
+
src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
|
| 282 |
+
src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
|
| 283 |
+
s_off, nc, nr, nh, ng, n_t, n_s, stream);
|
| 284 |
}
|
|
@@ -513,26 +513,25 @@ typedef struct {
|
|
| 513 |
typedef struct {
|
| 514 |
int64_t d_state;
|
| 515 |
int64_t d_inner;
|
|
|
|
|
|
|
| 516 |
int64_t n_seq_tokens;
|
| 517 |
int64_t n_seqs;
|
| 518 |
-
uint64_t nb00;
|
| 519 |
uint64_t nb01;
|
| 520 |
uint64_t nb02;
|
| 521 |
-
uint64_t
|
| 522 |
uint64_t nb11;
|
| 523 |
uint64_t nb12;
|
| 524 |
uint64_t nb13;
|
| 525 |
-
uint64_t nb20;
|
| 526 |
uint64_t nb21;
|
| 527 |
uint64_t nb22;
|
| 528 |
-
uint64_t nb30;
|
| 529 |
uint64_t nb31;
|
| 530 |
-
uint64_t nb40;
|
| 531 |
uint64_t nb41;
|
| 532 |
uint64_t nb42;
|
| 533 |
-
uint64_t
|
| 534 |
uint64_t nb51;
|
| 535 |
uint64_t nb52;
|
|
|
|
| 536 |
} ggml_metal_kargs_ssm_scan;
|
| 537 |
|
| 538 |
typedef struct {
|
|
|
|
| 513 |
typedef struct {
|
| 514 |
int64_t d_state;
|
| 515 |
int64_t d_inner;
|
| 516 |
+
int64_t n_head;
|
| 517 |
+
int64_t n_group;
|
| 518 |
int64_t n_seq_tokens;
|
| 519 |
int64_t n_seqs;
|
|
|
|
| 520 |
uint64_t nb01;
|
| 521 |
uint64_t nb02;
|
| 522 |
+
uint64_t nb03;
|
| 523 |
uint64_t nb11;
|
| 524 |
uint64_t nb12;
|
| 525 |
uint64_t nb13;
|
|
|
|
| 526 |
uint64_t nb21;
|
| 527 |
uint64_t nb22;
|
|
|
|
| 528 |
uint64_t nb31;
|
|
|
|
| 529 |
uint64_t nb41;
|
| 530 |
uint64_t nb42;
|
| 531 |
+
uint64_t nb43;
|
| 532 |
uint64_t nb51;
|
| 533 |
uint64_t nb52;
|
| 534 |
+
uint64_t nb53;
|
| 535 |
} ggml_metal_kargs_ssm_scan;
|
| 536 |
|
| 537 |
typedef struct {
|
|
@@ -217,6 +217,7 @@ enum ggml_metal_kernel_type {
|
|
| 217 |
GGML_METAL_KERNEL_TYPE_NORM,
|
| 218 |
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
| 219 |
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
|
|
|
| 220 |
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
| 221 |
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
| 222 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
|
@@ -1196,6 +1197,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1196 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
| 1197 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
| 1198 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
|
|
|
| 1199 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
| 1200 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
| 1201 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
|
@@ -2809,71 +2811,91 @@ static bool ggml_metal_encode_node(
|
|
| 2809 |
struct ggml_tensor * src3 = node->src[3];
|
| 2810 |
struct ggml_tensor * src4 = node->src[4];
|
| 2811 |
struct ggml_tensor * src5 = node->src[5];
|
|
|
|
| 2812 |
|
| 2813 |
GGML_ASSERT(src3);
|
| 2814 |
GGML_ASSERT(src4);
|
| 2815 |
GGML_ASSERT(src5);
|
|
|
|
| 2816 |
|
| 2817 |
size_t offs_src3 = 0;
|
| 2818 |
size_t offs_src4 = 0;
|
| 2819 |
size_t offs_src5 = 0;
|
|
|
|
| 2820 |
|
| 2821 |
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
| 2822 |
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
| 2823 |
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
|
|
|
| 2824 |
|
| 2825 |
-
const int64_t ne30 = src3->ne[0];
|
| 2826 |
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
| 2827 |
|
| 2828 |
-
const uint64_t nb30 = src3->nb[0];
|
| 2829 |
const uint64_t nb31 = src3->nb[1];
|
| 2830 |
|
| 2831 |
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
| 2832 |
-
const int64_t ne41 = src4->ne[1];
|
| 2833 |
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
|
|
|
| 2834 |
|
| 2835 |
-
const uint64_t nb40 = src4->nb[0];
|
| 2836 |
const uint64_t nb41 = src4->nb[1];
|
| 2837 |
const uint64_t nb42 = src4->nb[2];
|
|
|
|
| 2838 |
|
| 2839 |
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
|
| 2840 |
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
|
| 2841 |
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
|
|
|
| 2842 |
|
| 2843 |
-
const uint64_t nb50 = src5->nb[0];
|
| 2844 |
const uint64_t nb51 = src5->nb[1];
|
| 2845 |
const uint64_t nb52 = src5->nb[2];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2846 |
|
| 2847 |
const int64_t d_state = ne00;
|
| 2848 |
const int64_t d_inner = ne01;
|
| 2849 |
-
const int64_t
|
| 2850 |
-
const int64_t
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2851 |
|
| 2852 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2853 |
|
| 2854 |
ggml_metal_kargs_ssm_scan args = {
|
| 2855 |
-
/*.d_state
|
| 2856 |
-
/*.d_inner
|
|
|
|
|
|
|
| 2857 |
/*.n_seq_tokens =*/ n_seq_tokens,
|
| 2858 |
-
/*.n_seqs
|
| 2859 |
-
/*.
|
| 2860 |
-
/*.
|
| 2861 |
-
/*.
|
| 2862 |
-
/*.
|
| 2863 |
-
/*.
|
| 2864 |
-
/*.
|
| 2865 |
-
/*.
|
| 2866 |
-
/*.
|
| 2867 |
-
/*.
|
| 2868 |
-
/*.
|
| 2869 |
-
/*.
|
| 2870 |
-
/*.
|
| 2871 |
-
/*.
|
| 2872 |
-
/*.
|
| 2873 |
-
/*.
|
| 2874 |
-
/*.nb50 =*/ nb50,
|
| 2875 |
-
/*.nb51 =*/ nb51,
|
| 2876 |
-
/*.nb52 =*/ nb52,
|
| 2877 |
};
|
| 2878 |
|
| 2879 |
[encoder setComputePipelineState:pipeline];
|
|
@@ -2883,10 +2905,17 @@ static bool ggml_metal_encode_node(
|
|
| 2883 |
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
| 2884 |
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
| 2885 |
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
| 2886 |
-
[encoder setBuffer:
|
| 2887 |
-
[encoder
|
|
|
|
| 2888 |
|
| 2889 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2890 |
} break;
|
| 2891 |
case GGML_OP_RWKV_WKV6:
|
| 2892 |
{
|
|
|
|
| 217 |
GGML_METAL_KERNEL_TYPE_NORM,
|
| 218 |
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
| 219 |
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
| 220 |
+
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
|
| 221 |
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
| 222 |
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
| 223 |
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
|
|
|
| 1197 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
| 1198 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
| 1199 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
| 1200 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
|
| 1201 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
| 1202 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
| 1203 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
|
|
|
| 2811 |
struct ggml_tensor * src3 = node->src[3];
|
| 2812 |
struct ggml_tensor * src4 = node->src[4];
|
| 2813 |
struct ggml_tensor * src5 = node->src[5];
|
| 2814 |
+
struct ggml_tensor * src6 = node->src[6];
|
| 2815 |
|
| 2816 |
GGML_ASSERT(src3);
|
| 2817 |
GGML_ASSERT(src4);
|
| 2818 |
GGML_ASSERT(src5);
|
| 2819 |
+
GGML_ASSERT(src6);
|
| 2820 |
|
| 2821 |
size_t offs_src3 = 0;
|
| 2822 |
size_t offs_src4 = 0;
|
| 2823 |
size_t offs_src5 = 0;
|
| 2824 |
+
size_t offs_src6 = 0;
|
| 2825 |
|
| 2826 |
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
| 2827 |
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
| 2828 |
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
| 2829 |
+
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
|
| 2830 |
|
| 2831 |
+
const int64_t ne30 = src3->ne[0];
|
| 2832 |
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
| 2833 |
|
| 2834 |
+
const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
|
| 2835 |
const uint64_t nb31 = src3->nb[1];
|
| 2836 |
|
| 2837 |
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
| 2838 |
+
const int64_t ne41 = src4->ne[1];
|
| 2839 |
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
| 2840 |
+
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
|
| 2841 |
|
| 2842 |
+
const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
|
| 2843 |
const uint64_t nb41 = src4->nb[1];
|
| 2844 |
const uint64_t nb42 = src4->nb[2];
|
| 2845 |
+
const uint64_t nb43 = src4->nb[3];
|
| 2846 |
|
| 2847 |
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
|
| 2848 |
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
|
| 2849 |
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
| 2850 |
+
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
|
| 2851 |
|
| 2852 |
+
const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
|
| 2853 |
const uint64_t nb51 = src5->nb[1];
|
| 2854 |
const uint64_t nb52 = src5->nb[2];
|
| 2855 |
+
const uint64_t nb53 = src5->nb[3];
|
| 2856 |
+
|
| 2857 |
+
const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
|
| 2858 |
+
|
| 2859 |
+
const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
|
| 2860 |
|
| 2861 |
const int64_t d_state = ne00;
|
| 2862 |
const int64_t d_inner = ne01;
|
| 2863 |
+
const int64_t n_head = ne02;
|
| 2864 |
+
const int64_t n_group = ne41;
|
| 2865 |
+
const int64_t n_seq_tokens = ne12;
|
| 2866 |
+
const int64_t n_seqs = ne13;
|
| 2867 |
+
|
| 2868 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 2869 |
|
| 2870 |
+
if (ne30 == 1) {
|
| 2871 |
+
// Mamba-2
|
| 2872 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
|
| 2873 |
+
} else {
|
| 2874 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
| 2875 |
+
}
|
| 2876 |
|
| 2877 |
ggml_metal_kargs_ssm_scan args = {
|
| 2878 |
+
/*.d_state =*/ d_state,
|
| 2879 |
+
/*.d_inner =*/ d_inner,
|
| 2880 |
+
/*.n_head =*/ n_head,
|
| 2881 |
+
/*.n_group =*/ n_group,
|
| 2882 |
/*.n_seq_tokens =*/ n_seq_tokens,
|
| 2883 |
+
/*.n_seqs =*/ n_seqs,
|
| 2884 |
+
/*.nb01 =*/ nb01,
|
| 2885 |
+
/*.nb02 =*/ nb02,
|
| 2886 |
+
/*.nb03 =*/ nb03,
|
| 2887 |
+
/*.nb11 =*/ nb11,
|
| 2888 |
+
/*.nb12 =*/ nb12,
|
| 2889 |
+
/*.nb13 =*/ nb13,
|
| 2890 |
+
/*.nb21 =*/ nb21,
|
| 2891 |
+
/*.nb22 =*/ nb22,
|
| 2892 |
+
/*.nb31 =*/ nb31,
|
| 2893 |
+
/*.nb41 =*/ nb41,
|
| 2894 |
+
/*.nb42 =*/ nb42,
|
| 2895 |
+
/*.nb43 =*/ nb43,
|
| 2896 |
+
/*.nb51 =*/ nb51,
|
| 2897 |
+
/*.nb52 =*/ nb52,
|
| 2898 |
+
/*.nb53 =*/ nb53,
|
|
|
|
|
|
|
|
|
|
| 2899 |
};
|
| 2900 |
|
| 2901 |
[encoder setComputePipelineState:pipeline];
|
|
|
|
| 2905 |
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
| 2906 |
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
| 2907 |
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
| 2908 |
+
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
| 2909 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
| 2910 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
| 2911 |
|
| 2912 |
+
if (ne30 == 1) {
|
| 2913 |
+
// Mamba-2
|
| 2914 |
+
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2915 |
+
} else {
|
| 2916 |
+
GGML_ASSERT(d_inner == 1);
|
| 2917 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 2918 |
+
}
|
| 2919 |
} break;
|
| 2920 |
case GGML_OP_RWKV_WKV6:
|
| 2921 |
{
|
|
@@ -1596,7 +1596,7 @@ kernel void kernel_ssm_conv_f32(
|
|
| 1596 |
x[0] = sumf;
|
| 1597 |
}
|
| 1598 |
|
| 1599 |
-
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
|
| 1600 |
kernel void kernel_ssm_scan_f32(
|
| 1601 |
device const void * src0,
|
| 1602 |
device const void * src1,
|
|
@@ -1604,46 +1604,119 @@ kernel void kernel_ssm_scan_f32(
|
|
| 1604 |
device const void * src3,
|
| 1605 |
device const void * src4,
|
| 1606 |
device const void * src5,
|
|
|
|
| 1607 |
device float * dst,
|
| 1608 |
constant ggml_metal_kargs_ssm_scan & args,
|
| 1609 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1610 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1611 |
uint3 ntg[[threads_per_threadgroup]]) {
|
| 1612 |
-
const int64_t
|
| 1613 |
-
const int64_t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1614 |
|
| 1615 |
const int64_t nc = args.d_state;
|
| 1616 |
-
|
|
|
|
|
|
|
| 1617 |
const int64_t n_t = args.n_seq_tokens;
|
| 1618 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1619 |
|
| 1620 |
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
| 1621 |
-
device const float *
|
| 1622 |
-
device const float *
|
| 1623 |
-
device const float *
|
| 1624 |
-
device const float *
|
| 1625 |
-
device const float *
|
| 1626 |
-
device
|
| 1627 |
-
|
| 1628 |
-
|
| 1629 |
-
|
| 1630 |
-
if (i2 > 0) {
|
| 1631 |
-
s0 = s;
|
| 1632 |
-
}
|
| 1633 |
-
|
| 1634 |
-
// i1 == 0
|
| 1635 |
-
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
| 1636 |
-
float x_dt = x[0] * dt_soft_plus;
|
| 1637 |
float sumf = 0.0f;
|
| 1638 |
|
| 1639 |
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
| 1640 |
-
int64_t i = i0;
|
| 1641 |
-
float state = (s0[i] * exp(dt_soft_plus * A[
|
| 1642 |
sumf += state * C[i0];
|
| 1643 |
s[i] = state;
|
| 1644 |
}
|
| 1645 |
|
| 1646 |
y[0] = sumf;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1647 |
}
|
| 1648 |
}
|
| 1649 |
|
|
|
|
| 1596 |
x[0] = sumf;
|
| 1597 |
}
|
| 1598 |
|
| 1599 |
+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
|
| 1600 |
kernel void kernel_ssm_scan_f32(
|
| 1601 |
device const void * src0,
|
| 1602 |
device const void * src1,
|
|
|
|
| 1604 |
device const void * src3,
|
| 1605 |
device const void * src4,
|
| 1606 |
device const void * src5,
|
| 1607 |
+
device const void * src6,
|
| 1608 |
device float * dst,
|
| 1609 |
constant ggml_metal_kargs_ssm_scan & args,
|
| 1610 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1611 |
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1612 |
uint3 ntg[[threads_per_threadgroup]]) {
|
| 1613 |
+
const int64_t i1 = 0;
|
| 1614 |
+
const int64_t ir = tgpig.x; // current head
|
| 1615 |
+
const int64_t i3 = tgpig.y; // current seq
|
| 1616 |
+
|
| 1617 |
+
const uint64_t nb00 = sizeof(float);
|
| 1618 |
+
const uint64_t nb10 = sizeof(float);
|
| 1619 |
+
const uint64_t nb20 = sizeof(float);
|
| 1620 |
|
| 1621 |
const int64_t nc = args.d_state;
|
| 1622 |
+
const int64_t nr = args.d_inner;
|
| 1623 |
+
const int64_t nh = args.n_head;
|
| 1624 |
+
const int64_t ng = args.n_group;
|
| 1625 |
const int64_t n_t = args.n_seq_tokens;
|
| 1626 |
+
|
| 1627 |
+
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
| 1628 |
+
|
| 1629 |
+
device const int32_t * ids = (device const int32_t *) src6;
|
| 1630 |
+
|
| 1631 |
+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
| 1632 |
+
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
| 1633 |
|
| 1634 |
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
| 1635 |
+
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
| 1636 |
+
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
| 1637 |
+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
|
| 1638 |
+
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
| 1639 |
+
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
| 1640 |
+
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
| 1641 |
+
|
| 1642 |
+
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
| 1643 |
+
const float x_dt = x[0] * dt_soft_plus;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1644 |
float sumf = 0.0f;
|
| 1645 |
|
| 1646 |
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
| 1647 |
+
const int64_t i = i0 + i1*nc;
|
| 1648 |
+
const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
| 1649 |
sumf += state * C[i0];
|
| 1650 |
s[i] = state;
|
| 1651 |
}
|
| 1652 |
|
| 1653 |
y[0] = sumf;
|
| 1654 |
+
|
| 1655 |
+
// recurse
|
| 1656 |
+
s0 = s;
|
| 1657 |
+
}
|
| 1658 |
+
}
|
| 1659 |
+
|
| 1660 |
+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
| 1661 |
+
// TODO: optimize (e.g. by parallelizing over d_state)
|
| 1662 |
+
kernel void kernel_ssm_scan_f32_group(
|
| 1663 |
+
device const void * src0,
|
| 1664 |
+
device const void * src1,
|
| 1665 |
+
device const void * src2,
|
| 1666 |
+
device const void * src3,
|
| 1667 |
+
device const void * src4,
|
| 1668 |
+
device const void * src5,
|
| 1669 |
+
device const void * src6,
|
| 1670 |
+
device float * dst,
|
| 1671 |
+
constant ggml_metal_kargs_ssm_scan & args,
|
| 1672 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1673 |
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
| 1674 |
+
uint3 ntg[[threads_per_threadgroup]]) {
|
| 1675 |
+
const int64_t i1 = tgpig.x;
|
| 1676 |
+
const int64_t ir = tgpig.y; // current head
|
| 1677 |
+
const int64_t i3 = tgpig.z; // current seq
|
| 1678 |
+
|
| 1679 |
+
const uint64_t nb00 = sizeof(float);
|
| 1680 |
+
const uint64_t nb10 = sizeof(float);
|
| 1681 |
+
const uint64_t nb20 = sizeof(float);
|
| 1682 |
+
|
| 1683 |
+
const int64_t nc = args.d_state;
|
| 1684 |
+
const int64_t nr = args.d_inner;
|
| 1685 |
+
const int64_t nh = args.n_head;
|
| 1686 |
+
const int64_t ng = args.n_group;
|
| 1687 |
+
const int64_t n_t = args.n_seq_tokens;
|
| 1688 |
+
|
| 1689 |
+
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
| 1690 |
+
|
| 1691 |
+
device const int32_t * ids = (device const int32_t *) src6;
|
| 1692 |
+
|
| 1693 |
+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
| 1694 |
+
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
| 1695 |
+
|
| 1696 |
+
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
| 1697 |
+
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
| 1698 |
+
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
| 1699 |
+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
| 1700 |
+
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
| 1701 |
+
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
| 1702 |
+
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
| 1703 |
+
|
| 1704 |
+
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
| 1705 |
+
const float x_dt = x[0] * dt_soft_plus;
|
| 1706 |
+
const float dA = exp(dt_soft_plus * A[0]);
|
| 1707 |
+
float sumf = 0.0f;
|
| 1708 |
+
|
| 1709 |
+
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
| 1710 |
+
const int64_t i = i0 + i1*nc;
|
| 1711 |
+
const float state = (s0[i] * dA) + (B[i0] * x_dt);
|
| 1712 |
+
sumf += state * C[i0];
|
| 1713 |
+
s[i] = state;
|
| 1714 |
+
}
|
| 1715 |
+
|
| 1716 |
+
y[0] = sumf;
|
| 1717 |
+
|
| 1718 |
+
// recurse
|
| 1719 |
+
s0 = s;
|
| 1720 |
}
|
| 1721 |
}
|
| 1722 |
|
|
@@ -4829,7 +4829,6 @@ struct ggml_tensor * ggml_ssm_conv(
|
|
| 4829 |
const int64_t n_s = sx->ne[2];
|
| 4830 |
|
| 4831 |
// TODO: maybe support other strides than 1?
|
| 4832 |
-
// FIXME: this is always true?
|
| 4833 |
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
|
| 4834 |
GGML_ASSERT(sx->ne[1] == d_inner);
|
| 4835 |
GGML_ASSERT(n_t >= 0);
|
|
@@ -4852,36 +4851,49 @@ struct ggml_tensor * ggml_ssm_scan(
|
|
| 4852 |
struct ggml_tensor * dt,
|
| 4853 |
struct ggml_tensor * A,
|
| 4854 |
struct ggml_tensor * B,
|
| 4855 |
-
struct ggml_tensor * C
|
|
|
|
| 4856 |
GGML_ASSERT(ggml_is_contiguous(s));
|
| 4857 |
-
GGML_ASSERT(ggml_is_contiguous(x));
|
| 4858 |
GGML_ASSERT(ggml_is_contiguous(dt));
|
| 4859 |
GGML_ASSERT(ggml_is_contiguous(A));
|
| 4860 |
-
GGML_ASSERT(
|
| 4861 |
-
GGML_ASSERT(ggml_is_3d(B));
|
| 4862 |
-
GGML_ASSERT(ggml_is_3d(s));
|
| 4863 |
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
| 4864 |
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
| 4865 |
-
GGML_ASSERT(
|
|
|
|
|
|
|
| 4866 |
GGML_ASSERT(ggml_are_same_shape(B, C));
|
|
|
|
| 4867 |
|
| 4868 |
{
|
| 4869 |
const int64_t d_state = s->ne[0];
|
| 4870 |
-
const int64_t
|
| 4871 |
-
const int64_t
|
| 4872 |
-
const int64_t
|
| 4873 |
-
|
| 4874 |
-
|
| 4875 |
-
GGML_ASSERT(
|
| 4876 |
-
GGML_ASSERT(
|
| 4877 |
-
GGML_ASSERT(
|
|
|
|
|
|
|
|
|
|
| 4878 |
GGML_ASSERT(B->ne[0] == d_state);
|
| 4879 |
-
GGML_ASSERT(B->ne[
|
| 4880 |
-
GGML_ASSERT(B->ne[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4881 |
}
|
| 4882 |
|
| 4883 |
// concatenated y + ssm_states
|
| 4884 |
-
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) +
|
| 4885 |
|
| 4886 |
result->op = GGML_OP_SSM_SCAN;
|
| 4887 |
result->src[0] = s;
|
|
@@ -4890,6 +4902,7 @@ struct ggml_tensor * ggml_ssm_scan(
|
|
| 4890 |
result->src[3] = A;
|
| 4891 |
result->src[4] = B;
|
| 4892 |
result->src[5] = C;
|
|
|
|
| 4893 |
|
| 4894 |
return result;
|
| 4895 |
}
|
|
|
|
| 4829 |
const int64_t n_s = sx->ne[2];
|
| 4830 |
|
| 4831 |
// TODO: maybe support other strides than 1?
|
|
|
|
| 4832 |
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
|
| 4833 |
GGML_ASSERT(sx->ne[1] == d_inner);
|
| 4834 |
GGML_ASSERT(n_t >= 0);
|
|
|
|
| 4851 |
struct ggml_tensor * dt,
|
| 4852 |
struct ggml_tensor * A,
|
| 4853 |
struct ggml_tensor * B,
|
| 4854 |
+
struct ggml_tensor * C,
|
| 4855 |
+
struct ggml_tensor * ids) {
|
| 4856 |
GGML_ASSERT(ggml_is_contiguous(s));
|
|
|
|
| 4857 |
GGML_ASSERT(ggml_is_contiguous(dt));
|
| 4858 |
GGML_ASSERT(ggml_is_contiguous(A));
|
| 4859 |
+
GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
|
|
|
|
|
|
|
| 4860 |
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
| 4861 |
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
| 4862 |
+
GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
|
| 4863 |
+
GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
|
| 4864 |
+
GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
|
| 4865 |
GGML_ASSERT(ggml_are_same_shape(B, C));
|
| 4866 |
+
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
| 4867 |
|
| 4868 |
{
|
| 4869 |
const int64_t d_state = s->ne[0];
|
| 4870 |
+
const int64_t head_dim = x->ne[0];
|
| 4871 |
+
const int64_t n_head = x->ne[1];
|
| 4872 |
+
const int64_t n_seq_tokens = x->ne[2];
|
| 4873 |
+
const int64_t n_seqs = x->ne[3];
|
| 4874 |
+
|
| 4875 |
+
GGML_ASSERT(dt->ne[0] == n_head);
|
| 4876 |
+
GGML_ASSERT(dt->ne[1] == n_seq_tokens);
|
| 4877 |
+
GGML_ASSERT(dt->ne[2] == n_seqs);
|
| 4878 |
+
GGML_ASSERT(ggml_is_3d(dt));
|
| 4879 |
+
GGML_ASSERT(s->ne[1] == head_dim);
|
| 4880 |
+
GGML_ASSERT(s->ne[2] == n_head);
|
| 4881 |
GGML_ASSERT(B->ne[0] == d_state);
|
| 4882 |
+
GGML_ASSERT(B->ne[2] == n_seq_tokens);
|
| 4883 |
+
GGML_ASSERT(B->ne[3] == n_seqs);
|
| 4884 |
+
GGML_ASSERT(ids->ne[0] == n_seqs);
|
| 4885 |
+
GGML_ASSERT(ggml_is_vector(ids));
|
| 4886 |
+
GGML_ASSERT(A->ne[1] == n_head);
|
| 4887 |
+
GGML_ASSERT(ggml_is_matrix(A));
|
| 4888 |
+
|
| 4889 |
+
if (A->ne[0] != 1) {
|
| 4890 |
+
// Mamba-1 has more granular decay factors
|
| 4891 |
+
GGML_ASSERT(A->ne[0] == d_state);
|
| 4892 |
+
}
|
| 4893 |
}
|
| 4894 |
|
| 4895 |
// concatenated y + ssm_states
|
| 4896 |
+
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
|
| 4897 |
|
| 4898 |
result->op = GGML_OP_SSM_SCAN;
|
| 4899 |
result->src[0] = s;
|
|
|
|
| 4902 |
result->src[3] = A;
|
| 4903 |
result->src[4] = B;
|
| 4904 |
result->src[5] = C;
|
| 4905 |
+
result->src[6] = ids;
|
| 4906 |
|
| 4907 |
return result;
|
| 4908 |
}
|