mirror of
https://github.com/ggerganov/ggml
synced 2026-03-01 20:50:26 +01:00
ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support. (llama/19700)
* ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support. * Fix to cast the src value to f32 before sin/cos computing.
This commit is contained in:
parent
defde0e7c7
commit
8502813490
@ -2008,6 +2008,14 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_LOG:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SQR:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SQRT:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SIN:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_COS:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_PAD:
|
||||
return ggml_webgpu_pad(ctx, src0, node);
|
||||
case GGML_OP_ARGMAX:
|
||||
@ -2967,6 +2975,18 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_OP_LOG:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
case GGML_OP_SQR:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
case GGML_OP_SQRT:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
case GGML_OP_SIN:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
case GGML_OP_COS:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
case GGML_OP_PAD:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
|
||||
@ -170,6 +170,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
#ifdef TRUNC
|
||||
let res = trunc(src[params.offset_src + src_idx]);
|
||||
#endif
|
||||
#ifdef SQR
|
||||
let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx];
|
||||
#endif
|
||||
#ifdef SQRT
|
||||
let res = sqrt(src[params.offset_src + src_idx]);
|
||||
#endif
|
||||
#ifdef SIN
|
||||
let res_f32 = sin(f32(src[params.offset_src + src_idx]));
|
||||
let res = TYPE(res_f32);
|
||||
#endif
|
||||
#ifdef COS
|
||||
let res_f32 = cos(f32(src[params.offset_src + src_idx]));
|
||||
let res = TYPE(res_f32);
|
||||
#endif
|
||||
|
||||
#ifdef INPLACE
|
||||
src[params.offset_src + src_idx] = res;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user