From 225088ea76687c005e282d3a48d73e9c0c8c5091 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Wed, 22 Apr 2026 18:02:56 +0530 Subject: [PATCH] sycl: Improve mul_mat_id memory efficiency and add BF16 fast path (#22119) * sycl: size mul_mat_id staging buffers by routed rows Previously src1_contiguous/dst_contiguous in ggml_sycl_mul_mat_id were sized to ggml_nelements(src1/dst), which over-allocates when ne12 > 1 and can fail with UR_RESULT_ERROR_OUT_OF_HOST_MEMORY on Level Zero for MoE models (notably with --cpu-moe). Size them by the actual number of routed rows (ids->ne[1] * n_ids) instead. * sycl: add bf16 mul_mat fast path via DNNL When src0 is BF16 (commonly the case for lm_head / output.weight), the existing f16 path is skipped because bf16 isn't covered, and the f32 fallback dequantizes the entire src0 slab to f32 in a single pool alloc (row_diff*ne00 floats). For large-vocab models this can reach several GB and fail with UR_RESULT_ERROR_OUT_OF_HOST_MEMORY on Level Zero. Add a bf16xbf16 -> f32 DNNL matmul fast path that uses the bf16 storage in place and only materializes a small src1 bf16 conversion buffer. bf16 matmul accumulates in f32, so it's correct even when the op requests GGML_PREC_F32 (as lm_head does). - gemm.hpp: map bfloat16 to dnnl::memory::data_type::bf16. - convert.{hpp,cpp}: expose ggml_get_to_bf16_sycl for f32/f16/bf16 -> bf16. - ggml-sycl.cpp: take the bf16 path early in ggml_sycl_op_mul_mat_sycl when DNNL and GGML_SYCL_HAS_BF16 are both available. --- ggml/src/ggml-sycl/common.hpp | 7 +++++++ ggml/src/ggml-sycl/convert.cpp | 23 ++++++++++++++++------- ggml/src/ggml-sycl/convert.hpp | 9 +++++++++ ggml/src/ggml-sycl/gemm.hpp | 3 +++ ggml/src/ggml-sycl/ggml-sycl.cpp | 30 ++++++++++++++++++++++++++++-- ggml/src/ggml-sycl/set_rows.cpp | 8 +++++++- 6 files changed, 70 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index fd84c91785..0101b27640 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -28,6 +28,13 @@ namespace syclexp = sycl::ext::oneapi::experimental; +#if defined(__INTEL_LLVM_COMPILER) && __has_include() + #include + #ifndef GGML_SYCL_HAS_BF16 + #define GGML_SYCL_HAS_BF16 + #endif +#endif + #if GGML_SYCL_DNNL #include "dnnl.hpp" #include "dnnl_sycl.hpp" diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index f3c521b45f..67b9c06f3e 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -2,13 +2,6 @@ #include "dequantize.hpp" #include "presets.hpp" -#if defined(__INTEL_LLVM_COMPILER) - #if __has_include() - #include - #define GGML_SYCL_HAS_BF16 - #endif -#endif - template static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, const sycl::nd_item<3> &item_ct1) { @@ -767,6 +760,22 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { } +#ifdef GGML_SYCL_HAS_BF16 +to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * /*dst*/) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_sycl; + case GGML_TYPE_F16: + return convert_unary_sycl; + case GGML_TYPE_BF16: + return convert_unary_sycl; + default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); + return nullptr; + } +} +#endif + to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) { switch (type) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index 6e621f2154..8de79d10ff 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -23,6 +23,11 @@ typedef to_t_sycl_t to_fp16_sycl_t; to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst); to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst); +#ifdef GGML_SYCL_HAS_BF16 +typedef to_t_sycl_t to_bf16_sycl_t; +to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * dst); +#endif + // Nc = Non-contiguous template using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, @@ -35,15 +40,19 @@ template inline dst_t ggml_sycl_cast(src_t x) { if constexpr (std::is_same_v) { return x; +#ifdef GGML_SYCL_HAS_BF16 } else if constexpr (std::is_same_v) { return sycl::ext::oneapi::bfloat16(float(x)); } else if constexpr (std::is_same_v) { return static_cast(x); +#endif } else if constexpr (std::is_same_v && std::is_same_v) { return x.template convert(); +#ifdef GGML_SYCL_HAS_BF16 } else if constexpr (std::is_same_v && std::is_same_v>) { return {x.x, x.y}; +#endif } else if constexpr(std::is_same_v) { return int32_t(x); } else { diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index dcf6c7aeeb..c202da110b 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -29,6 +29,9 @@ public: static constexpr dt to_dt() { if constexpr (std::is_same_v) return dt::f32; else if constexpr (std::is_same_v) return dt::f16; +#ifdef GGML_SYCL_HAS_BF16 + else if constexpr (std::is_same_v) return dt::bf16; +#endif else static_assert(0); } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index c02a41ad86..3829da8790 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2176,6 +2176,31 @@ inline void ggml_sycl_op_mul_mat_sycl( #else bool use_fp16 = false; #endif + +#if GGML_SYCL_DNNL && defined(GGML_SYCL_HAS_BF16) + // Fast path for bf16 src0 + if (src0->type == GGML_TYPE_BF16 && !g_ggml_sycl_disable_dnn && ggml_is_contiguous(src0) && + row_diff == src0->ne[1]) { + using bf16_t = sycl::ext::oneapi::bfloat16; + ggml_sycl_pool_alloc src1_as_bf16(ctx.pool(), src1_ncols*ne10); + if (src1->type != GGML_TYPE_BF16) { + const to_bf16_sycl_t to_bf16_sycl = ggml_get_to_bf16_sycl(src1->type, dst); + GGML_ASSERT(to_bf16_sycl != nullptr); + to_bf16_sycl(src1_ddf_i, src1_as_bf16.get(), src1_ncols*ne10, stream); + } else { + stream->memcpy(src1_as_bf16.get(), src1_ddf_i, src1_ncols*ne10*sizeof(bf16_t)); + } + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, + src0_dd_i, DnnlGemmWrapper::to_dt(), + src1_as_bf16.get(), DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_padded_row_size); + return; + } +#endif + if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { ggml_sycl_pool_alloc src0_as_f16(ctx.pool()); @@ -3848,8 +3873,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, } } } else { - ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); - ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + const int64_t n_routed_rows = ids->ne[1] * n_ids; + ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne10); + ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne0); src1_row.data = src1_contiguous.get(); dst_row.data = dst_contiguous.get(); diff --git a/ggml/src/ggml-sycl/set_rows.cpp b/ggml/src/ggml-sycl/set_rows.cpp index a641c10091..8fb4194352 100644 --- a/ggml/src/ggml-sycl/set_rows.cpp +++ b/ggml/src/ggml-sycl/set_rows.cpp @@ -4,7 +4,11 @@ namespace utils { template static constexpr bool is_arithmetic_v() { - return std::is_arithmetic_v || std::is_same_v || std::is_same_v; + return std::is_arithmetic_v || std::is_same_v +#ifdef GGML_SYCL_HAS_BF16 + || std::is_same_v +#endif + ; } } @@ -181,6 +185,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s stream ); break; +#ifdef GGML_SYCL_HAS_BF16 case GGML_TYPE_BF16: set_rows_sycl( src0_d, src1_d, (char *)dst->data, @@ -193,6 +198,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s stream ); break; +#endif case GGML_TYPE_Q8_0: set_rows_sycl_q(src0_d, src1_d, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream); break;