ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support. (llama/19700)

* ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support.

* Fix to cast the src value to f32 before sin/cos computing.
This commit is contained in:
Masashi Yoshimura 2026-02-20 01:18:30 +09:00 committed by Georgi Gerganov
parent defde0e7c7
commit 8502813490
2 changed files with 34 additions and 0 deletions

View File

@ -2008,6 +2008,14 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_LOG:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SQR:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SQRT:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_SIN:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_COS:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_PAD:
return ggml_webgpu_pad(ctx, src0, node);
case GGML_OP_ARGMAX:
@ -2967,6 +2975,18 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_LOG:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
case GGML_OP_SQR:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
case GGML_OP_SQRT:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
case GGML_OP_SIN:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
case GGML_OP_COS:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
case GGML_OP_PAD:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;

View File

@ -170,6 +170,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
#ifdef TRUNC
let res = trunc(src[params.offset_src + src_idx]);
#endif
#ifdef SQR
let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx];
#endif
#ifdef SQRT
let res = sqrt(src[params.offset_src + src_idx]);
#endif
#ifdef SIN
let res_f32 = sin(f32(src[params.offset_src + src_idx]));
let res = TYPE(res_f32);
#endif
#ifdef COS
let res_f32 = cos(f32(src[params.offset_src + src_idx]));
let res = TYPE(res_f32);
#endif
#ifdef INPLACE
src[params.offset_src + src_idx] = res;