cmdr2 commited on
Commit
f959b90
·
1 Parent(s): 67e8c32

cuda/vulkan: specify fp32-only support for some operations in supports_op (ggml/1129)

Browse files

* cuda: restrict SILU_BACK to fp32, since fp16 exceeds the desired test threshold

* vulkan: specify fp32-only support for certain ops (that are now tested for fp16 as well)

* f32 sigmoid in vulkan supports op

* Revert "f32 sigmoid in vulkan supports op"

This reverts commit c6f04b3c19bf4504c2776149c6d8cd84e0b48acb.

ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -3153,7 +3153,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3153
  return false;
3154
  } break;
3155
  case GGML_OP_SILU_BACK:
3156
- return ggml_is_contiguous(op->src[0]);
3157
  break;
3158
  case GGML_OP_NORM:
3159
  case GGML_OP_RMS_NORM:
 
3153
  return false;
3154
  } break;
3155
  case GGML_OP_SILU_BACK:
3156
+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
3157
  break;
3158
  case GGML_OP_NORM:
3159
  case GGML_OP_RMS_NORM:
ggml/src/ggml-vulkan/ggml-vulkan.cpp CHANGED
@@ -8371,7 +8371,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8371
  case GGML_UNARY_OP_SILU:
8372
  case GGML_UNARY_OP_RELU:
8373
  case GGML_UNARY_OP_TANH:
8374
- return ggml_is_contiguous(op->src[0]);
8375
  default:
8376
  return false;
8377
  }
@@ -8571,17 +8571,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8571
  case GGML_OP_RMS_NORM:
8572
  return ggml_is_contiguous(op->src[0]);
8573
  case GGML_OP_ADD:
8574
- case GGML_OP_ACC:
8575
  case GGML_OP_SUB:
8576
  case GGML_OP_MUL:
8577
  case GGML_OP_DIV:
8578
- case GGML_OP_CONCAT:
8579
- case GGML_OP_UPSCALE:
8580
- case GGML_OP_SCALE:
8581
  case GGML_OP_SQR:
8582
  case GGML_OP_SIN:
8583
  case GGML_OP_COS:
8584
  case GGML_OP_CLAMP:
 
 
 
 
 
8585
  case GGML_OP_PAD:
8586
  case GGML_OP_DIAG_MASK_INF:
8587
  case GGML_OP_SOFT_MAX:
 
8371
  case GGML_UNARY_OP_SILU:
8372
  case GGML_UNARY_OP_RELU:
8373
  case GGML_UNARY_OP_TANH:
8374
+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
8375
  default:
8376
  return false;
8377
  }
 
8571
  case GGML_OP_RMS_NORM:
8572
  return ggml_is_contiguous(op->src[0]);
8573
  case GGML_OP_ADD:
 
8574
  case GGML_OP_SUB:
8575
  case GGML_OP_MUL:
8576
  case GGML_OP_DIV:
 
 
 
8577
  case GGML_OP_SQR:
8578
  case GGML_OP_SIN:
8579
  case GGML_OP_COS:
8580
  case GGML_OP_CLAMP:
8581
+ return op->src[0]->type == GGML_TYPE_F32;
8582
+ case GGML_OP_ACC:
8583
+ case GGML_OP_CONCAT:
8584
+ case GGML_OP_UPSCALE:
8585
+ case GGML_OP_SCALE:
8586
  case GGML_OP_PAD:
8587
  case GGML_OP_DIAG_MASK_INF:
8588
  case GGML_OP_SOFT_MAX: