diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index a33ec5b7..b73a46db 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -82,6 +82,16 @@ static __global__ void fused_mul_silu_f32(int ne0, const float * x, const float dst[i] = x[row] * y[i] / (1.0f + expf(-x[row])); } +static __global__ void fused_mul_sigmoid_f32(int ne0, const float * x, const float * y, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + int row = i / ne0; + dst[i] = y[i] / (1.0f + expf(-x[row])); +} + static __global__ void fused_mul_silu_f32(int ne0, const float * x, const float * y, float * dst, const int k, float limit) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -280,7 +290,6 @@ static void fused_mul_silu_f32_cuda(const float * x, const float * y, float * ds } static void fused_mul_silu_f32_cuda(int ne0, const float * x, const float * y, float * dst, const int k, float limit, cudaStream_t stream) { - //printf("%s: ne0 = %d, nelem = %d limit = %g\n", __func__, ne0, k, limit); const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; if (limit < 1e-6f) { fused_mul_silu_f32<<>>(ne0, x, y, dst, k); @@ -289,6 +298,11 @@ static void fused_mul_silu_f32_cuda(int ne0, const float * x, const float * y, f } } +static void fused_mul_sigmoid_f32_cuda(int ne0, const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + fused_mul_sigmoid_f32<<>>(ne0, x, y, dst, k); +} + static void fused_mul_relu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; fused_mul_relu_f32<<>>(x, y, dst, k); @@ -445,9 +459,16 @@ void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * GGML_ASSERT(ggml_are_same_shape(src1, dst)); if (!ggml_are_same_shape(src0, src1)) { GGML_ASSERT(src0->ne[0] == 1 && src0->ne[1] == src1->ne[1] && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3]); - GGML_ASSERT(op == GGML_UNARY_OP_SILU); - fused_mul_silu_f32_cuda(src1->ne[0], (const float *)src0->data, (const float *)src1->data, (float *)dst->data, - ggml_nelements(dst), limit, ctx.stream()); + if (op == GGML_UNARY_OP_SILU) { + fused_mul_silu_f32_cuda(src1->ne[0], (const float *)src0->data, (const float *)src1->data, (float *)dst->data, + ggml_nelements(dst), limit, ctx.stream()); + } + else if (op == GGML_UNARY_OP_SIGMOID) { + fused_mul_sigmoid_f32_cuda(src1->ne[0], (const float *)src0->data, (const float *)src1->data, (float *)dst->data, + ggml_nelements(dst), ctx.stream()); + } else { + GGML_ABORT("Fatal error"); + } return; } GGML_ASSERT(ggml_are_same_shape(src0, src1)); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ee89086e..b899678c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6474,9 +6474,9 @@ static struct ggml_tensor * ggml_fused_mul_unary_impl( bool inplace) { GGML_ASSERT(ggml_is_contiguous(a)); - GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU); if (!ggml_are_same_shape(b, a)) { GGML_ASSERT(a->ne[0] == 1 && a->ne[1] == b->ne[1] && a->ne[2] == b->ne[2] && a->ne[3] == b->ne[3]); + GGML_ASSERT(op == GGML_UNARY_OP_SILU || op == GGML_UNARY_OP_SIGMOID); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, b) : ggml_dup_tensor(ctx, b); ggml_set_op_params_i32(result, 0, (int32_t) op); result->op = GGML_OP_FUSED_MUL_UNARY; @@ -6484,6 +6484,7 @@ static struct ggml_tensor * ggml_fused_mul_unary_impl( result->src[1] = b; return result; } + GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU); //GGML_ASSERT(ggml_are_same_shape(b, a)); bool is_node = false; @@ -15183,12 +15184,12 @@ static void ggml_compute_forward_fused_mul_unary_f32( if (!ggml_are_same_shape(src0, src1)) { GGML_ASSERT(src0->ne[0] == 1 && ggml_nrows(src0) == nr); - GGML_ASSERT(op == GGML_UNARY_OP_SILU); + GGML_ASSERT(op == GGML_UNARY_OP_SILU || op == GGML_UNARY_OP_SIGMOID); for (int i1 = ir0; i1 < ir1; i1++) { float * z = (float *) ((char *) dst->data + i1*( dst->nb[1])); const float * x = (const float *) ((char *) src0->data + i1*(src0->nb[1])); const float * y = (const float *) ((char *) src1->data + i1*(src1->nb[1])); - float gate = ggml_silu_f32(x[0]); + float gate = op == GGML_UNARY_OP_SILU ? ggml_silu_f32(x[0]) : 1.0f/(1.0f + expf(-x[0])); if (limit < 1e-6f) { for (int i = 0; i < nc; ++i) z[i] = gate * y[i]; } else {