cont : cache shift support

This commit is contained in:
Georgi Gerganov 2026-03-27 14:39:14 +02:00
parent 7711b3a36a
commit ff76c6731d
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 56 additions and 2 deletions

View File

@ -1545,11 +1545,44 @@ size_t llama_kv_cache::size_v_bytes() const {
return size_v_bytes;
}
static ggml_tensor * ggml_rotate_hadamard(
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * rot) {
const auto n = rot->ne[0];
ggml_tensor * res;
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
res = ggml_mul_mat(ctx, rot, res);
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
return res;
}
static void set_input_hadamard(int n, float * data) {
assert(ggml_is_power_of_2(n));
data[0*n + 0] = 1.0 / sqrtf(n);
for (int s = 1; s < n; s *= 2) {
for (int i = 0; i < s; i++) {
for (int j = 0; j < s; j++) {
const float val = data[i*n + j];
data[(i + s)*n + (j )] = val;
data[(i )*n + (j + s)] = val;
data[(i + s)*n + (j + s)] = -val;
}
}
}
}
ggml_tensor * llama_kv_cache::build_rope_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * rotk,
ggml_tensor * factors,
float freq_base,
float freq_scale,
@ -1575,10 +1608,14 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
// dequantize to f32 -> RoPE -> quantize back
tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
tmp = ggml_rotate_hadamard(ctx, tmp, rotk);
tmp = ggml_rope_ext(ctx, tmp,
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
tmp = ggml_rotate_hadamard(ctx, tmp, rotk);
tmp = ggml_cpy(ctx, tmp, cur);
} else {
// we rotate only the first n_rot dimensions
@ -1599,6 +1636,8 @@ public:
ggml_tensor * k_shift; // I32 [kv_size*n_stream]
ggml_tensor * rotk = nullptr;
const llama_kv_cache * kv_self;
};
@ -1608,6 +1647,14 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
if (k_shift) {
kv_self->set_input_k_shift(k_shift);
}
if (rotk) {
GGML_ASSERT(ggml_backend_buffer_is_host(rotk->buffer));
float * data = (float *) rotk->data;
set_input_hadamard(rotk->ne[0], data);
}
}
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
@ -1619,6 +1666,12 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
ggml_set_input(inp->k_shift);
if (ggml_is_quantized(type_k())) {
int nrot = hparams.n_embd_head_k();
inp->rotk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
ggml_set_input(inp->rotk);
}
const auto & cparams = lctx->get_cparams();
for (const auto & layer : layers) {
@ -1643,7 +1696,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
ggml_row_size(layer.k->type, n_embd_k_gqa),
ggml_row_size(layer.k->type, n_embd_nope));
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il);
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->rotk, rope_factors, freq_base_l, freq_scale_l, il);
ggml_build_forward_expand(gf, cur);
}

View File

@ -265,6 +265,7 @@ private:
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * rotk,
ggml_tensor * factors,
float freq_base,
float freq_scale,

View File

@ -628,7 +628,7 @@ int main(int argc, char ** argv) {
const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
LOG_WRN("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_memory_seq_rm (mem, 0, params.n_keep , params.n_keep + n_discard);