|
|
|
|
@ -10,7 +10,7 @@
|
|
|
|
|
|
|
|
|
|
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
|
|
|
|
|
*/
|
|
|
|
|
template <size_t n_experts>
|
|
|
|
|
template <size_t n_experts, bool normalize>
|
|
|
|
|
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
|
|
|
|
|
float * weights,
|
|
|
|
|
int32_t * ids,
|
|
|
|
|
@ -58,7 +58,6 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|
|
|
|
tmp = warp_reduce_sum(tmp);
|
|
|
|
|
|
|
|
|
|
const float inv_sum = 1.0f / tmp;
|
|
|
|
|
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int i = 0; i < experts_per_thread; i++) {
|
|
|
|
|
wt[i] = wt[i] * inv_sum;
|
|
|
|
|
@ -68,6 +67,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|
|
|
|
//we do the argmax reduce over n_expert_used, each time marking
|
|
|
|
|
//the expert weight as -inf to exclude from the next iteration
|
|
|
|
|
|
|
|
|
|
[[maybe_unused]] float sum_selected = 0;
|
|
|
|
|
for (int k = 0; k < n_expert_used; k++) {
|
|
|
|
|
float max_val = wt[0];
|
|
|
|
|
int max_expert = threadIdx.x;
|
|
|
|
|
@ -91,6 +91,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sum_selected += max_val;
|
|
|
|
|
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
|
|
|
|
wt[max_expert / WARP_SIZE] = -INFINITY;
|
|
|
|
|
|
|
|
|
|
@ -98,8 +99,19 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|
|
|
|
ids[k] = max_expert;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!normalize) return;
|
|
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
float norm = 1/sum_selected;
|
|
|
|
|
for (int k = threadIdx.x; k < n_expert_used; k += WARP_SIZE) {
|
|
|
|
|
weights[k] *= norm;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <bool normalize>
|
|
|
|
|
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
|
|
|
|
const float * logits,
|
|
|
|
|
float * weights,
|
|
|
|
|
@ -114,34 +126,34 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
|
|
|
|
|
|
|
|
|
switch (n_expert) {
|
|
|
|
|
case 1:
|
|
|
|
|
topk_moe_cuda<1><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
topk_moe_cuda<1, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
break;
|
|
|
|
|
case 2:
|
|
|
|
|
topk_moe_cuda<2><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
topk_moe_cuda<2, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
break;
|
|
|
|
|
case 4:
|
|
|
|
|
topk_moe_cuda<4><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
topk_moe_cuda<4, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
break;
|
|
|
|
|
case 8:
|
|
|
|
|
topk_moe_cuda<8><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
topk_moe_cuda<8, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
break;
|
|
|
|
|
case 16:
|
|
|
|
|
topk_moe_cuda<16><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
topk_moe_cuda<16, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
break;
|
|
|
|
|
case 32:
|
|
|
|
|
topk_moe_cuda<32><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
topk_moe_cuda<32, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
break;
|
|
|
|
|
case 64:
|
|
|
|
|
topk_moe_cuda<64><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
topk_moe_cuda<64, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
break;
|
|
|
|
|
case 128:
|
|
|
|
|
topk_moe_cuda<128><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
topk_moe_cuda<128, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
break;
|
|
|
|
|
case 256:
|
|
|
|
|
topk_moe_cuda<256><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
topk_moe_cuda<256, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
break;
|
|
|
|
|
case 512:
|
|
|
|
|
topk_moe_cuda<512><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
topk_moe_cuda<512, normalize><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
GGML_ASSERT(false && "fatal error");
|
|
|
|
|
@ -168,9 +180,13 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|
|
|
|
|
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
const int n_expert_used = weights->ne[1];
|
|
|
|
|
|
|
|
|
|
launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
|
|
|
|
if (weights->op == GGML_OP_DIV) {
|
|
|
|
|
const int n_expert_used = weights->ne[0];
|
|
|
|
|
launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
|
|
|
|
} else {
|
|
|
|
|
const int n_expert_used = weights->ne[1];
|
|
|
|
|
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
|
|
|
|
|
|