mirror of
https://github.com/ggerganov/llama.cpp
synced 2026-03-02 05:09:23 +01:00
* llama : remove write/read of output ids/logits/embeddings This commit removes the write/read of output ids, logits and embeddings from the llama context state. Refs: https://github.com/ggml-org/llama.cpp/pull/18862#issuecomment-3756330941 * completion : add replying of session state This commit updates the session handing in the completion tool to handle the that logits are no longer stored in the session file. Instead, we need to replay the last token to get the logits for sampling. * common : add common_prompt_batch_decode function This commit adds a new function which is responsible for decoding prompt and optionally handle the saving for session data. * update save-state.cpp to use llama_state_load_file This commit updates the save-load-state example to utilize the new llama_state_load_file function for loading the model state from a file. And it also replays the last token after loading since this state is now stored before the last token is processed. * examples : set n_seq_max = 2 for ctx3 This commit updates the save-load-state example to set the n_seq_max parameter to 2 when initializing the ctx3 context. The motivation for this change is that using 1 as n_parallel/n_seq_max the context only supports one sequence, but the test laster tries to use a second sequence which results in the following error: ```console main : loaded state with 4 tokens main : seq 0 copied, 225760 bytes main : kv cache cleared find_slot: seq_id=1 >= n_seq_max=1 Try using a bigger --parallel value state_read_meta: failed to find available cells in kv cache ``` This seems to only happen for recurrent/hybrid models.
236 lines
7.0 KiB
C++
236 lines
7.0 KiB
C++
#include "arg.h"
|
|
#include "common.h"
|
|
#include "llama.h"
|
|
|
|
#include <vector>
|
|
#include <cstdio>
|
|
|
|
|
|
int main(int argc, char ** argv) {
|
|
common_params params;
|
|
|
|
params.prompt = "The quick brown fox";
|
|
params.sampling.seed = 1234;
|
|
|
|
const std::string_view state_file = "dump_state.bin";
|
|
|
|
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
|
return 1;
|
|
}
|
|
|
|
if (params.n_parallel == 1) {
|
|
// the example uses 2 sequences, so when n_parallel == 1, we need to enable unified kv cache
|
|
printf("%s: n_parallel == 1, enabling unified kv cache\n", __func__);
|
|
params.kv_unified = true;
|
|
}
|
|
|
|
common_init();
|
|
|
|
if (params.n_predict < 0) {
|
|
params.n_predict = 16;
|
|
}
|
|
|
|
auto n_past = 0;
|
|
|
|
std::string result0;
|
|
std::string result1;
|
|
std::string result2;
|
|
|
|
// init
|
|
auto llama_init = common_init_from_params(params);
|
|
|
|
auto * model = llama_init->model();
|
|
auto * ctx = llama_init->context();
|
|
|
|
if (model == nullptr || ctx == nullptr) {
|
|
fprintf(stderr, "%s : failed to init\n", __func__);
|
|
return 1;
|
|
}
|
|
|
|
auto sparams = llama_sampler_chain_default_params();
|
|
|
|
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
|
|
|
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed));
|
|
|
|
// tokenize prompt
|
|
auto tokens = common_tokenize(ctx, params.prompt, true);
|
|
|
|
const bool save_state = true;
|
|
if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) {
|
|
return 1;
|
|
}
|
|
|
|
// first run
|
|
printf("\nfirst run: %s", params.prompt.c_str());
|
|
|
|
llama_batch batch = llama_batch_init(1, 0, 1);
|
|
|
|
for (auto i = 0; i < params.n_predict; i++) {
|
|
auto next_token = llama_sampler_sample(smpl, ctx, -1);
|
|
auto next_token_str = common_token_to_piece(ctx, next_token);
|
|
|
|
printf("%s", next_token_str.c_str());
|
|
result0 += next_token_str;
|
|
|
|
common_batch_clear(batch);
|
|
common_batch_add(batch, next_token, n_past, {0}, true);
|
|
|
|
if (llama_decode(ctx, batch)) {
|
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
|
llama_batch_free(batch);
|
|
return 1;
|
|
}
|
|
n_past += 1;
|
|
}
|
|
|
|
printf("\n\n");
|
|
|
|
// make new context
|
|
llama_context * ctx2 = llama_init_from_model(model, common_context_params_to_llama(params));
|
|
|
|
llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
|
|
|
|
llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sampling.seed));
|
|
|
|
printf("\nsecond run: %s", params.prompt.c_str());
|
|
|
|
// load state from file
|
|
std::vector<llama_token> unused_sts(tokens.size()); // unused session tokens.
|
|
size_t n_token_count_out = 0;
|
|
|
|
if (!llama_state_load_file(ctx2, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
|
|
fprintf(stderr, "\n%s : failed to load state\n", __func__);
|
|
return 1;
|
|
}
|
|
|
|
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
|
|
|
|
// restore state (last tokens)
|
|
n_past = n_token_count_out;
|
|
if (!common_replay_last_token(ctx2, tokens.back(), n_past)) {
|
|
return 1;
|
|
}
|
|
++n_past;
|
|
|
|
// second run
|
|
for (auto i = 0; i < params.n_predict; i++) {
|
|
auto next_token = llama_sampler_sample(smpl2, ctx2, -1);
|
|
auto next_token_str = common_token_to_piece(ctx2, next_token);
|
|
|
|
printf("%s", next_token_str.c_str());
|
|
result1 += next_token_str;
|
|
|
|
common_batch_clear(batch);
|
|
common_batch_add(batch, next_token, n_past, {0}, true);
|
|
|
|
if (llama_decode(ctx2, batch)) {
|
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
|
llama_batch_free(batch);
|
|
return 1;
|
|
}
|
|
n_past += 1;
|
|
}
|
|
|
|
printf("\n\n");
|
|
|
|
if (result0 != result1) {
|
|
fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__);
|
|
return 1;
|
|
}
|
|
|
|
// make new context
|
|
auto params_ctx3 = common_context_params_to_llama(params);
|
|
params_ctx3.n_seq_max = 2;
|
|
llama_context * ctx3 = llama_init_from_model(model, params_ctx3);
|
|
|
|
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
|
|
|
|
llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sampling.seed));
|
|
|
|
printf("\nsingle seq run: %s", params.prompt.c_str());
|
|
|
|
// load state (rng, logits, embedding and kv_cache) from file
|
|
n_token_count_out = 0;
|
|
|
|
if (!llama_state_load_file(ctx3, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
|
|
fprintf(stderr, "\n%s : failed to load state\n", __func__);
|
|
return 1;
|
|
}
|
|
|
|
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
|
|
|
|
// restore state (last tokens)
|
|
n_past = n_token_count_out;
|
|
if (!common_replay_last_token(ctx3, tokens.back(), n_past)) {
|
|
return 1;
|
|
}
|
|
++n_past;
|
|
|
|
// save seq 0 and load into seq 1
|
|
{
|
|
// save kv of seq 0
|
|
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
|
|
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
|
|
if (ncopy != seq_store.size()) {
|
|
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
|
|
return 1;
|
|
}
|
|
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
|
|
|
|
// erase whole kv
|
|
llama_memory_clear(llama_get_memory(ctx3), true);
|
|
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
|
|
|
// restore kv into seq 1
|
|
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
|
|
if (nset != seq_store.size()) {
|
|
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
|
|
return 1;
|
|
}
|
|
fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
|
|
}
|
|
|
|
// third run with seq 1 instead of 0
|
|
for (auto i = 0; i < params.n_predict; i++) {
|
|
auto next_token = llama_sampler_sample(smpl3, ctx3, -1);
|
|
auto next_token_str = common_token_to_piece(ctx3, next_token);
|
|
|
|
printf("%s", next_token_str.c_str());
|
|
result2 += next_token_str;
|
|
|
|
common_batch_clear(batch);
|
|
common_batch_add(batch, next_token, n_past, {1}, true);
|
|
|
|
if (llama_decode(ctx3, batch)) {
|
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
|
llama_batch_free(batch);
|
|
return 1;
|
|
}
|
|
n_past += 1;
|
|
}
|
|
|
|
printf("\n");
|
|
|
|
llama_sampler_free(smpl);
|
|
llama_sampler_free(smpl2);
|
|
llama_sampler_free(smpl3);
|
|
|
|
llama_batch_free(batch);
|
|
|
|
// this one is managed by common_init_result
|
|
//llama_free(ctx);
|
|
|
|
llama_free(ctx2);
|
|
llama_free(ctx3);
|
|
|
|
if (result0 != result2) {
|
|
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
|
return 1;
|
|
}
|
|
|
|
fprintf(stderr, "\n%s : success\n", __func__);
|
|
|
|
return 0;
|
|
}
|