mirror of
https://github.com/ggerganov/llama.cpp
synced 2026-04-29 10:41:41 +02:00
cont : cache shift support
This commit is contained in:
parent
7711b3a36a
commit
ff76c6731d
@ -1545,11 +1545,44 @@ size_t llama_kv_cache::size_v_bytes() const {
|
||||
return size_v_bytes;
|
||||
}
|
||||
|
||||
static ggml_tensor * ggml_rotate_hadamard(
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * rot) {
|
||||
const auto n = rot->ne[0];
|
||||
|
||||
ggml_tensor * res;
|
||||
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
|
||||
res = ggml_mul_mat(ctx, rot, res);
|
||||
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static void set_input_hadamard(int n, float * data) {
|
||||
assert(ggml_is_power_of_2(n));
|
||||
|
||||
data[0*n + 0] = 1.0 / sqrtf(n);
|
||||
|
||||
for (int s = 1; s < n; s *= 2) {
|
||||
for (int i = 0; i < s; i++) {
|
||||
for (int j = 0; j < s; j++) {
|
||||
const float val = data[i*n + j];
|
||||
|
||||
data[(i + s)*n + (j )] = val;
|
||||
data[(i )*n + (j + s)] = val;
|
||||
data[(i + s)*n + (j + s)] = -val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * rotk,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
@ -1575,10 +1608,14 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||
// dequantize to f32 -> RoPE -> quantize back
|
||||
tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
|
||||
tmp = ggml_rotate_hadamard(ctx, tmp, rotk);
|
||||
|
||||
tmp = ggml_rope_ext(ctx, tmp,
|
||||
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
||||
|
||||
tmp = ggml_rotate_hadamard(ctx, tmp, rotk);
|
||||
|
||||
tmp = ggml_cpy(ctx, tmp, cur);
|
||||
} else {
|
||||
// we rotate only the first n_rot dimensions
|
||||
@ -1599,6 +1636,8 @@ public:
|
||||
|
||||
ggml_tensor * k_shift; // I32 [kv_size*n_stream]
|
||||
|
||||
ggml_tensor * rotk = nullptr;
|
||||
|
||||
const llama_kv_cache * kv_self;
|
||||
};
|
||||
|
||||
@ -1608,6 +1647,14 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
||||
if (k_shift) {
|
||||
kv_self->set_input_k_shift(k_shift);
|
||||
}
|
||||
|
||||
if (rotk) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(rotk->buffer));
|
||||
|
||||
float * data = (float *) rotk->data;
|
||||
|
||||
set_input_hadamard(rotk->ne[0], data);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
|
||||
@ -1619,6 +1666,12 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
|
||||
ggml_set_input(inp->k_shift);
|
||||
|
||||
if (ggml_is_quantized(type_k())) {
|
||||
int nrot = hparams.n_embd_head_k();
|
||||
inp->rotk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
|
||||
ggml_set_input(inp->rotk);
|
||||
}
|
||||
|
||||
const auto & cparams = lctx->get_cparams();
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
@ -1643,7 +1696,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
ggml_row_size(layer.k->type, n_embd_nope));
|
||||
|
||||
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il);
|
||||
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->rotk, rope_factors, freq_base_l, freq_scale_l, il);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
@ -265,6 +265,7 @@ private:
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * rotk,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
|
||||
@ -628,7 +628,7 @@ int main(int argc, char ** argv) {
|
||||
const int n_left = n_past - params.n_keep;
|
||||
const int n_discard = n_left/2;
|
||||
|
||||
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||
LOG_WRN("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||
|
||||
llama_memory_seq_rm (mem, 0, params.n_keep , params.n_keep + n_discard);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user