OccamRazor commited on
Commit
488f19e
·
1 Parent(s): c6de218

Vulkan: Add VK_EXT_subgroup_size_control support to ensure full subgroups for coopmats (llama/10721)

Browse files

* Vulkan: Add VK_EXT_subgroup_size_control support to ensure full subgroups for coopmats

* Fix subgroup size control extension support check

Add accf32 and accf16 checks for coopmats

* Also disable coopmats on amdvlk

Files changed (1) hide show
  1. ggml/src/ggml-vulkan/ggml-vulkan.cpp +140 -44
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -163,7 +163,11 @@ struct vk_device_struct {
163
  uint32_t shader_core_count;
164
  bool uma;
165
  bool float_controls_rte_fp16;
166
- bool coopmat2;
 
 
 
 
167
 
168
  bool coopmat_support;
169
  bool coopmat_acc_f32_support;
@@ -171,6 +175,7 @@ struct vk_device_struct {
171
  uint32_t coopmat_m;
172
  uint32_t coopmat_n;
173
  uint32_t coopmat_k;
 
174
 
175
  size_t idx;
176
 
@@ -749,8 +754,12 @@ static uint32_t compile_count = 0;
749
  static std::mutex compile_count_mutex;
750
  static std::condition_variable compile_count_cond;
751
 
752
- static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align, bool disable_robustness) {
753
- VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
 
 
 
 
754
  GGML_ASSERT(parameter_count > 0);
755
  GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
756
 
@@ -809,14 +818,28 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
809
  specialization_constants.data()
810
  );
811
 
 
 
 
 
 
 
812
  vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
813
- vk::PipelineShaderStageCreateFlags(),
814
  vk::ShaderStageFlagBits::eCompute,
815
  pipeline->shader_module,
816
  entrypoint.c_str(),
817
  &specialization_info);
 
 
 
 
 
 
 
 
818
  vk::ComputePipelineCreateInfo compute_pipeline_create_info(
819
- vk::PipelineCreateFlags(),
820
  pipeline_shader_create_info,
821
  pipeline->layout);
822
 
@@ -1496,7 +1519,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
1496
  device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1497
 
1498
  std::vector<std::future<void>> compiles;
1499
- auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false) {
 
 
1500
  {
1501
  // wait until fewer than N compiles are in progress
1502
  uint32_t N = std::max(1u, std::thread::hardware_concurrency());
@@ -1506,7 +1531,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1506
  }
1507
  compile_count++;
1508
  }
1509
- compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness));
 
1510
  };
1511
 
1512
  #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@@ -1612,40 +1638,59 @@ static void ggml_vk_load_shaders(vk_device& device) {
1612
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1613
  #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1614
  if (device->mul_mat ## ID ## _l) \
1615
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1616
  if (device->mul_mat ## ID ## _m) \
1617
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1618
  if (device->mul_mat ## ID ## _s) \
1619
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1620
  if (device->mul_mat ## ID ## _l) \
1621
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1622
  if (device->mul_mat ## ID ## _m) \
1623
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1624
  if (device->mul_mat ## ID ## _s) \
1625
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1626
 
1627
  // Create 2 variants, {f16,f32} accumulator
1628
  #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1629
- CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1630
- CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
 
 
 
 
1631
 
1632
  CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1633
  CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1634
  CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1635
  CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1636
 
1637
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1638
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1639
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1640
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1641
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1642
 
1643
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1644
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1645
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1646
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1647
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1648
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
1649
 
1650
  // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1651
  if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
@@ -1653,19 +1698,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
1653
  CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1654
  CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1655
 
1656
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1657
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1658
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1659
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1660
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1661
-
1662
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1663
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1664
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1665
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1666
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1667
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1668
  }
 
1669
  #undef CREATE_MM
1670
  } else if (device->fp16) {
1671
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
@@ -1683,6 +1744,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
1683
  if (device->mul_mat ## ID ## _s) \
1684
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1685
 
 
 
 
 
 
1686
  CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1687
  CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1688
  CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
@@ -1720,6 +1786,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1720
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1721
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1722
  }
 
1723
  #undef CREATE_MM
1724
  } else {
1725
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
@@ -1774,7 +1841,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
1774
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1775
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1776
  }
1777
- #undef CREATE_MM2
1778
  #undef CREATE_MM
1779
  }
1780
 
@@ -1998,6 +2064,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
1998
  amd_shader_core_properties2 = true;
1999
  } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
2000
  pipeline_robustness = true;
 
 
2001
  } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
2002
  !getenv("GGML_VK_DISABLE_COOPMAT")) {
2003
  device->coopmat_support = true;
@@ -2018,6 +2086,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
2018
  vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
2019
  vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2020
  vk::PhysicalDeviceVulkan12Properties vk12_props;
 
 
2021
  props2.pNext = &props3;
2022
  props3.pNext = &subgroup_props;
2023
  subgroup_props.pNext = &driver_props;
@@ -2037,6 +2107,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
2037
  last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
2038
  last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
2039
  }
 
 
 
 
2040
 
2041
  #if defined(VK_NV_cooperative_matrix2)
2042
  vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
@@ -2075,7 +2149,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2075
 
2076
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2077
 
2078
- if (device->vendor_id == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2079
  // Intel drivers don't support coopmat properly yet
2080
  // Only RADV supports coopmat properly on AMD
2081
  device->coopmat_support = false;
@@ -2131,6 +2205,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
2131
  device_extensions.push_back("VK_EXT_pipeline_robustness");
2132
  }
2133
 
 
 
 
 
 
 
 
 
 
 
 
2134
  VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
2135
  coopmat_features.pNext = nullptr;
2136
  coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
@@ -2158,6 +2243,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
2158
 
2159
  device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
2160
 
 
 
 
 
 
 
 
 
 
 
 
2161
  device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
2162
 
2163
  if (coopmat2_support) {
@@ -2307,7 +2403,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2307
  }
2308
  }
2309
 
2310
- if (device->coopmat_m == 0) {
2311
  // No suitable matmul mode found
2312
  GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
2313
  device->coopmat_support = false;
@@ -2440,7 +2536,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2440
  }
2441
  }
2442
 
2443
- if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2444
  // Intel drivers don't support coopmat properly yet
2445
  // Only RADV supports coopmat properly on AMD
2446
  coopmat_support = false;
@@ -2727,7 +2823,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
2727
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
2728
  return ctx->device->pipeline_matmul_f32_f16;
2729
  }
2730
- if (prec == GGML_PREC_DEFAULT && ctx->device->fp16) {
2731
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2732
  return ctx->device->pipeline_matmul_f16_f32.f16acc;
2733
  }
@@ -2802,7 +2898,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
2802
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
2803
  return ctx->device->pipeline_matmul_id_f32;
2804
  }
2805
- if (prec == GGML_PREC_DEFAULT && ctx->device->fp16) {
2806
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2807
  return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
2808
  }
 
163
  uint32_t shader_core_count;
164
  bool uma;
165
  bool float_controls_rte_fp16;
166
+
167
+ bool subgroup_size_control;
168
+ uint32_t subgroup_min_size;
169
+ uint32_t subgroup_max_size;
170
+ bool subgroup_require_full_support;
171
 
172
  bool coopmat_support;
173
  bool coopmat_acc_f32_support;
 
175
  uint32_t coopmat_m;
176
  uint32_t coopmat_n;
177
  uint32_t coopmat_k;
178
+ bool coopmat2;
179
 
180
  size_t idx;
181
 
 
754
  static std::mutex compile_count_mutex;
755
  static std::condition_variable compile_count_cond;
756
 
757
+ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint,
758
+ uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
759
+ uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
760
+ VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size <<
761
+ ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align <<
762
+ ", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
763
  GGML_ASSERT(parameter_count > 0);
764
  GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
765
 
 
818
  specialization_constants.data()
819
  );
820
 
821
+ vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{};
822
+
823
+ if (device->subgroup_require_full_support && require_full_subgroups) {
824
+ pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT;
825
+ }
826
+
827
  vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
828
+ pipeline_shader_stage_create_flags,
829
  vk::ShaderStageFlagBits::eCompute,
830
  pipeline->shader_module,
831
  entrypoint.c_str(),
832
  &specialization_info);
833
+
834
+ vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info;
835
+ pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size;
836
+ if (device->subgroup_size_control && required_subgroup_size > 0) {
837
+ GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size);
838
+ pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info);
839
+ }
840
+
841
  vk::ComputePipelineCreateInfo compute_pipeline_create_info(
842
+ vk::PipelineCreateFlags{},
843
  pipeline_shader_create_info,
844
  pipeline->layout);
845
 
 
1519
  device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1520
 
1521
  std::vector<std::future<void>> compiles;
1522
+ auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
1523
+ uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
1524
+ uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
1525
  {
1526
  // wait until fewer than N compiles are in progress
1527
  uint32_t N = std::max(1u, std::thread::hardware_concurrency());
 
1531
  }
1532
  compile_count++;
1533
  }
1534
+ compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint,
1535
+ parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size));
1536
  };
1537
 
1538
  #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
 
1638
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1639
  #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1640
  if (device->mul_mat ## ID ## _l) \
1641
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
1642
  if (device->mul_mat ## ID ## _m) \
1643
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
1644
  if (device->mul_mat ## ID ## _s) \
1645
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
1646
  if (device->mul_mat ## ID ## _l) \
1647
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
1648
  if (device->mul_mat ## ID ## _m) \
1649
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
1650
  if (device->mul_mat ## ID ## _s) \
1651
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
1652
 
1653
  // Create 2 variants, {f16,f32} accumulator
1654
  #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1655
+ if (device->coopmat_acc_f16_support) { \
1656
+ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1657
+ } \
1658
+ if (device->coopmat_acc_f32_support) { \
1659
+ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1660
+ } \
1661
 
1662
  CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1663
  CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1664
  CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1665
  CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1666
 
1667
+ if (device->coopmat_acc_f16_support) {
1668
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1669
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1670
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1671
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1672
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1673
+
1674
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1675
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1676
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1677
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1678
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1679
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1680
+ } else {
1681
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1682
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1683
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1684
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1685
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1686
 
1687
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1688
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1689
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1690
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1691
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1692
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1693
+ }
1694
 
1695
  // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1696
  if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
 
1698
  CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1699
  CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1700
 
1701
+ if (device->coopmat_acc_f16_support) {
1702
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1703
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1704
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1705
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1706
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1707
+
1708
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1709
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1710
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1711
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1712
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1713
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1714
+ } else {
1715
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1716
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1717
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1718
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1719
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1720
+
1721
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1722
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1723
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1724
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1725
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1726
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1727
+ }
1728
  }
1729
+ #undef CREATE_MM2
1730
  #undef CREATE_MM
1731
  } else if (device->fp16) {
1732
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
 
1744
  if (device->mul_mat ## ID ## _s) \
1745
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1746
 
1747
+ // Create 2 variants, {f16,f32} accumulator
1748
+ #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1749
+ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1750
+ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1751
+
1752
  CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1753
  CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1754
  CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
 
1786
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1787
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1788
  }
1789
+ #undef CREATE_MM2
1790
  #undef CREATE_MM
1791
  } else {
1792
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
 
1841
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1842
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1843
  }
 
1844
  #undef CREATE_MM
1845
  }
1846
 
 
2064
  amd_shader_core_properties2 = true;
2065
  } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
2066
  pipeline_robustness = true;
2067
+ } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
2068
+ device->subgroup_size_control = true;
2069
  } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
2070
  !getenv("GGML_VK_DISABLE_COOPMAT")) {
2071
  device->coopmat_support = true;
 
2086
  vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
2087
  vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2088
  vk::PhysicalDeviceVulkan12Properties vk12_props;
2089
+ vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2090
+
2091
  props2.pNext = &props3;
2092
  props3.pNext = &subgroup_props;
2093
  subgroup_props.pNext = &driver_props;
 
2107
  last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
2108
  last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
2109
  }
2110
+ if (device->subgroup_size_control) {
2111
+ last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props;
2112
+ last_struct = (VkBaseOutStructure *)&subgroup_size_control_props;
2113
+ }
2114
 
2115
  #if defined(VK_NV_cooperative_matrix2)
2116
  vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
 
2149
 
2150
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2151
 
2152
+ if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
2153
  // Intel drivers don't support coopmat properly yet
2154
  // Only RADV supports coopmat properly on AMD
2155
  device->coopmat_support = false;
 
2205
  device_extensions.push_back("VK_EXT_pipeline_robustness");
2206
  }
2207
 
2208
+ VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
2209
+ subgroup_size_control_features.pNext = nullptr;
2210
+ subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;
2211
+ subgroup_size_control_features.computeFullSubgroups = false;
2212
+ subgroup_size_control_features.subgroupSizeControl = false;
2213
+
2214
+ if (device->subgroup_size_control) {
2215
+ last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features;
2216
+ last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;
2217
+ }
2218
+
2219
  VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
2220
  coopmat_features.pNext = nullptr;
2221
  coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
 
2243
 
2244
  device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
2245
 
2246
+ device->subgroup_size_control = device->subgroup_size_control &&
2247
+ (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
2248
+ subgroup_size_control_features.subgroupSizeControl;
2249
+
2250
+ if (device->subgroup_size_control) {
2251
+ device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
2252
+ device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
2253
+ device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
2254
+ device_extensions.push_back("VK_EXT_subgroup_size_control");
2255
+ }
2256
+
2257
  device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
2258
 
2259
  if (coopmat2_support) {
 
2403
  }
2404
  }
2405
 
2406
+ if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
2407
  // No suitable matmul mode found
2408
  GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
2409
  device->coopmat_support = false;
 
2536
  }
2537
  }
2538
 
2539
+ if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
2540
  // Intel drivers don't support coopmat properly yet
2541
  // Only RADV supports coopmat properly on AMD
2542
  coopmat_support = false;
 
2823
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
2824
  return ctx->device->pipeline_matmul_f32_f16;
2825
  }
2826
+ if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
2827
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2828
  return ctx->device->pipeline_matmul_f16_f32.f16acc;
2829
  }
 
2898
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
2899
  return ctx->device->pipeline_matmul_id_f32;
2900
  }
2901
+ if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
2902
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2903
  return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
2904
  }