ggerganov commited on
Commit
385a521
·
1 Parent(s): 0001327

metal : minor code formatting

Browse files
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -1959,316 +1959,316 @@ static void ggml_metal_encode_node(
1959
  }
1960
  #endif
1961
 
1962
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1963
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1964
- if ([device supportsFamily:MTLGPUFamilyApple7] &&
1965
- !ggml_is_transposed(src0) &&
1966
- !ggml_is_transposed(src1) &&
1967
- src1t == GGML_TYPE_F32 &&
1968
- ne00 % 32 == 0 && ne00 >= 64 &&
1969
- (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1970
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1971
-
1972
- // some Metal matrix data types require aligned pointers
1973
- // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1974
- switch (src0->type) {
1975
- case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1976
- case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1977
- case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1978
- default: break;
1979
- }
1980
 
1981
- id<MTLComputePipelineState> pipeline = nil;
1982
-
1983
- switch (src0->type) {
1984
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1985
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1986
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
1987
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1988
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1989
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
1990
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
1991
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
1992
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
1993
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
1994
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
1995
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
1996
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1997
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1998
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1999
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
2000
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
2001
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
2002
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
2003
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
2004
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
2005
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
2006
- default: GGML_ABORT("MUL MAT-MAT not implemented");
2007
- }
2008
 
2009
- ggml_metal_kargs_mul_mm args = {
2010
- /*.ne00 =*/ ne00,
2011
- /*.ne02 =*/ ne02,
2012
- /*.nb01 =*/ nb01,
2013
- /*.nb02 =*/ nb02,
2014
- /*.nb03 =*/ nb03,
2015
- /*.ne12 =*/ ne12,
2016
- /*.nb10 =*/ nb10,
2017
- /*.nb11 =*/ nb11,
2018
- /*.nb12 =*/ nb12,
2019
- /*.nb13 =*/ nb13,
2020
- /*.ne0 =*/ ne0,
2021
- /*.ne1 =*/ ne1,
2022
- /*.r2 =*/ r2,
2023
- /*.r3 =*/ r3,
2024
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2025
 
2026
- [encoder setComputePipelineState:pipeline];
2027
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
2028
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2029
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2030
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2031
 
2032
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
2033
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2034
- } else {
2035
- int nth0 = 32;
2036
- int nth1 = 1;
2037
- int nrows = 1;
2038
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2039
 
2040
- id<MTLComputePipelineState> pipeline = nil;
2041
 
2042
- // use custom matrix x vector kernel
2043
- switch (src0t) {
2044
- case GGML_TYPE_F32:
2045
- {
2046
- GGML_ASSERT(src1t == GGML_TYPE_F32);
2047
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2048
  nrows = 4;
2049
- } break;
2050
- case GGML_TYPE_F16:
2051
- {
2052
- nth0 = 32;
2053
- nth1 = 1;
2054
- if (src1t == GGML_TYPE_F32) {
2055
- if (ne11 * ne12 < 4) {
2056
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
2057
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2058
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
2059
- nrows = ne11;
2060
- } else {
2061
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
2062
- nrows = 4;
2063
- }
2064
- } else {
2065
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
2066
- nrows = 4;
2067
- }
2068
- } break;
2069
- case GGML_TYPE_BF16:
2070
- {
2071
- nth0 = 32;
2072
- nth1 = 1;
2073
- if (src1t == GGML_TYPE_F32) {
2074
- if (ne11 * ne12 < 4) {
2075
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
2076
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2077
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
2078
- nrows = ne11;
2079
- } else {
2080
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
2081
- nrows = 4;
2082
- }
2083
- } else {
2084
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
2085
- nrows = 4;
2086
- }
2087
- } break;
2088
- case GGML_TYPE_Q4_0:
2089
- {
2090
- nth0 = 8;
2091
- nth1 = 8;
2092
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
2093
- } break;
2094
- case GGML_TYPE_Q4_1:
2095
- {
2096
- nth0 = 8;
2097
- nth1 = 8;
2098
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
2099
- } break;
2100
- case GGML_TYPE_Q5_0:
2101
- {
2102
- nth0 = 8;
2103
- nth1 = 8;
2104
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
2105
- } break;
2106
- case GGML_TYPE_Q5_1:
2107
- {
2108
- nth0 = 8;
2109
- nth1 = 8;
2110
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
2111
- } break;
2112
- case GGML_TYPE_Q8_0:
2113
- {
2114
- nth0 = 8;
2115
- nth1 = 8;
2116
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
2117
- } break;
2118
- case GGML_TYPE_Q2_K:
2119
- {
2120
- nth0 = 2;
2121
- nth1 = 32;
2122
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
2123
- } break;
2124
- case GGML_TYPE_Q3_K:
2125
- {
2126
- nth0 = 2;
2127
- nth1 = 32;
2128
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
2129
- } break;
2130
- case GGML_TYPE_Q4_K:
2131
- {
2132
- nth0 = 4; //1;
2133
- nth1 = 8; //32;
2134
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
2135
- } break;
2136
- case GGML_TYPE_Q5_K:
2137
- {
2138
- nth0 = 2;
2139
- nth1 = 32;
2140
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
2141
- } break;
2142
- case GGML_TYPE_Q6_K:
2143
- {
2144
- nth0 = 2;
2145
- nth1 = 32;
2146
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
2147
- } break;
2148
- case GGML_TYPE_IQ2_XXS:
2149
- {
2150
- nth0 = 4;
2151
- nth1 = 16;
2152
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
2153
- } break;
2154
- case GGML_TYPE_IQ2_XS:
2155
- {
2156
- nth0 = 4;
2157
- nth1 = 16;
2158
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
2159
- } break;
2160
- case GGML_TYPE_IQ3_XXS:
2161
- {
2162
- nth0 = 4;
2163
- nth1 = 16;
2164
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
2165
- } break;
2166
- case GGML_TYPE_IQ3_S:
2167
- {
2168
- nth0 = 4;
2169
- nth1 = 16;
2170
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
2171
- } break;
2172
- case GGML_TYPE_IQ2_S:
2173
- {
2174
- nth0 = 4;
2175
- nth1 = 16;
2176
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
2177
- } break;
2178
- case GGML_TYPE_IQ1_S:
2179
- {
2180
- nth0 = 4;
2181
- nth1 = 16;
2182
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
2183
- } break;
2184
- case GGML_TYPE_IQ1_M:
2185
- {
2186
- nth0 = 4;
2187
- nth1 = 16;
2188
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
2189
- } break;
2190
- case GGML_TYPE_IQ4_NL:
2191
- {
2192
- nth0 = 4;
2193
- nth1 = 16;
2194
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
2195
- } break;
2196
- case GGML_TYPE_IQ4_XS:
2197
- {
2198
- nth0 = 4;
2199
- nth1 = 16;
2200
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
2201
- } break;
2202
- default:
2203
- {
2204
- GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
2205
- GGML_ABORT("not implemented");
2206
  }
2207
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2208
 
2209
- ggml_metal_kargs_mul_mv args = {
2210
- /*.ne00 =*/ ne00,
2211
- /*.ne01 =*/ ne01,
2212
- /*.ne02 =*/ ne02,
2213
- /*.nb00 =*/ nb00,
2214
- /*.nb01 =*/ nb01,
2215
- /*.nb02 =*/ nb02,
2216
- /*.nb03 =*/ nb03,
2217
- /*.ne10 =*/ ne10,
2218
- /*.ne11 =*/ ne11,
2219
- /*.ne12 =*/ ne12,
2220
- /*.nb10 =*/ nb10,
2221
- /*.nb11 =*/ nb11,
2222
- /*.nb12 =*/ nb12,
2223
- /*.nb13 =*/ nb13,
2224
- /*.ne0 =*/ ne0,
2225
- /*.ne1 =*/ ne1,
2226
- /*.r2 =*/ r2,
2227
- /*.r3 =*/ r3,
2228
- };
2229
 
2230
- [encoder setComputePipelineState:pipeline];
2231
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
2232
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2233
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2234
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2235
 
2236
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2237
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2238
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2239
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2240
- }
2241
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
2242
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2243
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2244
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2245
- }
2246
- else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
2247
- const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2248
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2249
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2250
- }
2251
- else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
2252
- const int mem_size = 32*sizeof(float);
2253
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2254
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2255
- }
2256
- else if (src0t == GGML_TYPE_Q4_K) {
2257
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2258
- }
2259
- else if (src0t == GGML_TYPE_Q3_K) {
2260
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2261
- }
2262
- else if (src0t == GGML_TYPE_Q5_K) {
2263
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2264
- }
2265
- else if (src0t == GGML_TYPE_Q6_K) {
2266
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2267
- } else {
2268
- const int64_t ny = (ne11 + nrows - 1)/nrows;
2269
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2270
- }
2271
- }
2272
  } break;
2273
  case GGML_OP_MUL_MAT_ID:
2274
  {
 
1959
  }
1960
  #endif
1961
 
1962
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1963
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1964
+ if ([device supportsFamily:MTLGPUFamilyApple7] &&
1965
+ !ggml_is_transposed(src0) &&
1966
+ !ggml_is_transposed(src1) &&
1967
+ src1t == GGML_TYPE_F32 &&
1968
+ ne00 % 32 == 0 && ne00 >= 64 &&
1969
+ (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1970
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
 
 
 
 
 
 
 
 
 
1971
 
1972
+ // some Metal matrix data types require aligned pointers
1973
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1974
+ switch (src0->type) {
1975
+ case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1976
+ case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1977
+ case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1978
+ default: break;
1979
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1980
 
1981
+ id<MTLComputePipelineState> pipeline = nil;
1982
+
1983
+ switch (src0->type) {
1984
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1985
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1986
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
1987
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1988
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1989
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
1990
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
1991
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
1992
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
1993
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
1994
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
1995
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
1996
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1997
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1998
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1999
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
2000
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
2001
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
2002
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
2003
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
2004
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
2005
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
2006
+ default: GGML_ABORT("MUL MAT-MAT not implemented");
2007
+ }
2008
+
2009
+ ggml_metal_kargs_mul_mm args = {
2010
+ /*.ne00 =*/ ne00,
2011
+ /*.ne02 =*/ ne02,
2012
+ /*.nb01 =*/ nb01,
2013
+ /*.nb02 =*/ nb02,
2014
+ /*.nb03 =*/ nb03,
2015
+ /*.ne12 =*/ ne12,
2016
+ /*.nb10 =*/ nb10,
2017
+ /*.nb11 =*/ nb11,
2018
+ /*.nb12 =*/ nb12,
2019
+ /*.nb13 =*/ nb13,
2020
+ /*.ne0 =*/ ne0,
2021
+ /*.ne1 =*/ ne1,
2022
+ /*.r2 =*/ r2,
2023
+ /*.r3 =*/ r3,
2024
+ };
2025
 
2026
+ [encoder setComputePipelineState:pipeline];
2027
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2028
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2029
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2030
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2031
 
2032
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
2033
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2034
+ } else {
2035
+ int nth0 = 32;
2036
+ int nth1 = 1;
2037
+ int nrows = 1;
2038
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2039
 
2040
+ id<MTLComputePipelineState> pipeline = nil;
2041
 
2042
+ // use custom matrix x vector kernel
2043
+ switch (src0t) {
2044
+ case GGML_TYPE_F32:
2045
+ {
2046
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
2047
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
2048
+ nrows = 4;
2049
+ } break;
2050
+ case GGML_TYPE_F16:
2051
+ {
2052
+ nth0 = 32;
2053
+ nth1 = 1;
2054
+ if (src1t == GGML_TYPE_F32) {
2055
+ if (ne11 * ne12 < 4) {
2056
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
2057
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2058
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
2059
+ nrows = ne11;
2060
+ } else {
2061
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
2062
  nrows = 4;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2063
  }
2064
+ } else {
2065
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
2066
+ nrows = 4;
2067
+ }
2068
+ } break;
2069
+ case GGML_TYPE_BF16:
2070
+ {
2071
+ nth0 = 32;
2072
+ nth1 = 1;
2073
+ if (src1t == GGML_TYPE_F32) {
2074
+ if (ne11 * ne12 < 4) {
2075
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
2076
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2077
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
2078
+ nrows = ne11;
2079
+ } else {
2080
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
2081
+ nrows = 4;
2082
+ }
2083
+ } else {
2084
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
2085
+ nrows = 4;
2086
+ }
2087
+ } break;
2088
+ case GGML_TYPE_Q4_0:
2089
+ {
2090
+ nth0 = 8;
2091
+ nth1 = 8;
2092
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
2093
+ } break;
2094
+ case GGML_TYPE_Q4_1:
2095
+ {
2096
+ nth0 = 8;
2097
+ nth1 = 8;
2098
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
2099
+ } break;
2100
+ case GGML_TYPE_Q5_0:
2101
+ {
2102
+ nth0 = 8;
2103
+ nth1 = 8;
2104
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
2105
+ } break;
2106
+ case GGML_TYPE_Q5_1:
2107
+ {
2108
+ nth0 = 8;
2109
+ nth1 = 8;
2110
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
2111
+ } break;
2112
+ case GGML_TYPE_Q8_0:
2113
+ {
2114
+ nth0 = 8;
2115
+ nth1 = 8;
2116
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
2117
+ } break;
2118
+ case GGML_TYPE_Q2_K:
2119
+ {
2120
+ nth0 = 2;
2121
+ nth1 = 32;
2122
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
2123
+ } break;
2124
+ case GGML_TYPE_Q3_K:
2125
+ {
2126
+ nth0 = 2;
2127
+ nth1 = 32;
2128
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
2129
+ } break;
2130
+ case GGML_TYPE_Q4_K:
2131
+ {
2132
+ nth0 = 4; //1;
2133
+ nth1 = 8; //32;
2134
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
2135
+ } break;
2136
+ case GGML_TYPE_Q5_K:
2137
+ {
2138
+ nth0 = 2;
2139
+ nth1 = 32;
2140
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
2141
+ } break;
2142
+ case GGML_TYPE_Q6_K:
2143
+ {
2144
+ nth0 = 2;
2145
+ nth1 = 32;
2146
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
2147
+ } break;
2148
+ case GGML_TYPE_IQ2_XXS:
2149
+ {
2150
+ nth0 = 4;
2151
+ nth1 = 16;
2152
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
2153
+ } break;
2154
+ case GGML_TYPE_IQ2_XS:
2155
+ {
2156
+ nth0 = 4;
2157
+ nth1 = 16;
2158
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
2159
+ } break;
2160
+ case GGML_TYPE_IQ3_XXS:
2161
+ {
2162
+ nth0 = 4;
2163
+ nth1 = 16;
2164
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
2165
+ } break;
2166
+ case GGML_TYPE_IQ3_S:
2167
+ {
2168
+ nth0 = 4;
2169
+ nth1 = 16;
2170
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
2171
+ } break;
2172
+ case GGML_TYPE_IQ2_S:
2173
+ {
2174
+ nth0 = 4;
2175
+ nth1 = 16;
2176
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
2177
+ } break;
2178
+ case GGML_TYPE_IQ1_S:
2179
+ {
2180
+ nth0 = 4;
2181
+ nth1 = 16;
2182
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
2183
+ } break;
2184
+ case GGML_TYPE_IQ1_M:
2185
+ {
2186
+ nth0 = 4;
2187
+ nth1 = 16;
2188
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
2189
+ } break;
2190
+ case GGML_TYPE_IQ4_NL:
2191
+ {
2192
+ nth0 = 4;
2193
+ nth1 = 16;
2194
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
2195
+ } break;
2196
+ case GGML_TYPE_IQ4_XS:
2197
+ {
2198
+ nth0 = 4;
2199
+ nth1 = 16;
2200
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
2201
+ } break;
2202
+ default:
2203
+ {
2204
+ GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
2205
+ GGML_ABORT("not implemented");
2206
+ }
2207
+ };
2208
 
2209
+ ggml_metal_kargs_mul_mv args = {
2210
+ /*.ne00 =*/ ne00,
2211
+ /*.ne01 =*/ ne01,
2212
+ /*.ne02 =*/ ne02,
2213
+ /*.nb00 =*/ nb00,
2214
+ /*.nb01 =*/ nb01,
2215
+ /*.nb02 =*/ nb02,
2216
+ /*.nb03 =*/ nb03,
2217
+ /*.ne10 =*/ ne10,
2218
+ /*.ne11 =*/ ne11,
2219
+ /*.ne12 =*/ ne12,
2220
+ /*.nb10 =*/ nb10,
2221
+ /*.nb11 =*/ nb11,
2222
+ /*.nb12 =*/ nb12,
2223
+ /*.nb13 =*/ nb13,
2224
+ /*.ne0 =*/ ne0,
2225
+ /*.ne1 =*/ ne1,
2226
+ /*.r2 =*/ r2,
2227
+ /*.r3 =*/ r3,
2228
+ };
2229
 
2230
+ [encoder setComputePipelineState:pipeline];
2231
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2232
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2233
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2234
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2235
 
2236
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2237
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2238
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2239
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2240
+ }
2241
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
2242
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2243
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2244
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2245
+ }
2246
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
2247
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2248
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2249
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2250
+ }
2251
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
2252
+ const int mem_size = 32*sizeof(float);
2253
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2254
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2255
+ }
2256
+ else if (src0t == GGML_TYPE_Q4_K) {
2257
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2258
+ }
2259
+ else if (src0t == GGML_TYPE_Q3_K) {
2260
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2261
+ }
2262
+ else if (src0t == GGML_TYPE_Q5_K) {
2263
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2264
+ }
2265
+ else if (src0t == GGML_TYPE_Q6_K) {
2266
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2267
+ } else {
2268
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
2269
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2270
+ }
2271
+ }
2272
  } break;
2273
  case GGML_OP_MUL_MAT_ID:
2274
  {
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -5577,12 +5577,12 @@ kernel void kernel_mul_mm(
5577
  const int im = tgpig.z;
5578
 
5579
  // if this block is of 64x32 shape or smaller
5580
- short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
5581
- short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
5582
 
5583
  // a thread shouldn't load data outside of the matrix
5584
- short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
5585
- short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
5586
 
5587
  simdgroup_T8x8 ma[4];
5588
  simdgroup_float8x8 mb[2];
@@ -5597,20 +5597,23 @@ kernel void kernel_mul_mm(
5597
  const int i12 = im%args.ne12;
5598
  const int i13 = im/args.ne12;
5599
 
5600
- uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
5601
- short offset1 = il/nl;
 
 
 
5602
 
5603
- device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1;
5604
  device const float * y = (device const float *)(src1
5605
  + args.nb13*i13
5606
  + args.nb12*i12
5607
- + args.nb11*(r1 * BLOCK_SIZE_N + thread_col)
5608
  + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
5609
 
5610
  for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
5611
  // load data and store to threadgroup memory
5612
  T4x4 temp_a;
5613
  dequantize_func(x, il, temp_a);
 
5614
  threadgroup_barrier(mem_flags::mem_threadgroup);
5615
 
5616
  #pragma unroll(16)
@@ -5620,44 +5623,46 @@ kernel void kernel_mul_mm(
5620
  + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
5621
  }
5622
 
5623
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL)*8*32 + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
5624
 
5625
  il = (il + 2 < nl) ? il + 2 : il % 2;
5626
- x = (il < 2) ? x + (2+nl-1)/nl : x;
5627
  y += BLOCK_SIZE_K;
5628
 
5629
  threadgroup_barrier(mem_flags::mem_threadgroup);
5630
 
5631
  // load matrices from threadgroup memory and conduct outer products
5632
- threadgroup T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
5633
- threadgroup float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
5634
 
5635
  #pragma unroll(4)
5636
- for (short ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
5637
  #pragma unroll(4)
5638
  for (short i = 0; i < 4; i++) {
5639
  simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
5640
  }
 
5641
  simdgroup_barrier(mem_flags::mem_none);
 
5642
  #pragma unroll(2)
5643
  for (short i = 0; i < 2; i++) {
5644
  simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
5645
  }
5646
 
5647
- lsma += BLOCK_SIZE_M/SG_MAT_ROW * SG_MAT_SIZE;
5648
- lsmb += BLOCK_SIZE_N/SG_MAT_ROW * SG_MAT_SIZE;
5649
-
5650
  #pragma unroll(8)
5651
  for (short i = 0; i < 8; i++){
5652
  simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
5653
  }
 
 
 
5654
  }
5655
  }
5656
 
5657
  if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
5658
  device float * C = (device float *) dst +
5659
- (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) + \
5660
- (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
5661
 
5662
  for (short i = 0; i < 8; i++) {
5663
  simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
@@ -5666,7 +5671,7 @@ kernel void kernel_mul_mm(
5666
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
5667
  threadgroup_barrier(mem_flags::mem_threadgroup);
5668
  threadgroup float * temp_str = ((threadgroup float *) shmem) \
5669
- + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M;
5670
  for (short i = 0; i < 8; i++) {
5671
  simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
5672
  }
 
5577
  const int im = tgpig.z;
5578
 
5579
  // if this block is of 64x32 shape or smaller
5580
+ const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
5581
+ const short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
5582
 
5583
  // a thread shouldn't load data outside of the matrix
5584
+ const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
5585
+ const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
5586
 
5587
  simdgroup_T8x8 ma[4];
5588
  simdgroup_float8x8 mb[2];
 
5597
  const int i12 = im%args.ne12;
5598
  const int i13 = im/args.ne12;
5599
 
5600
+ const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
5601
+ const short offset1 = il/nl;
5602
+
5603
+ device const block_q * x = (device const block_q *)(src0
5604
+ + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
5605
 
 
5606
  device const float * y = (device const float *)(src1
5607
  + args.nb13*i13
5608
  + args.nb12*i12
5609
+ + args.nb11*(r1*BLOCK_SIZE_N + thread_col)
5610
  + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
5611
 
5612
  for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
5613
  // load data and store to threadgroup memory
5614
  T4x4 temp_a;
5615
  dequantize_func(x, il, temp_a);
5616
+
5617
  threadgroup_barrier(mem_flags::mem_threadgroup);
5618
 
5619
  #pragma unroll(16)
 
5623
  + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
5624
  }
5625
 
5626
+ *(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
5627
 
5628
  il = (il + 2 < nl) ? il + 2 : il % 2;
5629
+ x = (il < 2) ? x + (2 + nl - 1)/nl : x;
5630
  y += BLOCK_SIZE_K;
5631
 
5632
  threadgroup_barrier(mem_flags::mem_threadgroup);
5633
 
5634
  // load matrices from threadgroup memory and conduct outer products
5635
+ threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
5636
+ threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
5637
 
5638
  #pragma unroll(4)
5639
+ for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
5640
  #pragma unroll(4)
5641
  for (short i = 0; i < 4; i++) {
5642
  simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
5643
  }
5644
+
5645
  simdgroup_barrier(mem_flags::mem_none);
5646
+
5647
  #pragma unroll(2)
5648
  for (short i = 0; i < 2; i++) {
5649
  simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
5650
  }
5651
 
 
 
 
5652
  #pragma unroll(8)
5653
  for (short i = 0; i < 8; i++){
5654
  simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
5655
  }
5656
+
5657
+ lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
5658
+ lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
5659
  }
5660
  }
5661
 
5662
  if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
5663
  device float * C = (device float *) dst +
5664
+ (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
5665
+ (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
5666
 
5667
  for (short i = 0; i < 8; i++) {
5668
  simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
 
5671
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
5672
  threadgroup_barrier(mem_flags::mem_threadgroup);
5673
  threadgroup float * temp_str = ((threadgroup float *) shmem) \
5674
+ + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
5675
  for (short i = 0; i < 8; i++) {
5676
  simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
5677
  }