It works now, but performance gain is very minor

This commit is contained in:
Kawrakow 2026-02-06 12:55:11 +00:00
parent c1b8ef5b09
commit 927593d424
2 changed files with 29 additions and 7 deletions

View File

@ -82,6 +82,16 @@ static __global__ void fused_mul_silu_f32(int ne0, const float * x, const float
dst[i] = x[row] * y[i] / (1.0f + expf(-x[row]));
}
static __global__ void fused_mul_sigmoid_f32(int ne0, const float * x, const float * y, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
int row = i / ne0;
dst[i] = y[i] / (1.0f + expf(-x[row]));
}
static __global__ void fused_mul_silu_f32(int ne0, const float * x, const float * y, float * dst, const int k, float limit) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -280,7 +290,6 @@ static void fused_mul_silu_f32_cuda(const float * x, const float * y, float * ds
}
static void fused_mul_silu_f32_cuda(int ne0, const float * x, const float * y, float * dst, const int k, float limit, cudaStream_t stream) {
//printf("%s: ne0 = %d, nelem = %d limit = %g\n", __func__, ne0, k, limit);
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
if (limit < 1e-6f) {
fused_mul_silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(ne0, x, y, dst, k);
@ -289,6 +298,11 @@ static void fused_mul_silu_f32_cuda(int ne0, const float * x, const float * y, f
}
}
static void fused_mul_sigmoid_f32_cuda(int ne0, const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
fused_mul_sigmoid_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(ne0, x, y, dst, k);
}
static void fused_mul_relu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
fused_mul_relu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
@ -445,9 +459,16 @@ void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor *
GGML_ASSERT(ggml_are_same_shape(src1, dst));
if (!ggml_are_same_shape(src0, src1)) {
GGML_ASSERT(src0->ne[0] == 1 && src0->ne[1] == src1->ne[1] && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3]);
GGML_ASSERT(op == GGML_UNARY_OP_SILU);
fused_mul_silu_f32_cuda(src1->ne[0], (const float *)src0->data, (const float *)src1->data, (float *)dst->data,
ggml_nelements(dst), limit, ctx.stream());
if (op == GGML_UNARY_OP_SILU) {
fused_mul_silu_f32_cuda(src1->ne[0], (const float *)src0->data, (const float *)src1->data, (float *)dst->data,
ggml_nelements(dst), limit, ctx.stream());
}
else if (op == GGML_UNARY_OP_SIGMOID) {
fused_mul_sigmoid_f32_cuda(src1->ne[0], (const float *)src0->data, (const float *)src1->data, (float *)dst->data,
ggml_nelements(dst), ctx.stream());
} else {
GGML_ABORT("Fatal error");
}
return;
}
GGML_ASSERT(ggml_are_same_shape(src0, src1));

View File

@ -6474,9 +6474,9 @@ static struct ggml_tensor * ggml_fused_mul_unary_impl(
bool inplace) {
GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU);
if (!ggml_are_same_shape(b, a)) {
GGML_ASSERT(a->ne[0] == 1 && a->ne[1] == b->ne[1] && a->ne[2] == b->ne[2] && a->ne[3] == b->ne[3]);
GGML_ASSERT(op == GGML_UNARY_OP_SILU || op == GGML_UNARY_OP_SIGMOID);
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, b) : ggml_dup_tensor(ctx, b);
ggml_set_op_params_i32(result, 0, (int32_t) op);
result->op = GGML_OP_FUSED_MUL_UNARY;
@ -6484,6 +6484,7 @@ static struct ggml_tensor * ggml_fused_mul_unary_impl(
result->src[1] = b;
return result;
}
GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU);
//GGML_ASSERT(ggml_are_same_shape(b, a));
bool is_node = false;
@ -15183,12 +15184,12 @@ static void ggml_compute_forward_fused_mul_unary_f32(
if (!ggml_are_same_shape(src0, src1)) {
GGML_ASSERT(src0->ne[0] == 1 && ggml_nrows(src0) == nr);
GGML_ASSERT(op == GGML_UNARY_OP_SILU);
GGML_ASSERT(op == GGML_UNARY_OP_SILU || op == GGML_UNARY_OP_SIGMOID);
for (int i1 = ir0; i1 < ir1; i1++) {
float * z = (float *) ((char *) dst->data + i1*( dst->nb[1]));
const float * x = (const float *) ((char *) src0->data + i1*(src0->nb[1]));
const float * y = (const float *) ((char *) src1->data + i1*(src1->nb[1]));
float gate = ggml_silu_f32(x[0]);
float gate = op == GGML_UNARY_OP_SILU ? ggml_silu_f32(x[0]) : 1.0f/(1.0f + expf(-x[0]));
if (limit < 1e-6f) {
for (int i = 0; i < nc; ++i) z[i] = gate * y[i];
} else {