Spaces:
Sleeping
Sleeping
metal : minor code formatting
Browse files- ggml/src/ggml-metal/ggml-metal.m +299 -299
- ggml/src/ggml-metal/ggml-metal.metal +24 -19
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -1959,316 +1959,316 @@ static void ggml_metal_encode_node(
|
|
| 1959 |
}
|
| 1960 |
#endif
|
| 1961 |
|
| 1962 |
-
|
| 1963 |
-
|
| 1964 |
-
|
| 1965 |
-
|
| 1966 |
-
|
| 1967 |
-
|
| 1968 |
-
|
| 1969 |
-
|
| 1970 |
-
|
| 1971 |
-
|
| 1972 |
-
// some Metal matrix data types require aligned pointers
|
| 1973 |
-
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
| 1974 |
-
switch (src0->type) {
|
| 1975 |
-
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
| 1976 |
-
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
| 1977 |
-
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
| 1978 |
-
default: break;
|
| 1979 |
-
}
|
| 1980 |
|
| 1981 |
-
|
| 1982 |
-
|
| 1983 |
-
|
| 1984 |
-
|
| 1985 |
-
|
| 1986 |
-
|
| 1987 |
-
|
| 1988 |
-
|
| 1989 |
-
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
| 1990 |
-
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
| 1991 |
-
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
| 1992 |
-
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
| 1993 |
-
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
| 1994 |
-
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
| 1995 |
-
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
|
| 1996 |
-
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
| 1997 |
-
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
| 1998 |
-
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
| 1999 |
-
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
| 2000 |
-
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
| 2001 |
-
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
| 2002 |
-
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
| 2003 |
-
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
|
| 2004 |
-
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
| 2005 |
-
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
| 2006 |
-
default: GGML_ABORT("MUL MAT-MAT not implemented");
|
| 2007 |
-
}
|
| 2008 |
|
| 2009 |
-
|
| 2010 |
-
|
| 2011 |
-
|
| 2012 |
-
|
| 2013 |
-
|
| 2014 |
-
|
| 2015 |
-
|
| 2016 |
-
|
| 2017 |
-
|
| 2018 |
-
|
| 2019 |
-
|
| 2020 |
-
|
| 2021 |
-
|
| 2022 |
-
|
| 2023 |
-
|
| 2024 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2025 |
|
| 2026 |
-
|
| 2027 |
-
|
| 2028 |
-
|
| 2029 |
-
|
| 2030 |
-
|
| 2031 |
|
| 2032 |
-
|
| 2033 |
-
|
| 2034 |
-
|
| 2035 |
-
|
| 2036 |
-
|
| 2037 |
-
|
| 2038 |
-
|
| 2039 |
|
| 2040 |
-
|
| 2041 |
|
| 2042 |
-
|
| 2043 |
-
|
| 2044 |
-
|
| 2045 |
-
|
| 2046 |
-
|
| 2047 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2048 |
nrows = 4;
|
| 2049 |
-
} break;
|
| 2050 |
-
case GGML_TYPE_F16:
|
| 2051 |
-
{
|
| 2052 |
-
nth0 = 32;
|
| 2053 |
-
nth1 = 1;
|
| 2054 |
-
if (src1t == GGML_TYPE_F32) {
|
| 2055 |
-
if (ne11 * ne12 < 4) {
|
| 2056 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
| 2057 |
-
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 2058 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
| 2059 |
-
nrows = ne11;
|
| 2060 |
-
} else {
|
| 2061 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
| 2062 |
-
nrows = 4;
|
| 2063 |
-
}
|
| 2064 |
-
} else {
|
| 2065 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
| 2066 |
-
nrows = 4;
|
| 2067 |
-
}
|
| 2068 |
-
} break;
|
| 2069 |
-
case GGML_TYPE_BF16:
|
| 2070 |
-
{
|
| 2071 |
-
nth0 = 32;
|
| 2072 |
-
nth1 = 1;
|
| 2073 |
-
if (src1t == GGML_TYPE_F32) {
|
| 2074 |
-
if (ne11 * ne12 < 4) {
|
| 2075 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
| 2076 |
-
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 2077 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
| 2078 |
-
nrows = ne11;
|
| 2079 |
-
} else {
|
| 2080 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
|
| 2081 |
-
nrows = 4;
|
| 2082 |
-
}
|
| 2083 |
-
} else {
|
| 2084 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
|
| 2085 |
-
nrows = 4;
|
| 2086 |
-
}
|
| 2087 |
-
} break;
|
| 2088 |
-
case GGML_TYPE_Q4_0:
|
| 2089 |
-
{
|
| 2090 |
-
nth0 = 8;
|
| 2091 |
-
nth1 = 8;
|
| 2092 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
| 2093 |
-
} break;
|
| 2094 |
-
case GGML_TYPE_Q4_1:
|
| 2095 |
-
{
|
| 2096 |
-
nth0 = 8;
|
| 2097 |
-
nth1 = 8;
|
| 2098 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
| 2099 |
-
} break;
|
| 2100 |
-
case GGML_TYPE_Q5_0:
|
| 2101 |
-
{
|
| 2102 |
-
nth0 = 8;
|
| 2103 |
-
nth1 = 8;
|
| 2104 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
| 2105 |
-
} break;
|
| 2106 |
-
case GGML_TYPE_Q5_1:
|
| 2107 |
-
{
|
| 2108 |
-
nth0 = 8;
|
| 2109 |
-
nth1 = 8;
|
| 2110 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
| 2111 |
-
} break;
|
| 2112 |
-
case GGML_TYPE_Q8_0:
|
| 2113 |
-
{
|
| 2114 |
-
nth0 = 8;
|
| 2115 |
-
nth1 = 8;
|
| 2116 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
| 2117 |
-
} break;
|
| 2118 |
-
case GGML_TYPE_Q2_K:
|
| 2119 |
-
{
|
| 2120 |
-
nth0 = 2;
|
| 2121 |
-
nth1 = 32;
|
| 2122 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
| 2123 |
-
} break;
|
| 2124 |
-
case GGML_TYPE_Q3_K:
|
| 2125 |
-
{
|
| 2126 |
-
nth0 = 2;
|
| 2127 |
-
nth1 = 32;
|
| 2128 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
| 2129 |
-
} break;
|
| 2130 |
-
case GGML_TYPE_Q4_K:
|
| 2131 |
-
{
|
| 2132 |
-
nth0 = 4; //1;
|
| 2133 |
-
nth1 = 8; //32;
|
| 2134 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
| 2135 |
-
} break;
|
| 2136 |
-
case GGML_TYPE_Q5_K:
|
| 2137 |
-
{
|
| 2138 |
-
nth0 = 2;
|
| 2139 |
-
nth1 = 32;
|
| 2140 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
| 2141 |
-
} break;
|
| 2142 |
-
case GGML_TYPE_Q6_K:
|
| 2143 |
-
{
|
| 2144 |
-
nth0 = 2;
|
| 2145 |
-
nth1 = 32;
|
| 2146 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
| 2147 |
-
} break;
|
| 2148 |
-
case GGML_TYPE_IQ2_XXS:
|
| 2149 |
-
{
|
| 2150 |
-
nth0 = 4;
|
| 2151 |
-
nth1 = 16;
|
| 2152 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
| 2153 |
-
} break;
|
| 2154 |
-
case GGML_TYPE_IQ2_XS:
|
| 2155 |
-
{
|
| 2156 |
-
nth0 = 4;
|
| 2157 |
-
nth1 = 16;
|
| 2158 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
| 2159 |
-
} break;
|
| 2160 |
-
case GGML_TYPE_IQ3_XXS:
|
| 2161 |
-
{
|
| 2162 |
-
nth0 = 4;
|
| 2163 |
-
nth1 = 16;
|
| 2164 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
| 2165 |
-
} break;
|
| 2166 |
-
case GGML_TYPE_IQ3_S:
|
| 2167 |
-
{
|
| 2168 |
-
nth0 = 4;
|
| 2169 |
-
nth1 = 16;
|
| 2170 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
| 2171 |
-
} break;
|
| 2172 |
-
case GGML_TYPE_IQ2_S:
|
| 2173 |
-
{
|
| 2174 |
-
nth0 = 4;
|
| 2175 |
-
nth1 = 16;
|
| 2176 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
| 2177 |
-
} break;
|
| 2178 |
-
case GGML_TYPE_IQ1_S:
|
| 2179 |
-
{
|
| 2180 |
-
nth0 = 4;
|
| 2181 |
-
nth1 = 16;
|
| 2182 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
| 2183 |
-
} break;
|
| 2184 |
-
case GGML_TYPE_IQ1_M:
|
| 2185 |
-
{
|
| 2186 |
-
nth0 = 4;
|
| 2187 |
-
nth1 = 16;
|
| 2188 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
| 2189 |
-
} break;
|
| 2190 |
-
case GGML_TYPE_IQ4_NL:
|
| 2191 |
-
{
|
| 2192 |
-
nth0 = 4;
|
| 2193 |
-
nth1 = 16;
|
| 2194 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
| 2195 |
-
} break;
|
| 2196 |
-
case GGML_TYPE_IQ4_XS:
|
| 2197 |
-
{
|
| 2198 |
-
nth0 = 4;
|
| 2199 |
-
nth1 = 16;
|
| 2200 |
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
| 2201 |
-
} break;
|
| 2202 |
-
default:
|
| 2203 |
-
{
|
| 2204 |
-
GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
| 2205 |
-
GGML_ABORT("not implemented");
|
| 2206 |
}
|
| 2207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2208 |
|
| 2209 |
-
|
| 2210 |
-
|
| 2211 |
-
|
| 2212 |
-
|
| 2213 |
-
|
| 2214 |
-
|
| 2215 |
-
|
| 2216 |
-
|
| 2217 |
-
|
| 2218 |
-
|
| 2219 |
-
|
| 2220 |
-
|
| 2221 |
-
|
| 2222 |
-
|
| 2223 |
-
|
| 2224 |
-
|
| 2225 |
-
|
| 2226 |
-
|
| 2227 |
-
|
| 2228 |
-
|
| 2229 |
|
| 2230 |
-
|
| 2231 |
-
|
| 2232 |
-
|
| 2233 |
-
|
| 2234 |
-
|
| 2235 |
|
| 2236 |
-
|
| 2237 |
-
|
| 2238 |
-
|
| 2239 |
-
|
| 2240 |
-
|
| 2241 |
-
|
| 2242 |
-
|
| 2243 |
-
|
| 2244 |
-
|
| 2245 |
-
|
| 2246 |
-
|
| 2247 |
-
|
| 2248 |
-
|
| 2249 |
-
|
| 2250 |
-
|
| 2251 |
-
|
| 2252 |
-
|
| 2253 |
-
|
| 2254 |
-
|
| 2255 |
-
|
| 2256 |
-
|
| 2257 |
-
|
| 2258 |
-
|
| 2259 |
-
|
| 2260 |
-
|
| 2261 |
-
|
| 2262 |
-
|
| 2263 |
-
|
| 2264 |
-
|
| 2265 |
-
|
| 2266 |
-
|
| 2267 |
-
|
| 2268 |
-
|
| 2269 |
-
|
| 2270 |
-
|
| 2271 |
-
|
| 2272 |
} break;
|
| 2273 |
case GGML_OP_MUL_MAT_ID:
|
| 2274 |
{
|
|
|
|
| 1959 |
}
|
| 1960 |
#endif
|
| 1961 |
|
| 1962 |
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
| 1963 |
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
| 1964 |
+
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
| 1965 |
+
!ggml_is_transposed(src0) &&
|
| 1966 |
+
!ggml_is_transposed(src1) &&
|
| 1967 |
+
src1t == GGML_TYPE_F32 &&
|
| 1968 |
+
ne00 % 32 == 0 && ne00 >= 64 &&
|
| 1969 |
+
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
|
| 1970 |
+
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1971 |
|
| 1972 |
+
// some Metal matrix data types require aligned pointers
|
| 1973 |
+
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
| 1974 |
+
switch (src0->type) {
|
| 1975 |
+
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
| 1976 |
+
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
| 1977 |
+
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
| 1978 |
+
default: break;
|
| 1979 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1980 |
|
| 1981 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 1982 |
+
|
| 1983 |
+
switch (src0->type) {
|
| 1984 |
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
|
| 1985 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
|
| 1986 |
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
|
| 1987 |
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
|
| 1988 |
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
|
| 1989 |
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
| 1990 |
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
| 1991 |
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
| 1992 |
+
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
| 1993 |
+
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
| 1994 |
+
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
| 1995 |
+
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
|
| 1996 |
+
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
| 1997 |
+
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
| 1998 |
+
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
| 1999 |
+
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
| 2000 |
+
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
| 2001 |
+
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
| 2002 |
+
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
| 2003 |
+
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
|
| 2004 |
+
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
| 2005 |
+
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
| 2006 |
+
default: GGML_ABORT("MUL MAT-MAT not implemented");
|
| 2007 |
+
}
|
| 2008 |
+
|
| 2009 |
+
ggml_metal_kargs_mul_mm args = {
|
| 2010 |
+
/*.ne00 =*/ ne00,
|
| 2011 |
+
/*.ne02 =*/ ne02,
|
| 2012 |
+
/*.nb01 =*/ nb01,
|
| 2013 |
+
/*.nb02 =*/ nb02,
|
| 2014 |
+
/*.nb03 =*/ nb03,
|
| 2015 |
+
/*.ne12 =*/ ne12,
|
| 2016 |
+
/*.nb10 =*/ nb10,
|
| 2017 |
+
/*.nb11 =*/ nb11,
|
| 2018 |
+
/*.nb12 =*/ nb12,
|
| 2019 |
+
/*.nb13 =*/ nb13,
|
| 2020 |
+
/*.ne0 =*/ ne0,
|
| 2021 |
+
/*.ne1 =*/ ne1,
|
| 2022 |
+
/*.r2 =*/ r2,
|
| 2023 |
+
/*.r3 =*/ r3,
|
| 2024 |
+
};
|
| 2025 |
|
| 2026 |
+
[encoder setComputePipelineState:pipeline];
|
| 2027 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
| 2028 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 2029 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
| 2030 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
| 2031 |
|
| 2032 |
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
| 2033 |
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 2034 |
+
} else {
|
| 2035 |
+
int nth0 = 32;
|
| 2036 |
+
int nth1 = 1;
|
| 2037 |
+
int nrows = 1;
|
| 2038 |
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
| 2039 |
|
| 2040 |
+
id<MTLComputePipelineState> pipeline = nil;
|
| 2041 |
|
| 2042 |
+
// use custom matrix x vector kernel
|
| 2043 |
+
switch (src0t) {
|
| 2044 |
+
case GGML_TYPE_F32:
|
| 2045 |
+
{
|
| 2046 |
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
| 2047 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
| 2048 |
+
nrows = 4;
|
| 2049 |
+
} break;
|
| 2050 |
+
case GGML_TYPE_F16:
|
| 2051 |
+
{
|
| 2052 |
+
nth0 = 32;
|
| 2053 |
+
nth1 = 1;
|
| 2054 |
+
if (src1t == GGML_TYPE_F32) {
|
| 2055 |
+
if (ne11 * ne12 < 4) {
|
| 2056 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
| 2057 |
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 2058 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
| 2059 |
+
nrows = ne11;
|
| 2060 |
+
} else {
|
| 2061 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
| 2062 |
nrows = 4;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2063 |
}
|
| 2064 |
+
} else {
|
| 2065 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
| 2066 |
+
nrows = 4;
|
| 2067 |
+
}
|
| 2068 |
+
} break;
|
| 2069 |
+
case GGML_TYPE_BF16:
|
| 2070 |
+
{
|
| 2071 |
+
nth0 = 32;
|
| 2072 |
+
nth1 = 1;
|
| 2073 |
+
if (src1t == GGML_TYPE_F32) {
|
| 2074 |
+
if (ne11 * ne12 < 4) {
|
| 2075 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
| 2076 |
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
| 2077 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
| 2078 |
+
nrows = ne11;
|
| 2079 |
+
} else {
|
| 2080 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
|
| 2081 |
+
nrows = 4;
|
| 2082 |
+
}
|
| 2083 |
+
} else {
|
| 2084 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
|
| 2085 |
+
nrows = 4;
|
| 2086 |
+
}
|
| 2087 |
+
} break;
|
| 2088 |
+
case GGML_TYPE_Q4_0:
|
| 2089 |
+
{
|
| 2090 |
+
nth0 = 8;
|
| 2091 |
+
nth1 = 8;
|
| 2092 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
| 2093 |
+
} break;
|
| 2094 |
+
case GGML_TYPE_Q4_1:
|
| 2095 |
+
{
|
| 2096 |
+
nth0 = 8;
|
| 2097 |
+
nth1 = 8;
|
| 2098 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
| 2099 |
+
} break;
|
| 2100 |
+
case GGML_TYPE_Q5_0:
|
| 2101 |
+
{
|
| 2102 |
+
nth0 = 8;
|
| 2103 |
+
nth1 = 8;
|
| 2104 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
| 2105 |
+
} break;
|
| 2106 |
+
case GGML_TYPE_Q5_1:
|
| 2107 |
+
{
|
| 2108 |
+
nth0 = 8;
|
| 2109 |
+
nth1 = 8;
|
| 2110 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
| 2111 |
+
} break;
|
| 2112 |
+
case GGML_TYPE_Q8_0:
|
| 2113 |
+
{
|
| 2114 |
+
nth0 = 8;
|
| 2115 |
+
nth1 = 8;
|
| 2116 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
| 2117 |
+
} break;
|
| 2118 |
+
case GGML_TYPE_Q2_K:
|
| 2119 |
+
{
|
| 2120 |
+
nth0 = 2;
|
| 2121 |
+
nth1 = 32;
|
| 2122 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
| 2123 |
+
} break;
|
| 2124 |
+
case GGML_TYPE_Q3_K:
|
| 2125 |
+
{
|
| 2126 |
+
nth0 = 2;
|
| 2127 |
+
nth1 = 32;
|
| 2128 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
| 2129 |
+
} break;
|
| 2130 |
+
case GGML_TYPE_Q4_K:
|
| 2131 |
+
{
|
| 2132 |
+
nth0 = 4; //1;
|
| 2133 |
+
nth1 = 8; //32;
|
| 2134 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
| 2135 |
+
} break;
|
| 2136 |
+
case GGML_TYPE_Q5_K:
|
| 2137 |
+
{
|
| 2138 |
+
nth0 = 2;
|
| 2139 |
+
nth1 = 32;
|
| 2140 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
| 2141 |
+
} break;
|
| 2142 |
+
case GGML_TYPE_Q6_K:
|
| 2143 |
+
{
|
| 2144 |
+
nth0 = 2;
|
| 2145 |
+
nth1 = 32;
|
| 2146 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
| 2147 |
+
} break;
|
| 2148 |
+
case GGML_TYPE_IQ2_XXS:
|
| 2149 |
+
{
|
| 2150 |
+
nth0 = 4;
|
| 2151 |
+
nth1 = 16;
|
| 2152 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
| 2153 |
+
} break;
|
| 2154 |
+
case GGML_TYPE_IQ2_XS:
|
| 2155 |
+
{
|
| 2156 |
+
nth0 = 4;
|
| 2157 |
+
nth1 = 16;
|
| 2158 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
| 2159 |
+
} break;
|
| 2160 |
+
case GGML_TYPE_IQ3_XXS:
|
| 2161 |
+
{
|
| 2162 |
+
nth0 = 4;
|
| 2163 |
+
nth1 = 16;
|
| 2164 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
| 2165 |
+
} break;
|
| 2166 |
+
case GGML_TYPE_IQ3_S:
|
| 2167 |
+
{
|
| 2168 |
+
nth0 = 4;
|
| 2169 |
+
nth1 = 16;
|
| 2170 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
| 2171 |
+
} break;
|
| 2172 |
+
case GGML_TYPE_IQ2_S:
|
| 2173 |
+
{
|
| 2174 |
+
nth0 = 4;
|
| 2175 |
+
nth1 = 16;
|
| 2176 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
| 2177 |
+
} break;
|
| 2178 |
+
case GGML_TYPE_IQ1_S:
|
| 2179 |
+
{
|
| 2180 |
+
nth0 = 4;
|
| 2181 |
+
nth1 = 16;
|
| 2182 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
| 2183 |
+
} break;
|
| 2184 |
+
case GGML_TYPE_IQ1_M:
|
| 2185 |
+
{
|
| 2186 |
+
nth0 = 4;
|
| 2187 |
+
nth1 = 16;
|
| 2188 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
| 2189 |
+
} break;
|
| 2190 |
+
case GGML_TYPE_IQ4_NL:
|
| 2191 |
+
{
|
| 2192 |
+
nth0 = 4;
|
| 2193 |
+
nth1 = 16;
|
| 2194 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
| 2195 |
+
} break;
|
| 2196 |
+
case GGML_TYPE_IQ4_XS:
|
| 2197 |
+
{
|
| 2198 |
+
nth0 = 4;
|
| 2199 |
+
nth1 = 16;
|
| 2200 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
| 2201 |
+
} break;
|
| 2202 |
+
default:
|
| 2203 |
+
{
|
| 2204 |
+
GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
| 2205 |
+
GGML_ABORT("not implemented");
|
| 2206 |
+
}
|
| 2207 |
+
};
|
| 2208 |
|
| 2209 |
+
ggml_metal_kargs_mul_mv args = {
|
| 2210 |
+
/*.ne00 =*/ ne00,
|
| 2211 |
+
/*.ne01 =*/ ne01,
|
| 2212 |
+
/*.ne02 =*/ ne02,
|
| 2213 |
+
/*.nb00 =*/ nb00,
|
| 2214 |
+
/*.nb01 =*/ nb01,
|
| 2215 |
+
/*.nb02 =*/ nb02,
|
| 2216 |
+
/*.nb03 =*/ nb03,
|
| 2217 |
+
/*.ne10 =*/ ne10,
|
| 2218 |
+
/*.ne11 =*/ ne11,
|
| 2219 |
+
/*.ne12 =*/ ne12,
|
| 2220 |
+
/*.nb10 =*/ nb10,
|
| 2221 |
+
/*.nb11 =*/ nb11,
|
| 2222 |
+
/*.nb12 =*/ nb12,
|
| 2223 |
+
/*.nb13 =*/ nb13,
|
| 2224 |
+
/*.ne0 =*/ ne0,
|
| 2225 |
+
/*.ne1 =*/ ne1,
|
| 2226 |
+
/*.r2 =*/ r2,
|
| 2227 |
+
/*.r3 =*/ r3,
|
| 2228 |
+
};
|
| 2229 |
|
| 2230 |
+
[encoder setComputePipelineState:pipeline];
|
| 2231 |
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
| 2232 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 2233 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
| 2234 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
| 2235 |
|
| 2236 |
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
| 2237 |
+
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
| 2238 |
+
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
| 2239 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2240 |
+
}
|
| 2241 |
+
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
| 2242 |
+
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
| 2243 |
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2244 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2245 |
+
}
|
| 2246 |
+
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
| 2247 |
+
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
| 2248 |
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2249 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2250 |
+
}
|
| 2251 |
+
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
| 2252 |
+
const int mem_size = 32*sizeof(float);
|
| 2253 |
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2254 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2255 |
+
}
|
| 2256 |
+
else if (src0t == GGML_TYPE_Q4_K) {
|
| 2257 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2258 |
+
}
|
| 2259 |
+
else if (src0t == GGML_TYPE_Q3_K) {
|
| 2260 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2261 |
+
}
|
| 2262 |
+
else if (src0t == GGML_TYPE_Q5_K) {
|
| 2263 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2264 |
+
}
|
| 2265 |
+
else if (src0t == GGML_TYPE_Q6_K) {
|
| 2266 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2267 |
+
} else {
|
| 2268 |
+
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
| 2269 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2270 |
+
}
|
| 2271 |
+
}
|
| 2272 |
} break;
|
| 2273 |
case GGML_OP_MUL_MAT_ID:
|
| 2274 |
{
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -5577,12 +5577,12 @@ kernel void kernel_mul_mm(
|
|
| 5577 |
const int im = tgpig.z;
|
| 5578 |
|
| 5579 |
// if this block is of 64x32 shape or smaller
|
| 5580 |
-
short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
| 5581 |
-
short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
| 5582 |
|
| 5583 |
// a thread shouldn't load data outside of the matrix
|
| 5584 |
-
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
| 5585 |
-
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
| 5586 |
|
| 5587 |
simdgroup_T8x8 ma[4];
|
| 5588 |
simdgroup_float8x8 mb[2];
|
|
@@ -5597,20 +5597,23 @@ kernel void kernel_mul_mm(
|
|
| 5597 |
const int i12 = im%args.ne12;
|
| 5598 |
const int i13 = im/args.ne12;
|
| 5599 |
|
| 5600 |
-
uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
| 5601 |
-
short offset1 = il/nl;
|
|
|
|
|
|
|
|
|
|
| 5602 |
|
| 5603 |
-
device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1;
|
| 5604 |
device const float * y = (device const float *)(src1
|
| 5605 |
+ args.nb13*i13
|
| 5606 |
+ args.nb12*i12
|
| 5607 |
-
+ args.nb11*(r1
|
| 5608 |
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
| 5609 |
|
| 5610 |
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
| 5611 |
// load data and store to threadgroup memory
|
| 5612 |
T4x4 temp_a;
|
| 5613 |
dequantize_func(x, il, temp_a);
|
|
|
|
| 5614 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5615 |
|
| 5616 |
#pragma unroll(16)
|
|
@@ -5620,44 +5623,46 @@ kernel void kernel_mul_mm(
|
|
| 5620 |
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
| 5621 |
}
|
| 5622 |
|
| 5623 |
-
*(threadgroup float2x4 *)(sb + (tiitg
|
| 5624 |
|
| 5625 |
il = (il + 2 < nl) ? il + 2 : il % 2;
|
| 5626 |
-
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
| 5627 |
y += BLOCK_SIZE_K;
|
| 5628 |
|
| 5629 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5630 |
|
| 5631 |
// load matrices from threadgroup memory and conduct outer products
|
| 5632 |
-
threadgroup T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
| 5633 |
-
threadgroup float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
| 5634 |
|
| 5635 |
#pragma unroll(4)
|
| 5636 |
-
for (short ik = 0; ik < BLOCK_SIZE_K
|
| 5637 |
#pragma unroll(4)
|
| 5638 |
for (short i = 0; i < 4; i++) {
|
| 5639 |
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
| 5640 |
}
|
|
|
|
| 5641 |
simdgroup_barrier(mem_flags::mem_none);
|
|
|
|
| 5642 |
#pragma unroll(2)
|
| 5643 |
for (short i = 0; i < 2; i++) {
|
| 5644 |
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
| 5645 |
}
|
| 5646 |
|
| 5647 |
-
lsma += BLOCK_SIZE_M/SG_MAT_ROW * SG_MAT_SIZE;
|
| 5648 |
-
lsmb += BLOCK_SIZE_N/SG_MAT_ROW * SG_MAT_SIZE;
|
| 5649 |
-
|
| 5650 |
#pragma unroll(8)
|
| 5651 |
for (short i = 0; i < 8; i++){
|
| 5652 |
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
| 5653 |
}
|
|
|
|
|
|
|
|
|
|
| 5654 |
}
|
| 5655 |
}
|
| 5656 |
|
| 5657 |
if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
|
| 5658 |
device float * C = (device float *) dst +
|
| 5659 |
-
(BLOCK_SIZE_M * r0 + 32
|
| 5660 |
-
(BLOCK_SIZE_N * r1 + 16
|
| 5661 |
|
| 5662 |
for (short i = 0; i < 8; i++) {
|
| 5663 |
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
|
|
@@ -5666,7 +5671,7 @@ kernel void kernel_mul_mm(
|
|
| 5666 |
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
| 5667 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5668 |
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
| 5669 |
-
|
| 5670 |
for (short i = 0; i < 8; i++) {
|
| 5671 |
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
|
| 5672 |
}
|
|
|
|
| 5577 |
const int im = tgpig.z;
|
| 5578 |
|
| 5579 |
// if this block is of 64x32 shape or smaller
|
| 5580 |
+
const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
| 5581 |
+
const short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
| 5582 |
|
| 5583 |
// a thread shouldn't load data outside of the matrix
|
| 5584 |
+
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
| 5585 |
+
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
| 5586 |
|
| 5587 |
simdgroup_T8x8 ma[4];
|
| 5588 |
simdgroup_float8x8 mb[2];
|
|
|
|
| 5597 |
const int i12 = im%args.ne12;
|
| 5598 |
const int i13 = im/args.ne12;
|
| 5599 |
|
| 5600 |
+
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
| 5601 |
+
const short offset1 = il/nl;
|
| 5602 |
+
|
| 5603 |
+
device const block_q * x = (device const block_q *)(src0
|
| 5604 |
+
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
|
| 5605 |
|
|
|
|
| 5606 |
device const float * y = (device const float *)(src1
|
| 5607 |
+ args.nb13*i13
|
| 5608 |
+ args.nb12*i12
|
| 5609 |
+
+ args.nb11*(r1*BLOCK_SIZE_N + thread_col)
|
| 5610 |
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
| 5611 |
|
| 5612 |
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
| 5613 |
// load data and store to threadgroup memory
|
| 5614 |
T4x4 temp_a;
|
| 5615 |
dequantize_func(x, il, temp_a);
|
| 5616 |
+
|
| 5617 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5618 |
|
| 5619 |
#pragma unroll(16)
|
|
|
|
| 5623 |
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
| 5624 |
}
|
| 5625 |
|
| 5626 |
+
*(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
|
| 5627 |
|
| 5628 |
il = (il + 2 < nl) ? il + 2 : il % 2;
|
| 5629 |
+
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
| 5630 |
y += BLOCK_SIZE_K;
|
| 5631 |
|
| 5632 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5633 |
|
| 5634 |
// load matrices from threadgroup memory and conduct outer products
|
| 5635 |
+
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
| 5636 |
+
threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
| 5637 |
|
| 5638 |
#pragma unroll(4)
|
| 5639 |
+
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
|
| 5640 |
#pragma unroll(4)
|
| 5641 |
for (short i = 0; i < 4; i++) {
|
| 5642 |
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
| 5643 |
}
|
| 5644 |
+
|
| 5645 |
simdgroup_barrier(mem_flags::mem_none);
|
| 5646 |
+
|
| 5647 |
#pragma unroll(2)
|
| 5648 |
for (short i = 0; i < 2; i++) {
|
| 5649 |
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
| 5650 |
}
|
| 5651 |
|
|
|
|
|
|
|
|
|
|
| 5652 |
#pragma unroll(8)
|
| 5653 |
for (short i = 0; i < 8; i++){
|
| 5654 |
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
| 5655 |
}
|
| 5656 |
+
|
| 5657 |
+
lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
|
| 5658 |
+
lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
|
| 5659 |
}
|
| 5660 |
}
|
| 5661 |
|
| 5662 |
if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
|
| 5663 |
device float * C = (device float *) dst +
|
| 5664 |
+
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
|
| 5665 |
+
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
|
| 5666 |
|
| 5667 |
for (short i = 0; i < 8; i++) {
|
| 5668 |
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
|
|
|
|
| 5671 |
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
| 5672 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5673 |
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
| 5674 |
+
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
|
| 5675 |
for (short i = 0; i < 8; i++) {
|
| 5676 |
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
|
| 5677 |
}
|