From ff76c6731d9a72b7ea1dbc034020c70330a37ff8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 27 Mar 2026 14:39:14 +0200 Subject: [PATCH] cont : cache shift support --- src/llama-kv-cache.cpp | 55 ++++++++++++++++++++++++++++++++- src/llama-kv-cache.h | 1 + tools/completion/completion.cpp | 2 +- 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 5936a91ddb..632a92a304 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1545,11 +1545,44 @@ size_t llama_kv_cache::size_v_bytes() const { return size_v_bytes; } +static ggml_tensor * ggml_rotate_hadamard( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * rot) { + const auto n = rot->ne[0]; + + ggml_tensor * res; + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_mul_mat(ctx, rot, res); + res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + return res; +} + +static void set_input_hadamard(int n, float * data) { + assert(ggml_is_power_of_2(n)); + + data[0*n + 0] = 1.0 / sqrtf(n); + + for (int s = 1; s < n; s *= 2) { + for (int i = 0; i < s; i++) { + for (int j = 0; j < s; j++) { + const float val = data[i*n + j]; + + data[(i + s)*n + (j )] = val; + data[(i )*n + (j + s)] = val; + data[(i + s)*n + (j + s)] = -val; + } + } + } +} + ggml_tensor * llama_kv_cache::build_rope_shift( const llama_cparams & cparams, ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, + ggml_tensor * rotk, ggml_tensor * factors, float freq_base, float freq_scale, @@ -1575,10 +1608,14 @@ ggml_tensor * llama_kv_cache::build_rope_shift( // dequantize to f32 -> RoPE -> quantize back tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); + tmp = ggml_rotate_hadamard(ctx, tmp, rotk); + tmp = ggml_rope_ext(ctx, tmp, shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + tmp = ggml_rotate_hadamard(ctx, tmp, rotk); + tmp = ggml_cpy(ctx, tmp, cur); } else { // we rotate only the first n_rot dimensions @@ -1599,6 +1636,8 @@ public: ggml_tensor * k_shift; // I32 [kv_size*n_stream] + ggml_tensor * rotk = nullptr; + const llama_kv_cache * kv_self; }; @@ -1608,6 +1647,14 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { if (k_shift) { kv_self->set_input_k_shift(k_shift); } + + if (rotk) { + GGML_ASSERT(ggml_backend_buffer_is_host(rotk->buffer)); + + float * data = (float *) rotk->data; + + set_input_hadamard(rotk->ne[0], data); + } } ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { @@ -1619,6 +1666,12 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); ggml_set_input(inp->k_shift); + if (ggml_is_quantized(type_k())) { + int nrot = hparams.n_embd_head_k(); + inp->rotk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot); + ggml_set_input(inp->rotk); + } + const auto & cparams = lctx->get_cparams(); for (const auto & layer : layers) { @@ -1643,7 +1696,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_row_size(layer.k->type, n_embd_k_gqa), ggml_row_size(layer.k->type, n_embd_nope)); - ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il); + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->rotk, rope_factors, freq_base_l, freq_scale_l, il); ggml_build_forward_expand(gf, cur); } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 477701bb84..2d98050770 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -265,6 +265,7 @@ private: ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, + ggml_tensor * rotk, ggml_tensor * factors, float freq_base, float freq_scale, diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index 58d598fcc0..1151919104 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -628,7 +628,7 @@ int main(int argc, char ** argv) { const int n_left = n_past - params.n_keep; const int n_discard = n_left/2; - LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + LOG_WRN("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); llama_memory_seq_rm (mem, 0, params.n_keep , params.n_keep + n_discard);