compilade commited on
Commit
1b4087e
·
1 Parent(s): 05351ac

llama : initial Mamba-2 support (llama/9126)

Browse files

* llama : initial Mamba-2 support

* ggml : SIMD ggml_ssm_scan for Mamba-2

* ggml : improve ggml_mul speed when masking recurrent states

* llama : support running Mamba-Codestral-7B-v0.1

* llama : fix Mamba-2 conv state saving

* ggml : make the ggml_mul fast broadcast path more consistently formatted

* llama : remove unused variable

* llama : add missing break

* convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present

The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires
workarounds to work correctly.

* llama : avoid redundant state copy for Mamba 1 and 2

* metal : attempt to adapt SSM_SCAN for Mamba-2

* metal : fix SSM_SCAN pipeline scope

* metal : use log and exp instead of log1pf and expf in SSM_SCAN

* metal : remove unused arguments for SSM_SCAN

The max index is 31, so trimming the arguments is necessary.

* metal : add back n_seqs to SSM_SCAN args

Whoops, this is needed for the offset in the concatenated output.

* metal : fix SSM_SCAN state head offset

* metal : fix wrong number of tokens per sequence in SSM_SCAN

* ggml : remove unused fast broadcast path in GGML_MUL

This was initially added because states were masked with ggml_mul,
but this is no longer done and so this "optimisation" is no longer
necessary, or at least not worth the additional code complexity.

* ggml : avoid multiply by D in GGML_OP_SSM_SCAN

This makes the weight buft detection in src/llama.cpp simpler.

* convert : transpose Mamba-2 A, D and reshape SSM_NORM

This breaks existing conversions of Mamba-2 models
to avoid some reshapes.

Not sure if it's a good idea,
but it makes the graph slightly cleaner.

* llama : more appropriate SSM_SCAN and SSM_CONV buft support checks

* convert : fix flake8 lint

* metal : fix confusion between ; and ,

* metal : add missing args for nb references in ssm_scan_f32_group

* metal : single-user mamba2 inference works

* kv-cache : remove const_cast when setting inputs for s_copy

And also fix multi-user inference for recurrent models
by using cell_id instead of i as the kv cell index
when populating s_copy.

* convert : avoid AutoConfig for Mamba and Mamba2 hparams

* kv-cache : allow context shift for recurrent models

* graph : fix recurrent state copies when avoiding copies

Works, but using lambda functions might not be that clean.

* ggml : fix mamba2 ssm scan when compiled with SVE

* ggml-cpu : reorder SVE FMA for consistency with other SIMD arches

* cuda : implement ssm scan for Mamba2

There is still room for improvement, but it works!

* cuda : adapt Mamba1 ssm scan to shape changes from Mamba2

* mamba : fix mismatched new and delete size for llm_build_mamba

Subclasses of llm_graph_context cannot have extra fields,
because the called destructor is not the one from the subclass.
This otherwise would cause problems when runnning Mamba-(1|2) inference
when compiled -DGGML_SANITIZE_ADDRESS=ON

* cuda : graceful fallback for Mamba-1 models with weird embd size

ggml/include/ggml.h CHANGED
@@ -2028,7 +2028,8 @@ extern "C" {
2028
  struct ggml_tensor * dt,
2029
  struct ggml_tensor * A,
2030
  struct ggml_tensor * B,
2031
- struct ggml_tensor * C);
 
2032
 
2033
  // partition into non-overlapping windows with padding if needed
2034
  // example:
 
2028
  struct ggml_tensor * dt,
2029
  struct ggml_tensor * A,
2030
  struct ggml_tensor * B,
2031
+ struct ggml_tensor * C,
2032
+ struct ggml_tensor * ids);
2033
 
2034
  // partition into non-overlapping windows with padding if needed
2035
  // example:
ggml/src/ggml-cpu/ops.cpp CHANGED
@@ -8337,120 +8337,210 @@ void ggml_compute_forward_ssm_conv(
8337
  static void ggml_compute_forward_ssm_scan_f32(
8338
  const ggml_compute_params * params,
8339
  ggml_tensor * dst) {
8340
- const ggml_tensor * src0 = dst->src[0]; // s
8341
- const ggml_tensor * src1 = dst->src[1]; // x
8342
- const ggml_tensor * src2 = dst->src[2]; // dt
8343
- const ggml_tensor * src3 = dst->src[3]; // A
8344
- const ggml_tensor * src4 = dst->src[4]; // B
8345
- const ggml_tensor * src5 = dst->src[5]; // C
 
8346
 
8347
  const int ith = params->ith;
8348
  const int nth = params->nth;
8349
 
8350
- const int64_t nc = src0->ne[0]; // d_state
8351
- const int64_t nr = src0->ne[1]; // d_inner
8352
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
8353
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
 
 
8354
 
8355
- GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
 
 
 
8356
  GGML_ASSERT(src0->nb[0] == sizeof(float));
8357
  GGML_ASSERT(src1->nb[0] == sizeof(float));
8358
  GGML_ASSERT(src2->nb[0] == sizeof(float));
8359
  GGML_ASSERT(src3->nb[0] == sizeof(float));
8360
  GGML_ASSERT(src4->nb[0] == sizeof(float));
8361
  GGML_ASSERT(src5->nb[0] == sizeof(float));
8362
- // required for the dot product between s and C
8363
- GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
8364
- // required for per-sequence offsets for states
8365
- GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
8366
- // required to get correct offset for state destination (i.e. src1->nb[3])
8367
- GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
8368
 
8369
- // rows per thread
8370
- const int dr = (nr + nth - 1)/nth;
8371
 
8372
- // row range for this thread
8373
- const int ir0 = dr*ith;
8374
- const int ir1 = MIN(ir0 + dr, nr);
8375
- const int ir = ir1 - ir0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8376
 
8377
- #ifdef __ARM_FEATURE_SVE
8378
- for (int i3 = 0; i3 < n_s; ++i3) {
8379
- for (int i2 = 0; i2 < n_t; ++i2) {
8380
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
8381
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8382
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
8383
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
8384
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
8385
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
8386
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8387
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
8388
-
8389
- // use the output as the source for the next token-wise iterations
8390
- if (i2 > 0) { s0 = s; }
8391
-
8392
- // d_inner
8393
- for (int i1 = 0; i1 < ir; ++i1) {
8394
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
8395
- float x_dt = x[i1] * dt_soft_plus;
8396
- svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8397
- svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8398
- svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8399
-
8400
- for (int64_t k = 0; k < nc; k += svcntw()) {
8401
- svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
8402
- svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
8403
- svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
8404
- svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
8405
-
8406
- svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8407
- t1 = exp_ps_sve(svptrue_b32(), t1);
8408
- svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8409
-
8410
- vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
8411
- r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8412
-
8413
- GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
8414
  }
8415
- y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
8416
  }
8417
- }
8418
- }
8419
- #else
8420
- for (int i3 = 0; i3 < n_s; ++i3) {
8421
- for (int i2 = 0; i2 < n_t; ++i2) {
8422
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
8423
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8424
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
8425
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
8426
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
8427
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
8428
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
8429
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
8430
-
8431
- // use the output as the source for the next token-wise iterations
8432
- if (i2 > 0) { s0 = s; }
8433
-
8434
- // d_inner
8435
- for (int i1 = 0; i1 < ir; ++i1) {
8436
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
8437
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
8438
- float x_dt = x[i1] * dt_soft_plus;
8439
- float sumf = 0.0f;
8440
- // d_state
8441
- for (int i0 = 0; i0 < nc; ++i0) {
8442
- int i = i0 + i1*nc;
8443
- // state = prev_state * dA + dB * x
8444
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
8445
- // y = rowwise_dotprod(state, C)
8446
- sumf += state * C[i0];
8447
- s[i] = state;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8448
  }
8449
- y[i1] = sumf;
8450
  }
8451
  }
 
 
8452
  }
8453
- #endif
8454
  }
8455
 
8456
  void ggml_compute_forward_ssm_scan(
 
8337
  static void ggml_compute_forward_ssm_scan_f32(
8338
  const ggml_compute_params * params,
8339
  ggml_tensor * dst) {
8340
+ const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
8341
+ const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
8342
+ const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
8343
+ const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
8344
+ const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
8345
+ const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
8346
+ const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
8347
 
8348
  const int ith = params->ith;
8349
  const int nth = params->nth;
8350
 
8351
+ const int64_t nc = src0->ne[0]; // d_state
8352
+ const int64_t nr = src0->ne[1]; // dim
8353
+ const int64_t nh = src1->ne[1]; // n_head
8354
+ const int64_t ng = src4->ne[1];
8355
+ const int64_t nt = src1->ne[2]; // number of tokens per sequence
8356
+ const int64_t ns = src1->ne[3]; // number of sequences in the batch
8357
 
8358
+ // can't use ggml_nbytes because src1 is not necessarily contiguous
8359
+ const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
8360
+
8361
+ GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
8362
  GGML_ASSERT(src0->nb[0] == sizeof(float));
8363
  GGML_ASSERT(src1->nb[0] == sizeof(float));
8364
  GGML_ASSERT(src2->nb[0] == sizeof(float));
8365
  GGML_ASSERT(src3->nb[0] == sizeof(float));
8366
  GGML_ASSERT(src4->nb[0] == sizeof(float));
8367
  GGML_ASSERT(src5->nb[0] == sizeof(float));
8368
+ GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
8369
+ // allows optimizing the modulo since n_group should be a power of 2
8370
+ GGML_ASSERT((ng & -ng) == ng);
 
 
 
8371
 
8372
+ // heads per thread
8373
+ const int dh = (nh + nth - 1)/nth;
8374
 
8375
+ // head range for this thread
8376
+ const int ih0 = dh*ith;
8377
+ const int ih1 = MIN(ih0 + dh, nh);
8378
+
8379
+ const int32_t * ids = (const int32_t *) src6->data;
8380
+
8381
+ for (int i3 = 0; i3 < ns; ++i3) {
8382
+ const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
8383
+ float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
8384
+
8385
+ for (int i2 = 0; i2 < nt; ++i2) {
8386
+ const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
8387
+ const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
8388
+ const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
8389
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
8390
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
8391
+ float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
8392
+
8393
+ if (src3->ne[0] == 1) {
8394
+ // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
8395
+
8396
+ // n_head
8397
+ for (int h = ih0; h < ih1; ++h) {
8398
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8399
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8400
+ const float dA = expf(dt_soft_plus * A[h]);
8401
+
8402
+ // dim
8403
+ for (int i1 = 0; i1 < nr; ++i1) {
8404
+ const int ii = i1 + h*nr;
8405
+ const float x_dt = x[ii] * dt_soft_plus;
8406
+ float sumf = 0.0f;
8407
+ #if defined(GGML_SIMD)
8408
+ #if defined(__ARM_FEATURE_SVE)
8409
+ const int ggml_f32_epr = svcntw();
8410
+ const int ggml_f32_step = 1 * ggml_f32_epr;
8411
+
8412
+ const int np = (nc & ~(ggml_f32_step - 1));
8413
+
8414
+ GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8415
+
8416
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8417
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8418
+
8419
+ for (int i = 0; i < np; i += ggml_f32_step) {
8420
+ // TODO: maybe unroll more?
8421
+ for (int j = 0; j < 1; j++) {
8422
+ GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
8423
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
8424
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
8425
+
8426
+ t0 = GGML_F32_VEC_MUL(t0, adA);
8427
+ t1 = GGML_F32_VEC_MUL(t1, axdt);
8428
+
8429
+ t0 = GGML_F32_VEC_ADD(t0, t1);
8430
+
8431
+ sum = GGML_F32_VEC_FMA(sum, t0, t2);
8432
+
8433
+ GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
8434
+ }
8435
+ }
8436
+
8437
+ sumf = GGML_F32xt_REDUCE_ONE(sum);
8438
+ #else
8439
+ const int np = (nc & ~(GGML_F32_STEP - 1));
8440
+
8441
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8442
+
8443
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8444
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8445
+
8446
+ GGML_F32_VEC ax[GGML_F32_ARR];
8447
+ GGML_F32_VEC ay[GGML_F32_ARR];
8448
+ GGML_F32_VEC az[GGML_F32_ARR];
8449
+
8450
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
8451
+ for (int j = 0; j < GGML_F32_ARR; j++) {
8452
+ ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
8453
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
8454
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
8455
+
8456
+ ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
8457
+ ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
8458
+
8459
+ ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
8460
 
8461
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
8462
+
8463
+ GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
8464
+ }
8465
+ }
8466
+
8467
+ // reduce sum0..sum3 to sum0
8468
+ GGML_F32_VEC_REDUCE(sumf, sum);
8469
+ #endif
8470
+ #else
8471
+ const int np = 0;
8472
+ #endif
8473
+ // d_state
8474
+ for (int i0 = np; i0 < nc; ++i0) {
8475
+ const int i = i0 + ii*nc;
8476
+ const int ig = i0 + (h & (ng - 1))*nc;
8477
+ // state = prev_state * dA + dB * x
8478
+ const float state = (s0[i] * dA) + (B[ig] * x_dt);
8479
+ // y = rowwise_dotprod(state, C)
8480
+ sumf += state * C[ig];
8481
+ s[i] = state;
8482
+ }
8483
+ y[ii] = sumf;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8484
  }
 
8485
  }
8486
+ } else {
8487
+ // Mamba-1 has an element-wise decay factor for the states
8488
+
8489
+ // n_head
8490
+ for (int h = ih0; h < ih1; ++h) {
8491
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8492
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
8493
+
8494
+ // dim
8495
+ for (int i1 = 0; i1 < nr; ++i1) {
8496
+ const int ii = i1 + h*nr;
8497
+ const float x_dt = x[ii] * dt_soft_plus;
8498
+ #if defined(__ARM_FEATURE_SVE)
8499
+ svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8500
+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8501
+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8502
+
8503
+ // d_state
8504
+ // TODO: what happens when (d_state % svcntw()) != 0?
8505
+ for (int64_t k = 0; k < nc; k += svcntw()) {
8506
+ svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
8507
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
8508
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
8509
+ svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
8510
+
8511
+ svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8512
+ t1 = exp_ps_sve(svptrue_b32(), t1);
8513
+ svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8514
+
8515
+ vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
8516
+ r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8517
+
8518
+ GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
8519
+ }
8520
+ y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
8521
+ #else
8522
+ float sumf = 0.0f;
8523
+ // NOTE: can't really use GGML_SIMD here because d_state is usually 16
8524
+ // and also because expf is used within the loop.
8525
+ // d_state
8526
+ for (int i0 = 0; i0 < nc; ++i0) {
8527
+ const int i = i0 + ii*nc;
8528
+ const int ig = i0 + (h & (ng - 1))*nc;
8529
+ // state = prev_state * dA + dB * x
8530
+ const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
8531
+ // y = rowwise_dotprod(state, C)
8532
+ sumf += state * C[ig];
8533
+ s[i] = state;
8534
+ }
8535
+ y[ii] = sumf;
8536
+ #endif
8537
  }
 
8538
  }
8539
  }
8540
+ // use the output as the source when it's not the first token-wise iteration
8541
+ s0 = s;
8542
  }
8543
+ }
8544
  }
8545
 
8546
  void ggml_compute_forward_ssm_scan(
ggml/src/ggml-cpu/simd-mappings.h CHANGED
@@ -189,7 +189,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
189
  #define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
190
  #define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
191
  #define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
192
- #define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c)
193
  #define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
194
  #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
195
  #define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
 
189
  #define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
190
  #define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
191
  #define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
192
+ #define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
193
  #define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
194
  #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
195
  #define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
ggml/src/ggml-cpu/vec.cpp CHANGED
@@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
37
  for (int i = 0; i < np; i += ggml_f32_step) {
38
  ax1 = GGML_F32_VEC_LOAD(x + i);
39
  ay1 = GGML_F32_VEC_LOAD(y + i);
40
- sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
41
 
42
  ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
43
  ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
44
- sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2);
45
 
46
  ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
47
  ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
48
- sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3);
49
 
50
  ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
51
  ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
52
- sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4);
53
 
54
  ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
55
  ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
56
- sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5);
57
 
58
  ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
59
  ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
60
- sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6);
61
 
62
  ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
63
  ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
64
- sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7);
65
 
66
  ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
67
  ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
68
- sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8);
69
  }
70
  // leftovers
71
  // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
@@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
73
  for (int i = np; i < np2; i += ggml_f32_epr) {
74
  ax1 = GGML_F32_VEC_LOAD(x + i);
75
  ay1 = GGML_F32_VEC_LOAD(y + i);
76
- sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
77
  }
78
  // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
79
  if (np2 < n) {
 
37
  for (int i = 0; i < np; i += ggml_f32_step) {
38
  ax1 = GGML_F32_VEC_LOAD(x + i);
39
  ay1 = GGML_F32_VEC_LOAD(y + i);
40
+ sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
41
 
42
  ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
43
  ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
44
+ sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
45
 
46
  ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
47
  ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
48
+ sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
49
 
50
  ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
51
  ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
52
+ sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
53
 
54
  ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
55
  ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
56
+ sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
57
 
58
  ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
59
  ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
60
+ sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
61
 
62
  ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
63
  ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
64
+ sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
65
 
66
  ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
67
  ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
68
+ sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
69
  }
70
  // leftovers
71
  // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
 
73
  for (int i = np; i < np2; i += ggml_f32_epr) {
74
  ax1 = GGML_F32_VEC_LOAD(x + i);
75
  ay1 = GGML_F32_VEC_LOAD(y + i);
76
+ sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
77
  }
78
  // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
79
  if (np2 < n) {
ggml/src/ggml-cpu/vec.h CHANGED
@@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
163
 
164
  ax1 = GGML_F32_VEC_LOAD(x + i);
165
  ay1 = GGML_F32_VEC_LOAD(y + i);
166
- ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
167
 
168
  GGML_F32_VEC_STORE(y + i, ay1);
169
 
170
  ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
171
  ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
172
- ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2);
173
 
174
  GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
175
 
176
  ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
177
  ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
178
- ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3);
179
 
180
  GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
181
 
182
  ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
183
  ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
184
- ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4);
185
 
186
  GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
187
 
188
  ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
189
  ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
190
- ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5);
191
 
192
  GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
193
 
194
  ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
195
  ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
196
- ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6);
197
 
198
  GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
199
 
200
  ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
201
  ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
202
- ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7);
203
 
204
  GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
205
 
206
  ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
207
  ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
208
- ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8);
209
 
210
  GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
211
  }
@@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
215
  for (int i = np; i < np2; i += ggml_f32_epr) {
216
  ax1 = GGML_F32_VEC_LOAD(x + i);
217
  ay1 = GGML_F32_VEC_LOAD(y + i);
218
- ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
219
 
220
  GGML_F32_VEC_STORE(y + i, ay1);
221
  }
 
163
 
164
  ax1 = GGML_F32_VEC_LOAD(x + i);
165
  ay1 = GGML_F32_VEC_LOAD(y + i);
166
+ ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
167
 
168
  GGML_F32_VEC_STORE(y + i, ay1);
169
 
170
  ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
171
  ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
172
+ ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
173
 
174
  GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
175
 
176
  ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
177
  ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
178
+ ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
179
 
180
  GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
181
 
182
  ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
183
  ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
184
+ ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
185
 
186
  GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
187
 
188
  ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
189
  ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
190
+ ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
191
 
192
  GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
193
 
194
  ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
195
  ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
196
+ ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
197
 
198
  GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
199
 
200
  ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
201
  ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
202
+ ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
203
 
204
  GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
205
 
206
  ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
207
  ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
208
+ ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
209
 
210
  GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
211
  }
 
215
  for (int i = np; i < np2; i += ggml_f32_epr) {
216
  ax1 = GGML_F32_VEC_LOAD(x + i);
217
  ay1 = GGML_F32_VEC_LOAD(y + i);
218
+ ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
219
 
220
  GGML_F32_VEC_STORE(y + i, ay1);
221
  }
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3321,9 +3321,22 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3321
  case GGML_OP_COS:
3322
  case GGML_OP_CLAMP:
3323
  case GGML_OP_LOG:
3324
- case GGML_OP_SSM_SCAN:
3325
- case GGML_OP_SSM_CONV:
3326
  return true;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3327
  case GGML_OP_CONT:
3328
  return op->src[0]->type != GGML_TYPE_BF16;
3329
  case GGML_OP_DIAG_MASK_INF:
 
3321
  case GGML_OP_COS:
3322
  case GGML_OP_CLAMP:
3323
  case GGML_OP_LOG:
 
 
3324
  return true;
3325
+ case GGML_OP_SSM_SCAN: {
3326
+ if (op->src[3]->ne[0] == 1) {
3327
+ // Mamba2
3328
+ // (kernel only supports d_state == 128 && d_head % 16 == 0)
3329
+ return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
3330
+ } else {
3331
+ // Mamba
3332
+ // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
3333
+ return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
3334
+ }
3335
+ }
3336
+ case GGML_OP_SSM_CONV: {
3337
+ // assumes d_inner % threads == 0
3338
+ return op->src[0]->ne[1] % 128 == 0;
3339
+ }
3340
  case GGML_OP_CONT:
3341
  return op->src[0]->type != GGML_TYPE_BF16;
3342
  case GGML_OP_DIAG_MASK_INF:
ggml/src/ggml-cuda/ssm-scan.cu CHANGED
@@ -4,16 +4,15 @@ template <size_t splitD, size_t N>
4
  __global__ void __launch_bounds__(splitD, 2)
5
  ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
6
  const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
7
- const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
8
- const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
9
- const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
10
- float * __restrict__ dst, const int64_t L) {
11
- GGML_UNUSED(src1_nb0);
12
- GGML_UNUSED(src2_nb0);
13
 
14
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
15
- const int bidx = blockIdx.x; // split along B
16
- const int bidy = blockIdx.y; // split along D
17
  const int tid = threadIdx.x;
18
  const int wid = tid / 32;
19
  const int wtid = tid % 32;
@@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2)
24
  float * smem_A = smem;
25
  float * smem_s0 = smem_A + splitD * stride_sA;
26
 
27
- const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
28
- const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
29
  const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
30
  const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
31
- const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2));
32
- const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2));
33
- float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
34
- float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
35
 
36
- const int stride_s0 = src0_nb1 / sizeof(float);
37
- const int stride_x = src1_nb1 / sizeof(float);
38
  const int stride_dt = src2_nb1 / sizeof(float);
39
  const int stride_A = src3_nb1 / sizeof(float);
40
- const int stride_B = src4_nb1 / sizeof(float);
41
- const int stride_C = src5_nb1 / sizeof(float);
42
  const int stride_s = stride_s0;
43
- const int stride_y = stride_x;
44
 
45
  // can N not be 16? for example 32?
46
  if (N == 16) {
@@ -84,24 +83,156 @@ __global__ void __launch_bounds__(splitD, 2)
84
  }
85
  }
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
88
- const float * src4, const float * src5, const int src0_nb1, const int src0_nb2,
89
- const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
90
- const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
91
- const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
92
- float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
93
  cudaStream_t stream) {
94
  const int threads = 128;
95
- // todo: consider D cannot be divided,does this situation exist?
96
- GGML_ASSERT(D % threads == 0);
97
- const dim3 blocks(B, (D + threads - 1) / threads, 1);
98
- const int smem_size = (threads * (N + 1) * 2) * sizeof(float);
99
- if (N == 16) {
100
- ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
101
- src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
102
- src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
 
 
 
 
 
 
 
 
103
  } else {
104
- GGML_ABORT("doesn't support N!=16.");
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  }
106
  }
107
 
@@ -112,30 +243,25 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
112
  const struct ggml_tensor * src3 = dst->src[3]; // A
113
  const struct ggml_tensor * src4 = dst->src[4]; // B
114
  const struct ggml_tensor * src5 = dst->src[5]; // C
115
-
116
- // const int64_t d_state = src0->ne[0];
117
- // const int64_t d_inner = src0->ne[1];
118
- // const int64_t l = src1->ne[1];
119
- // const int64_t b = src0->ne[2];
120
 
121
  const int64_t nc = src0->ne[0]; // d_state
122
- const int64_t nr = src0->ne[1]; // d_inner
123
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
124
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
 
 
 
 
125
 
126
- GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
127
  GGML_ASSERT(src0->nb[0] == sizeof(float));
128
  GGML_ASSERT(src1->nb[0] == sizeof(float));
129
  GGML_ASSERT(src2->nb[0] == sizeof(float));
130
  GGML_ASSERT(src3->nb[0] == sizeof(float));
131
  GGML_ASSERT(src4->nb[0] == sizeof(float));
132
  GGML_ASSERT(src5->nb[0] == sizeof(float));
133
- // required for the dot product between s and C
134
- GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
135
- // required for per-sequence offsets for states
136
- GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
137
- // required to get correct offset for state destination (i.e. src1->nb[3])
138
- GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
139
 
140
  const float * src0_d = (const float *) src0->data;
141
  const float * src1_d = (const float *) src1->data;
@@ -143,13 +269,16 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
143
  const float * src3_d = (const float *) src3->data;
144
  const float * src4_d = (const float *) src4->data;
145
  const float * src5_d = (const float *) src5->data;
 
146
  float * dst_d = (float *) dst->data;
147
  cudaStream_t stream = ctx.stream();
148
 
149
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
 
150
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
151
 
152
- ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0],
153
- src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1],
154
- src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
 
155
  }
 
4
  __global__ void __launch_bounds__(splitD, 2)
5
  ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
6
  const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
7
+ const int32_t * __restrict__ src6, float * __restrict__ dst,
8
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
9
+ const int src2_nb1, const int src2_nb2, const int src3_nb1,
10
+ const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
11
+ const int64_t s_off, const int64_t d_inner, const int64_t L) {
 
12
 
13
  constexpr int warp_size = ggml_cuda_get_physical_warp_size();
14
+ const int bidx = blockIdx.x; // split along B (sequences)
15
+ const int bidy = blockIdx.y; // split along D (d_inner)
16
  const int tid = threadIdx.x;
17
  const int wid = tid / 32;
18
  const int wtid = tid % 32;
 
23
  float * smem_A = smem;
24
  float * smem_s0 = smem_A + splitD * stride_sA;
25
 
26
+ const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
27
+ const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
28
  const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
29
  const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
30
+ const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
31
+ const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
32
+ float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
33
+ float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
34
 
35
+ const int stride_s0 = src0_nb2 / sizeof(float);
36
+ const int stride_x = src1_nb2 / sizeof(float);
37
  const int stride_dt = src2_nb1 / sizeof(float);
38
  const int stride_A = src3_nb1 / sizeof(float);
39
+ const int stride_B = src4_nb2 / sizeof(float);
40
+ const int stride_C = src5_nb2 / sizeof(float);
41
  const int stride_s = stride_s0;
42
+ const int stride_y = d_inner;
43
 
44
  // can N not be 16? for example 32?
45
  if (N == 16) {
 
83
  }
84
  }
85
 
86
+ // assumes as many threads as d_state
87
+ template <int splitH, int d_state>
88
+ __global__ void __launch_bounds__(d_state, 1)
89
+ ssm_scan_f32_group(
90
+ const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
91
+ const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
92
+ const int32_t * __restrict__ src6, float * __restrict__ dst,
93
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
94
+ const int src2_nb1, const int src2_nb2, const int src3_nb1,
95
+ const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
96
+ const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
97
+
98
+ const int head_idx = (blockIdx.x * splitH) / d_head;
99
+ const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
100
+ const int seq_idx = blockIdx.y;
101
+
102
+ const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
103
+
104
+ const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
105
+ const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
106
+ const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
107
+ const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
108
+ const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
109
+ const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
110
+ float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
111
+ float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
112
+
113
+ // strides across n_seq_tokens
114
+ const int stride_x = src1_nb2 / sizeof(float);
115
+ const int stride_dt = src2_nb1 / sizeof(float);
116
+ const int stride_B = src4_nb2 / sizeof(float);
117
+ const int stride_C = src5_nb2 / sizeof(float);
118
+ const int stride_y = n_head * d_head;
119
+
120
+ float state[splitH];
121
+ // for the parallel accumulation
122
+ __shared__ float stateC[splitH * d_state];
123
+
124
+ #pragma unroll
125
+ for (int j = 0; j < splitH; j++) {
126
+ state[j] = s0_block[j * d_state + threadIdx.x];
127
+ }
128
+
129
+ for (int64_t i = 0; i < n_tok; i++) {
130
+ // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
131
+ // TODO: only calculate B and C once per head group
132
+ // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
133
+ float dt_soft_plus = dt_block[i * stride_dt];
134
+ if (dt_soft_plus <= 20.0f) {
135
+ dt_soft_plus = log1pf(expf(dt_soft_plus));
136
+ }
137
+ const float dA = expf(dt_soft_plus * A_block[0]);
138
+ const float B = B_block[i * stride_B + threadIdx.x];
139
+ const float C = C_block[i * stride_C + threadIdx.x];
140
+
141
+ // across d_head
142
+ #pragma unroll
143
+ for (int j = 0; j < splitH; j++) {
144
+ const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
145
+
146
+ state[j] = (state[j] * dA) + (B * x_dt);
147
+
148
+ stateC[j * d_state + threadIdx.x] = state[j] * C;
149
+ }
150
+
151
+ __syncthreads();
152
+
153
+ // parallel accumulation for stateC
154
+ // TODO: simplify
155
+ {
156
+ static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
157
+ static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
158
+
159
+ // reduce until w matches the warp size
160
+ // TODO: does this work even when the physical warp size is 64?
161
+ #pragma unroll
162
+ for (int w = d_state; w > WARP_SIZE; w >>= 1) {
163
+ // (assuming there are d_state threads)
164
+ #pragma unroll
165
+ for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
166
+ // TODO: check for bank conflicts
167
+ const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
168
+ stateC[k] += stateC[k + (w >> 1)];
169
+
170
+ }
171
+ __syncthreads();
172
+ }
173
+
174
+ static_assert(splitH >= d_state / WARP_SIZE);
175
+
176
+ #pragma unroll
177
+ for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
178
+ float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
179
+ y = warp_reduce_sum(y);
180
+
181
+ // store the above accumulations
182
+ if (threadIdx.x % WARP_SIZE == 0) {
183
+ const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
184
+ y_block[i * stride_y + k] = y;
185
+ }
186
+ }
187
+ }
188
+ }
189
+
190
+ // write back the state
191
+ #pragma unroll
192
+ for (int j = 0; j < splitH; j++) {
193
+ s_block[j * d_state + threadIdx.x] = state[j];
194
+ }
195
+ }
196
+
197
  static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
198
+ const float * src4, const float * src5, const int32_t * src6, float * dst,
199
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
200
+ const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
201
+ const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
202
+ const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
203
  cudaStream_t stream) {
204
  const int threads = 128;
205
+ // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
206
+ if (src3_nb1 == sizeof(float)) {
207
+ // Mamba-2
208
+ if (d_state == 128) {
209
+ GGML_ASSERT(d_state % threads == 0);
210
+ // NOTE: can be any power of two between 4 and 64
211
+ const int splitH = 16;
212
+ GGML_ASSERT(head_dim % splitH == 0);
213
+ const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
214
+ ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
215
+ src0, src1, src2, src3, src4, src5, src6, dst,
216
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
217
+ src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
218
+ } else {
219
+ GGML_ABORT("doesn't support d_state!=128.");
220
+ }
221
  } else {
222
+ // Mamba-1
223
+ GGML_ASSERT(n_head % threads == 0);
224
+ GGML_ASSERT(head_dim == 1);
225
+ GGML_ASSERT(n_group == 1);
226
+ const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
227
+ const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
228
+ if (d_state == 16) {
229
+ ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
230
+ src0, src1, src2, src3, src4, src5, src6, dst,
231
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
232
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
233
+ } else {
234
+ GGML_ABORT("doesn't support d_state!=16.");
235
+ }
236
  }
237
  }
238
 
 
243
  const struct ggml_tensor * src3 = dst->src[3]; // A
244
  const struct ggml_tensor * src4 = dst->src[4]; // B
245
  const struct ggml_tensor * src5 = dst->src[5]; // C
246
+ const struct ggml_tensor * src6 = dst->src[6]; // ids
 
 
 
 
247
 
248
  const int64_t nc = src0->ne[0]; // d_state
249
+ const int64_t nr = src0->ne[1]; // head_dim or 1
250
+ const int64_t nh = src1->ne[1]; // n_head
251
+ const int64_t ng = src4->ne[1]; // n_group
252
+ const int64_t n_t = src1->ne[2]; // number of tokens per sequence
253
+ const int64_t n_s = src1->ne[3]; // number of sequences in the batch
254
+
255
+ const int64_t s_off = ggml_nelements(src1) * sizeof(float);
256
 
257
+ GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
258
  GGML_ASSERT(src0->nb[0] == sizeof(float));
259
  GGML_ASSERT(src1->nb[0] == sizeof(float));
260
  GGML_ASSERT(src2->nb[0] == sizeof(float));
261
  GGML_ASSERT(src3->nb[0] == sizeof(float));
262
  GGML_ASSERT(src4->nb[0] == sizeof(float));
263
  GGML_ASSERT(src5->nb[0] == sizeof(float));
264
+ GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
 
 
 
 
 
265
 
266
  const float * src0_d = (const float *) src0->data;
267
  const float * src1_d = (const float *) src1->data;
 
269
  const float * src3_d = (const float *) src3->data;
270
  const float * src4_d = (const float *) src4->data;
271
  const float * src5_d = (const float *) src5->data;
272
+ const int32_t * src6_d = (const int32_t *) src6->data;
273
  float * dst_d = (float *) dst->data;
274
  cudaStream_t stream = ctx.stream();
275
 
276
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
277
+ GGML_ASSERT(src6->type == GGML_TYPE_I32);
278
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
279
 
280
+ ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
281
+ src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
282
+ src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
283
+ s_off, nc, nr, nh, ng, n_t, n_s, stream);
284
  }
ggml/src/ggml-metal/ggml-metal-impl.h CHANGED
@@ -513,26 +513,25 @@ typedef struct {
513
  typedef struct {
514
  int64_t d_state;
515
  int64_t d_inner;
 
 
516
  int64_t n_seq_tokens;
517
  int64_t n_seqs;
518
- uint64_t nb00;
519
  uint64_t nb01;
520
  uint64_t nb02;
521
- uint64_t nb10;
522
  uint64_t nb11;
523
  uint64_t nb12;
524
  uint64_t nb13;
525
- uint64_t nb20;
526
  uint64_t nb21;
527
  uint64_t nb22;
528
- uint64_t nb30;
529
  uint64_t nb31;
530
- uint64_t nb40;
531
  uint64_t nb41;
532
  uint64_t nb42;
533
- uint64_t nb50;
534
  uint64_t nb51;
535
  uint64_t nb52;
 
536
  } ggml_metal_kargs_ssm_scan;
537
 
538
  typedef struct {
 
513
  typedef struct {
514
  int64_t d_state;
515
  int64_t d_inner;
516
+ int64_t n_head;
517
+ int64_t n_group;
518
  int64_t n_seq_tokens;
519
  int64_t n_seqs;
 
520
  uint64_t nb01;
521
  uint64_t nb02;
522
+ uint64_t nb03;
523
  uint64_t nb11;
524
  uint64_t nb12;
525
  uint64_t nb13;
 
526
  uint64_t nb21;
527
  uint64_t nb22;
 
528
  uint64_t nb31;
 
529
  uint64_t nb41;
530
  uint64_t nb42;
531
+ uint64_t nb43;
532
  uint64_t nb51;
533
  uint64_t nb52;
534
+ uint64_t nb53;
535
  } ggml_metal_kargs_ssm_scan;
536
 
537
  typedef struct {
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -217,6 +217,7 @@ enum ggml_metal_kernel_type {
217
  GGML_METAL_KERNEL_TYPE_NORM,
218
  GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
219
  GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
 
220
  GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
221
  GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
222
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
@@ -1196,6 +1197,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1196
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
1197
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
1198
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
 
1199
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
1200
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
1201
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
@@ -2809,71 +2811,91 @@ static bool ggml_metal_encode_node(
2809
  struct ggml_tensor * src3 = node->src[3];
2810
  struct ggml_tensor * src4 = node->src[4];
2811
  struct ggml_tensor * src5 = node->src[5];
 
2812
 
2813
  GGML_ASSERT(src3);
2814
  GGML_ASSERT(src4);
2815
  GGML_ASSERT(src5);
 
2816
 
2817
  size_t offs_src3 = 0;
2818
  size_t offs_src4 = 0;
2819
  size_t offs_src5 = 0;
 
2820
 
2821
  id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
2822
  id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
2823
  id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
 
2824
 
2825
- const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
2826
  const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
2827
 
2828
- const uint64_t nb30 = src3->nb[0];
2829
  const uint64_t nb31 = src3->nb[1];
2830
 
2831
  const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
2832
- const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
2833
  const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
 
2834
 
2835
- const uint64_t nb40 = src4->nb[0];
2836
  const uint64_t nb41 = src4->nb[1];
2837
  const uint64_t nb42 = src4->nb[2];
 
2838
 
2839
  const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
2840
  const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
2841
  const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
 
2842
 
2843
- const uint64_t nb50 = src5->nb[0];
2844
  const uint64_t nb51 = src5->nb[1];
2845
  const uint64_t nb52 = src5->nb[2];
 
 
 
 
 
2846
 
2847
  const int64_t d_state = ne00;
2848
  const int64_t d_inner = ne01;
2849
- const int64_t n_seq_tokens = ne11;
2850
- const int64_t n_seqs = ne02;
 
 
 
 
2851
 
2852
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
 
 
 
 
 
2853
 
2854
  ggml_metal_kargs_ssm_scan args = {
2855
- /*.d_state =*/ d_state,
2856
- /*.d_inner =*/ d_inner,
 
 
2857
  /*.n_seq_tokens =*/ n_seq_tokens,
2858
- /*.n_seqs =*/ n_seqs,
2859
- /*.nb00 =*/ nb00,
2860
- /*.nb01 =*/ nb01,
2861
- /*.nb02 =*/ nb02,
2862
- /*.nb10 =*/ nb10,
2863
- /*.nb11 =*/ nb11,
2864
- /*.nb12 =*/ nb12,
2865
- /*.nb13 =*/ nb13,
2866
- /*.nb20 =*/ nb20,
2867
- /*.nb21 =*/ nb21,
2868
- /*.nb22 =*/ nb22,
2869
- /*.nb30 =*/ nb30,
2870
- /*.nb31 =*/ nb31,
2871
- /*.nb40 =*/ nb40,
2872
- /*.nb41 =*/ nb41,
2873
- /*.nb42 =*/ nb42,
2874
- /*.nb50 =*/ nb50,
2875
- /*.nb51 =*/ nb51,
2876
- /*.nb52 =*/ nb52,
2877
  };
2878
 
2879
  [encoder setComputePipelineState:pipeline];
@@ -2883,10 +2905,17 @@ static bool ggml_metal_encode_node(
2883
  [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2884
  [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2885
  [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2886
- [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
2887
- [encoder setBytes:&args length:sizeof(args) atIndex:7];
 
2888
 
2889
- [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
 
 
 
 
 
 
2890
  } break;
2891
  case GGML_OP_RWKV_WKV6:
2892
  {
 
217
  GGML_METAL_KERNEL_TYPE_NORM,
218
  GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
219
  GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
220
+ GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
221
  GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
222
  GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
223
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
 
1197
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
1198
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
1199
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
1200
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
1201
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
1202
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
1203
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
 
2811
  struct ggml_tensor * src3 = node->src[3];
2812
  struct ggml_tensor * src4 = node->src[4];
2813
  struct ggml_tensor * src5 = node->src[5];
2814
+ struct ggml_tensor * src6 = node->src[6];
2815
 
2816
  GGML_ASSERT(src3);
2817
  GGML_ASSERT(src4);
2818
  GGML_ASSERT(src5);
2819
+ GGML_ASSERT(src6);
2820
 
2821
  size_t offs_src3 = 0;
2822
  size_t offs_src4 = 0;
2823
  size_t offs_src5 = 0;
2824
+ size_t offs_src6 = 0;
2825
 
2826
  id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
2827
  id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
2828
  id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
2829
+ id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
2830
 
2831
+ const int64_t ne30 = src3->ne[0];
2832
  const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
2833
 
2834
+ const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
2835
  const uint64_t nb31 = src3->nb[1];
2836
 
2837
  const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
2838
+ const int64_t ne41 = src4->ne[1];
2839
  const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
2840
+ const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
2841
 
2842
+ const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
2843
  const uint64_t nb41 = src4->nb[1];
2844
  const uint64_t nb42 = src4->nb[2];
2845
+ const uint64_t nb43 = src4->nb[3];
2846
 
2847
  const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
2848
  const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
2849
  const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
2850
+ const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
2851
 
2852
+ const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
2853
  const uint64_t nb51 = src5->nb[1];
2854
  const uint64_t nb52 = src5->nb[2];
2855
+ const uint64_t nb53 = src5->nb[3];
2856
+
2857
+ const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
2858
+
2859
+ const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
2860
 
2861
  const int64_t d_state = ne00;
2862
  const int64_t d_inner = ne01;
2863
+ const int64_t n_head = ne02;
2864
+ const int64_t n_group = ne41;
2865
+ const int64_t n_seq_tokens = ne12;
2866
+ const int64_t n_seqs = ne13;
2867
+
2868
+ id<MTLComputePipelineState> pipeline = nil;
2869
 
2870
+ if (ne30 == 1) {
2871
+ // Mamba-2
2872
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
2873
+ } else {
2874
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
2875
+ }
2876
 
2877
  ggml_metal_kargs_ssm_scan args = {
2878
+ /*.d_state =*/ d_state,
2879
+ /*.d_inner =*/ d_inner,
2880
+ /*.n_head =*/ n_head,
2881
+ /*.n_group =*/ n_group,
2882
  /*.n_seq_tokens =*/ n_seq_tokens,
2883
+ /*.n_seqs =*/ n_seqs,
2884
+ /*.nb01 =*/ nb01,
2885
+ /*.nb02 =*/ nb02,
2886
+ /*.nb03 =*/ nb03,
2887
+ /*.nb11 =*/ nb11,
2888
+ /*.nb12 =*/ nb12,
2889
+ /*.nb13 =*/ nb13,
2890
+ /*.nb21 =*/ nb21,
2891
+ /*.nb22 =*/ nb22,
2892
+ /*.nb31 =*/ nb31,
2893
+ /*.nb41 =*/ nb41,
2894
+ /*.nb42 =*/ nb42,
2895
+ /*.nb43 =*/ nb43,
2896
+ /*.nb51 =*/ nb51,
2897
+ /*.nb52 =*/ nb52,
2898
+ /*.nb53 =*/ nb53,
 
 
 
2899
  };
2900
 
2901
  [encoder setComputePipelineState:pipeline];
 
2905
  [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2906
  [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2907
  [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2908
+ [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
2909
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
2910
+ [encoder setBytes:&args length:sizeof(args) atIndex:8];
2911
 
2912
+ if (ne30 == 1) {
2913
+ // Mamba-2
2914
+ [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2915
+ } else {
2916
+ GGML_ASSERT(d_inner == 1);
2917
+ [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2918
+ }
2919
  } break;
2920
  case GGML_OP_RWKV_WKV6:
2921
  {
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -1596,7 +1596,7 @@ kernel void kernel_ssm_conv_f32(
1596
  x[0] = sumf;
1597
  }
1598
 
1599
- // ref: ggml.c:ggml_compute_forward_ssm_scan_f32
1600
  kernel void kernel_ssm_scan_f32(
1601
  device const void * src0,
1602
  device const void * src1,
@@ -1604,46 +1604,119 @@ kernel void kernel_ssm_scan_f32(
1604
  device const void * src3,
1605
  device const void * src4,
1606
  device const void * src5,
 
1607
  device float * dst,
1608
  constant ggml_metal_kargs_ssm_scan & args,
1609
  uint3 tgpig[[threadgroup_position_in_grid]],
1610
  uint3 tpitg[[thread_position_in_threadgroup]],
1611
  uint3 ntg[[threads_per_threadgroup]]) {
1612
- const int64_t ir = tgpig.x;
1613
- const int64_t i3 = tgpig.y;
 
 
 
 
 
1614
 
1615
  const int64_t nc = args.d_state;
1616
- // const int64_t nr = args.d_inner;
 
 
1617
  const int64_t n_t = args.n_seq_tokens;
1618
- // const int64_t n_s = args.n_seqs;
 
 
 
 
 
 
1619
 
1620
  for (int64_t i2 = 0; i2 < n_t; ++i2) {
1621
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
1622
- device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
1623
- device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
1624
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
1625
- device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
1626
- device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
1627
- device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
1628
- device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
1629
-
1630
- if (i2 > 0) {
1631
- s0 = s;
1632
- }
1633
-
1634
- // i1 == 0
1635
- float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1636
- float x_dt = x[0] * dt_soft_plus;
1637
  float sumf = 0.0f;
1638
 
1639
  for (int64_t i0 = 0; i0 < nc; ++i0) {
1640
- int64_t i = i0;
1641
- float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
1642
  sumf += state * C[i0];
1643
  s[i] = state;
1644
  }
1645
 
1646
  y[0] = sumf;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1647
  }
1648
  }
1649
 
 
1596
  x[0] = sumf;
1597
  }
1598
 
1599
+ // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
1600
  kernel void kernel_ssm_scan_f32(
1601
  device const void * src0,
1602
  device const void * src1,
 
1604
  device const void * src3,
1605
  device const void * src4,
1606
  device const void * src5,
1607
+ device const void * src6,
1608
  device float * dst,
1609
  constant ggml_metal_kargs_ssm_scan & args,
1610
  uint3 tgpig[[threadgroup_position_in_grid]],
1611
  uint3 tpitg[[thread_position_in_threadgroup]],
1612
  uint3 ntg[[threads_per_threadgroup]]) {
1613
+ const int64_t i1 = 0;
1614
+ const int64_t ir = tgpig.x; // current head
1615
+ const int64_t i3 = tgpig.y; // current seq
1616
+
1617
+ const uint64_t nb00 = sizeof(float);
1618
+ const uint64_t nb10 = sizeof(float);
1619
+ const uint64_t nb20 = sizeof(float);
1620
 
1621
  const int64_t nc = args.d_state;
1622
+ const int64_t nr = args.d_inner;
1623
+ const int64_t nh = args.n_head;
1624
+ const int64_t ng = args.n_group;
1625
  const int64_t n_t = args.n_seq_tokens;
1626
+
1627
+ const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
1628
+
1629
+ device const int32_t * ids = (device const int32_t *) src6;
1630
+
1631
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1632
+ device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1633
 
1634
  for (int64_t i2 = 0; i2 < n_t; ++i2) {
1635
+ device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1636
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1637
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
1638
+ device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
1639
+ device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1640
+ device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
1641
+
1642
+ const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1643
+ const float x_dt = x[0] * dt_soft_plus;
 
 
 
 
 
 
 
1644
  float sumf = 0.0f;
1645
 
1646
  for (int64_t i0 = 0; i0 < nc; ++i0) {
1647
+ const int64_t i = i0 + i1*nc;
1648
+ const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
1649
  sumf += state * C[i0];
1650
  s[i] = state;
1651
  }
1652
 
1653
  y[0] = sumf;
1654
+
1655
+ // recurse
1656
+ s0 = s;
1657
+ }
1658
+ }
1659
+
1660
+ // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
1661
+ // TODO: optimize (e.g. by parallelizing over d_state)
1662
+ kernel void kernel_ssm_scan_f32_group(
1663
+ device const void * src0,
1664
+ device const void * src1,
1665
+ device const void * src2,
1666
+ device const void * src3,
1667
+ device const void * src4,
1668
+ device const void * src5,
1669
+ device const void * src6,
1670
+ device float * dst,
1671
+ constant ggml_metal_kargs_ssm_scan & args,
1672
+ uint3 tgpig[[threadgroup_position_in_grid]],
1673
+ uint3 tpitg[[thread_position_in_threadgroup]],
1674
+ uint3 ntg[[threads_per_threadgroup]]) {
1675
+ const int64_t i1 = tgpig.x;
1676
+ const int64_t ir = tgpig.y; // current head
1677
+ const int64_t i3 = tgpig.z; // current seq
1678
+
1679
+ const uint64_t nb00 = sizeof(float);
1680
+ const uint64_t nb10 = sizeof(float);
1681
+ const uint64_t nb20 = sizeof(float);
1682
+
1683
+ const int64_t nc = args.d_state;
1684
+ const int64_t nr = args.d_inner;
1685
+ const int64_t nh = args.n_head;
1686
+ const int64_t ng = args.n_group;
1687
+ const int64_t n_t = args.n_seq_tokens;
1688
+
1689
+ const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
1690
+
1691
+ device const int32_t * ids = (device const int32_t *) src6;
1692
+
1693
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1694
+ device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1695
+
1696
+ for (int64_t i2 = 0; i2 < n_t; ++i2) {
1697
+ device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1698
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1699
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
1700
+ device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
1701
+ device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1702
+ device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
1703
+
1704
+ const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1705
+ const float x_dt = x[0] * dt_soft_plus;
1706
+ const float dA = exp(dt_soft_plus * A[0]);
1707
+ float sumf = 0.0f;
1708
+
1709
+ for (int64_t i0 = 0; i0 < nc; ++i0) {
1710
+ const int64_t i = i0 + i1*nc;
1711
+ const float state = (s0[i] * dA) + (B[i0] * x_dt);
1712
+ sumf += state * C[i0];
1713
+ s[i] = state;
1714
+ }
1715
+
1716
+ y[0] = sumf;
1717
+
1718
+ // recurse
1719
+ s0 = s;
1720
  }
1721
  }
1722
 
ggml/src/ggml.c CHANGED
@@ -4829,7 +4829,6 @@ struct ggml_tensor * ggml_ssm_conv(
4829
  const int64_t n_s = sx->ne[2];
4830
 
4831
  // TODO: maybe support other strides than 1?
4832
- // FIXME: this is always true?
4833
  GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
4834
  GGML_ASSERT(sx->ne[1] == d_inner);
4835
  GGML_ASSERT(n_t >= 0);
@@ -4852,36 +4851,49 @@ struct ggml_tensor * ggml_ssm_scan(
4852
  struct ggml_tensor * dt,
4853
  struct ggml_tensor * A,
4854
  struct ggml_tensor * B,
4855
- struct ggml_tensor * C) {
 
4856
  GGML_ASSERT(ggml_is_contiguous(s));
4857
- GGML_ASSERT(ggml_is_contiguous(x));
4858
  GGML_ASSERT(ggml_is_contiguous(dt));
4859
  GGML_ASSERT(ggml_is_contiguous(A));
4860
- GGML_ASSERT(ggml_is_matrix(A));
4861
- GGML_ASSERT(ggml_is_3d(B));
4862
- GGML_ASSERT(ggml_is_3d(s));
4863
  GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
4864
  GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
4865
- GGML_ASSERT(ggml_are_same_shape(x, dt));
 
 
4866
  GGML_ASSERT(ggml_are_same_shape(B, C));
 
4867
 
4868
  {
4869
  const int64_t d_state = s->ne[0];
4870
- const int64_t d_inner = s->ne[1];
4871
- const int64_t n_seq_tokens = x->ne[1];
4872
- const int64_t n_seqs = x->ne[2];
4873
-
4874
- GGML_ASSERT(s->ne[2] == n_seqs);
4875
- GGML_ASSERT(x->ne[0] == d_inner);
4876
- GGML_ASSERT(A->ne[0] == d_state);
4877
- GGML_ASSERT(A->ne[1] == d_inner);
 
 
 
4878
  GGML_ASSERT(B->ne[0] == d_state);
4879
- GGML_ASSERT(B->ne[1] == n_seq_tokens);
4880
- GGML_ASSERT(B->ne[2] == n_seqs);
 
 
 
 
 
 
 
 
 
4881
  }
4882
 
4883
  // concatenated y + ssm_states
4884
- struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
4885
 
4886
  result->op = GGML_OP_SSM_SCAN;
4887
  result->src[0] = s;
@@ -4890,6 +4902,7 @@ struct ggml_tensor * ggml_ssm_scan(
4890
  result->src[3] = A;
4891
  result->src[4] = B;
4892
  result->src[5] = C;
 
4893
 
4894
  return result;
4895
  }
 
4829
  const int64_t n_s = sx->ne[2];
4830
 
4831
  // TODO: maybe support other strides than 1?
 
4832
  GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
4833
  GGML_ASSERT(sx->ne[1] == d_inner);
4834
  GGML_ASSERT(n_t >= 0);
 
4851
  struct ggml_tensor * dt,
4852
  struct ggml_tensor * A,
4853
  struct ggml_tensor * B,
4854
+ struct ggml_tensor * C,
4855
+ struct ggml_tensor * ids) {
4856
  GGML_ASSERT(ggml_is_contiguous(s));
 
4857
  GGML_ASSERT(ggml_is_contiguous(dt));
4858
  GGML_ASSERT(ggml_is_contiguous(A));
4859
+ GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
 
 
4860
  GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
4861
  GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
4862
+ GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
4863
+ GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
4864
+ GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
4865
  GGML_ASSERT(ggml_are_same_shape(B, C));
4866
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
4867
 
4868
  {
4869
  const int64_t d_state = s->ne[0];
4870
+ const int64_t head_dim = x->ne[0];
4871
+ const int64_t n_head = x->ne[1];
4872
+ const int64_t n_seq_tokens = x->ne[2];
4873
+ const int64_t n_seqs = x->ne[3];
4874
+
4875
+ GGML_ASSERT(dt->ne[0] == n_head);
4876
+ GGML_ASSERT(dt->ne[1] == n_seq_tokens);
4877
+ GGML_ASSERT(dt->ne[2] == n_seqs);
4878
+ GGML_ASSERT(ggml_is_3d(dt));
4879
+ GGML_ASSERT(s->ne[1] == head_dim);
4880
+ GGML_ASSERT(s->ne[2] == n_head);
4881
  GGML_ASSERT(B->ne[0] == d_state);
4882
+ GGML_ASSERT(B->ne[2] == n_seq_tokens);
4883
+ GGML_ASSERT(B->ne[3] == n_seqs);
4884
+ GGML_ASSERT(ids->ne[0] == n_seqs);
4885
+ GGML_ASSERT(ggml_is_vector(ids));
4886
+ GGML_ASSERT(A->ne[1] == n_head);
4887
+ GGML_ASSERT(ggml_is_matrix(A));
4888
+
4889
+ if (A->ne[0] != 1) {
4890
+ // Mamba-1 has more granular decay factors
4891
+ GGML_ASSERT(A->ne[0] == d_state);
4892
+ }
4893
  }
4894
 
4895
  // concatenated y + ssm_states
4896
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
4897
 
4898
  result->op = GGML_OP_SSM_SCAN;
4899
  result->src[0] = s;
 
4902
  result->src[3] = A;
4903
  result->src[4] = B;
4904
  result->src[5] = C;
4905
+ result->src[6] = ids;
4906
 
4907
  return result;
4908
  }