llama.cpp/examples/save-load-state/save-load-state.cpp
Daniel Bevenius 2b6dfe824d
llama : remove write/read of output ids/logits/embeddings (#18862)
* llama : remove write/read of output ids/logits/embeddings

This commit removes the write/read of output ids, logits and
embeddings from the llama context state.

Refs: https://github.com/ggml-org/llama.cpp/pull/18862#issuecomment-3756330941

* completion : add replying of session state

This commit updates the session handing in the completion tool to handle
the that logits are no longer stored in the session file. Instead, we
need to replay the last token to get the logits for sampling.

* common : add common_prompt_batch_decode function

This commit adds a new function which is responsible for decoding prompt
and optionally handle the saving for session data.

* update save-state.cpp to use llama_state_load_file

This commit updates the save-load-state example to utilize the new
llama_state_load_file function for loading the model state from a file.
And it also replays the last token after loading since this state is now
stored before the last token is processed.

* examples : set n_seq_max = 2 for ctx3

This commit updates the save-load-state example to set the n_seq_max
parameter to 2 when initializing the ctx3 context.

The motivation for this change is that using 1 as n_parallel/n_seq_max
the context only supports one sequence, but the test laster tries to
use a second sequence which results in the following error:
```console
main : loaded state with 4 tokens
main : seq 0 copied, 225760 bytes
main : kv cache cleared
find_slot: seq_id=1 >= n_seq_max=1 Try using a bigger --parallel value
state_read_meta: failed to find available cells in kv cache
```
This seems to only happen for recurrent/hybrid models.
2026-02-23 07:04:30 +01:00

236 lines
7.0 KiB
C++

#include "arg.h"
#include "common.h"
#include "llama.h"
#include <vector>
#include <cstdio>
int main(int argc, char ** argv) {
common_params params;
params.prompt = "The quick brown fox";
params.sampling.seed = 1234;
const std::string_view state_file = "dump_state.bin";
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
}
if (params.n_parallel == 1) {
// the example uses 2 sequences, so when n_parallel == 1, we need to enable unified kv cache
printf("%s: n_parallel == 1, enabling unified kv cache\n", __func__);
params.kv_unified = true;
}
common_init();
if (params.n_predict < 0) {
params.n_predict = 16;
}
auto n_past = 0;
std::string result0;
std::string result1;
std::string result2;
// init
auto llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
if (model == nullptr || ctx == nullptr) {
fprintf(stderr, "%s : failed to init\n", __func__);
return 1;
}
auto sparams = llama_sampler_chain_default_params();
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed));
// tokenize prompt
auto tokens = common_tokenize(ctx, params.prompt, true);
const bool save_state = true;
if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) {
return 1;
}
// first run
printf("\nfirst run: %s", params.prompt.c_str());
llama_batch batch = llama_batch_init(1, 0, 1);
for (auto i = 0; i < params.n_predict; i++) {
auto next_token = llama_sampler_sample(smpl, ctx, -1);
auto next_token_str = common_token_to_piece(ctx, next_token);
printf("%s", next_token_str.c_str());
result0 += next_token_str;
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {0}, true);
if (llama_decode(ctx, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
return 1;
}
n_past += 1;
}
printf("\n\n");
// make new context
llama_context * ctx2 = llama_init_from_model(model, common_context_params_to_llama(params));
llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sampling.seed));
printf("\nsecond run: %s", params.prompt.c_str());
// load state from file
std::vector<llama_token> unused_sts(tokens.size()); // unused session tokens.
size_t n_token_count_out = 0;
if (!llama_state_load_file(ctx2, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
fprintf(stderr, "\n%s : failed to load state\n", __func__);
return 1;
}
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
// restore state (last tokens)
n_past = n_token_count_out;
if (!common_replay_last_token(ctx2, tokens.back(), n_past)) {
return 1;
}
++n_past;
// second run
for (auto i = 0; i < params.n_predict; i++) {
auto next_token = llama_sampler_sample(smpl2, ctx2, -1);
auto next_token_str = common_token_to_piece(ctx2, next_token);
printf("%s", next_token_str.c_str());
result1 += next_token_str;
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {0}, true);
if (llama_decode(ctx2, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
return 1;
}
n_past += 1;
}
printf("\n\n");
if (result0 != result1) {
fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__);
return 1;
}
// make new context
auto params_ctx3 = common_context_params_to_llama(params);
params_ctx3.n_seq_max = 2;
llama_context * ctx3 = llama_init_from_model(model, params_ctx3);
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sampling.seed));
printf("\nsingle seq run: %s", params.prompt.c_str());
// load state (rng, logits, embedding and kv_cache) from file
n_token_count_out = 0;
if (!llama_state_load_file(ctx3, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
fprintf(stderr, "\n%s : failed to load state\n", __func__);
return 1;
}
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
// restore state (last tokens)
n_past = n_token_count_out;
if (!common_replay_last_token(ctx3, tokens.back(), n_past)) {
return 1;
}
++n_past;
// save seq 0 and load into seq 1
{
// save kv of seq 0
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
if (ncopy != seq_store.size()) {
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
return 1;
}
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
// erase whole kv
llama_memory_clear(llama_get_memory(ctx3), true);
fprintf(stderr, "%s : kv cache cleared\n", __func__);
// restore kv into seq 1
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
if (nset != seq_store.size()) {
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
return 1;
}
fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
}
// third run with seq 1 instead of 0
for (auto i = 0; i < params.n_predict; i++) {
auto next_token = llama_sampler_sample(smpl3, ctx3, -1);
auto next_token_str = common_token_to_piece(ctx3, next_token);
printf("%s", next_token_str.c_str());
result2 += next_token_str;
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {1}, true);
if (llama_decode(ctx3, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
return 1;
}
n_past += 1;
}
printf("\n");
llama_sampler_free(smpl);
llama_sampler_free(smpl2);
llama_sampler_free(smpl3);
llama_batch_free(batch);
// this one is managed by common_init_result
//llama_free(ctx);
llama_free(ctx2);
llama_free(ctx3);
if (result0 != result2) {
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
return 1;
}
fprintf(stderr, "\n%s : success\n", __func__);
return 0;
}