ggml-webgpu: updated matrix-vector multiplication (#21738)

* merged properly, but slow q3_k and q5_k with u32 indexing

* Start on new mat-vec

* New format float paths working

* Working q4_0

* Work on remaining legacy q-types

* port k-quants to new matvec

* remove old shader

* Remove old constants, format

* remove accidental file

---------

Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>
This commit is contained in:
neha-ha 2026-04-20 07:37:17 -07:00 committed by GitHub
parent a678916623
commit a6cc43c286
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 816 additions and 411 deletions

View File

@ -44,18 +44,9 @@
// Matrix-vector multiplication parameters
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
// Must be multiple of 4 to work with vectorized paths, and must divide
// mul_mat_vec wg size
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
// Requires at least two (and multiple of 2) k-quant blocks per tile
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4
// default size for legacy matrix multiplication
#define WEBGPU_MUL_MAT_WG_SIZE 256
@ -78,6 +69,7 @@ struct ggml_webgpu_shader_lib_context {
bool inplace = false;
bool overlap = false;
bool src_overlap = false;
bool supports_subgroups = false;
bool supports_subgroup_matrix = false;
uint32_t sg_mat_m = 0;
uint32_t sg_mat_n = 0;
@ -575,7 +567,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
struct ggml_webgpu_mul_mat_vec_shader_decisions {
uint32_t wg_size;
uint32_t tile_k;
uint32_t outputs_per_wg;
uint32_t vec_size;
};
@ -1326,7 +1317,7 @@ class ggml_webgpu_shader_lib {
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
@ -1337,7 +1328,8 @@ class ggml_webgpu_shader_lib {
}
std::vector<std::string> defines;
std::string variant = "mul_mat_vec";
std::string variant = "mul_mat_vec";
const char * shader_src = wgsl_mul_mat_vec;
// src0 type (matrix row)
switch (context.src0->type) {
@ -1386,25 +1378,25 @@ class ggml_webgpu_shader_lib {
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
if (key.src0_type >= GGML_TYPE_Q2_K) {
tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
if (key.vectorized) {
variant += "_vectorized";
}
auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines);
auto processed = preprocessor.preprocess(shader_src, defines);
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
decisions->wg_size = wg_size;
decisions->tile_k = tile_k;
decisions->outputs_per_wg = outputs_per_wg;
decisions->vec_size = key.vectorized ? 4 : 1;

View File

@ -181,6 +181,7 @@ struct webgpu_dispatch_desc {
struct webgpu_capabilities {
wgpu::Limits limits;
bool supports_subgroups = false;
bool supports_subgroup_matrix = false;
uint32_t sg_mat_m = 0;
@ -1164,14 +1165,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q6_K:
use_fast = true;
break;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
// we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
use_fast = !is_vec;
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q2_K:
use_fast = true;
break;
default:
break;
@ -1182,10 +1180,12 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
}
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst;
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
@ -1287,7 +1287,8 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
// Get or create pipeline
webgpu_pipeline gather_pipeline, main_pipeline;
webgpu_pipeline gather_pipeline;
webgpu_pipeline main_pipeline;
std::vector<webgpu_dispatch_desc> dispatches;
@ -3040,6 +3041,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
// we require f16 support
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
#ifndef __EMSCRIPTEN__
// Accept f16 subgroup matrix configurations (square or non-square).
@ -3072,11 +3075,14 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
#ifndef __EMSCRIPTEN__
required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
required_features.push_back(wgpu::FeatureName::Subgroups);
required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
}
#endif
if (ctx->webgpu_global_ctx->capabilities.supports_subgroups) {
required_features.push_back(wgpu::FeatureName::Subgroups);
}
#ifdef GGML_WEBGPU_GPU_PROFILE
required_features.push_back(wgpu::FeatureName::TimestampQuery);
#endif

View File

@ -45,6 +45,13 @@ fn load_u16_at_src0(byte_offset: u32) -> u32 {
return (word >> shift) & 0xFFFFu;
}
// Always reads the 4-byte-aligned word containing byte_offset.
// Caller extracts the 16-bit half it needs via & 0xFFFFu or >> 16u.
// this is used in k-quants for better performance
fn load_u32_at_src0_aligned(byte_offset: u32) -> u32 {
return src0[(byte_offset & ~3u) / 4u];
}
fn load_u32_at_src0(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 0x3u) * 8u;

File diff suppressed because it is too large Load Diff