diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 4d6556dd..796d9c7b 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -143,19 +143,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const int chunks_per_row = D2 / h2_per_chunk; + int k0_start = 0; #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4, WARP_SIZE/8, WARP_SIZE/16}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + for (int stride_k = WARP_SIZE; stride_k > WARP_SIZE/32; stride_k >>= 1) { const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; if (k0_start == k0_stop) { continue; } + const int stride_i = WARP_SIZE / stride_k; #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int i = i0 + threadIdx.y*stride_i + threadIdx.x / stride_k; if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { break; @@ -168,6 +168,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); } } + k0_start = k0_stop; } } else { static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");