mirror of
https://github.com/ggerganov/llama.cpp
synced 2026-04-26 05:21:53 +02:00
ggml-webgpu: updated matrix-vector multiplication (#21738)
* merged properly, but slow q3_k and q5_k with u32 indexing * Start on new mat-vec * New format float paths working * Working q4_0 * Work on remaining legacy q-types * port k-quants to new matvec * remove old shader * Remove old constants, format * remove accidental file --------- Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local> Co-authored-by: Reese Levine <reeselevine1@gmail.com>
This commit is contained in:
parent
a678916623
commit
a6cc43c286
@ -44,18 +44,9 @@
|
||||
// Matrix-vector multiplication parameters
|
||||
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
||||
|
||||
// Must be multiple of 4 to work with vectorized paths, and must divide
|
||||
// mul_mat_vec wg size
|
||||
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
|
||||
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
|
||||
|
||||
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
|
||||
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
|
||||
|
||||
// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
|
||||
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
|
||||
// Requires at least two (and multiple of 2) k-quant blocks per tile
|
||||
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
|
||||
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4
|
||||
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4
|
||||
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4
|
||||
|
||||
// default size for legacy matrix multiplication
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE 256
|
||||
@ -78,6 +69,7 @@ struct ggml_webgpu_shader_lib_context {
|
||||
bool inplace = false;
|
||||
bool overlap = false;
|
||||
bool src_overlap = false;
|
||||
bool supports_subgroups = false;
|
||||
bool supports_subgroup_matrix = false;
|
||||
uint32_t sg_mat_m = 0;
|
||||
uint32_t sg_mat_n = 0;
|
||||
@ -575,7 +567,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
|
||||
|
||||
struct ggml_webgpu_mul_mat_vec_shader_decisions {
|
||||
uint32_t wg_size;
|
||||
uint32_t tile_k;
|
||||
uint32_t outputs_per_wg;
|
||||
uint32_t vec_size;
|
||||
};
|
||||
@ -1326,7 +1317,7 @@ class ggml_webgpu_shader_lib {
|
||||
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
@ -1337,7 +1328,8 @@ class ggml_webgpu_shader_lib {
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "mul_mat_vec";
|
||||
std::string variant = "mul_mat_vec";
|
||||
const char * shader_src = wgsl_mul_mat_vec;
|
||||
|
||||
// src0 type (matrix row)
|
||||
switch (context.src0->type) {
|
||||
@ -1386,25 +1378,25 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
||||
|
||||
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
||||
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
|
||||
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
|
||||
|
||||
if (key.src0_type >= GGML_TYPE_Q2_K) {
|
||||
tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
|
||||
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
|
||||
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
|
||||
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
|
||||
defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION");
|
||||
variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce";
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines);
|
||||
auto processed = preprocessor.preprocess(shader_src, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
decisions->tile_k = tile_k;
|
||||
decisions->outputs_per_wg = outputs_per_wg;
|
||||
decisions->vec_size = key.vectorized ? 4 : 1;
|
||||
|
||||
|
||||
@ -181,6 +181,7 @@ struct webgpu_dispatch_desc {
|
||||
|
||||
struct webgpu_capabilities {
|
||||
wgpu::Limits limits;
|
||||
bool supports_subgroups = false;
|
||||
bool supports_subgroup_matrix = false;
|
||||
|
||||
uint32_t sg_mat_m = 0;
|
||||
@ -1164,14 +1165,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_Q6_K:
|
||||
use_fast = true;
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
// we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
|
||||
use_fast = !is_vec;
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q2_K:
|
||||
use_fast = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@ -1182,10 +1180,12 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
}
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.dst = dst;
|
||||
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
|
||||
@ -1287,7 +1287,8 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
|
||||
// Get or create pipeline
|
||||
webgpu_pipeline gather_pipeline, main_pipeline;
|
||||
webgpu_pipeline gather_pipeline;
|
||||
webgpu_pipeline main_pipeline;
|
||||
|
||||
std::vector<webgpu_dispatch_desc> dispatches;
|
||||
|
||||
@ -3040,6 +3041,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
|
||||
// we require f16 support
|
||||
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
||||
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
|
||||
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
// Accept f16 subgroup matrix configurations (square or non-square).
|
||||
@ -3072,11 +3075,14 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
#ifndef __EMSCRIPTEN__
|
||||
required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
|
||||
if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
||||
required_features.push_back(wgpu::FeatureName::Subgroups);
|
||||
required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (ctx->webgpu_global_ctx->capabilities.supports_subgroups) {
|
||||
required_features.push_back(wgpu::FeatureName::Subgroups);
|
||||
}
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
required_features.push_back(wgpu::FeatureName::TimestampQuery);
|
||||
#endif
|
||||
|
||||
@ -45,6 +45,13 @@ fn load_u16_at_src0(byte_offset: u32) -> u32 {
|
||||
return (word >> shift) & 0xFFFFu;
|
||||
}
|
||||
|
||||
// Always reads the 4-byte-aligned word containing byte_offset.
|
||||
// Caller extracts the 16-bit half it needs via & 0xFFFFu or >> 16u.
|
||||
// this is used in k-quants for better performance
|
||||
fn load_u32_at_src0_aligned(byte_offset: u32) -> u32 {
|
||||
return src0[(byte_offset & ~3u) / 4u];
|
||||
}
|
||||
|
||||
fn load_u32_at_src0(byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4u;
|
||||
let shift = (byte_offset & 0x3u) * 8u;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user