diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index a03dbce887..73394b74ee 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -8,8 +8,24 @@ #include #include #include +#include #include #include +#include + +struct spec_checkpoint { + int64_t n_tokens = 0; + + std::vector data; + + size_t size() const { + return data.size(); + } + + bool empty() const { + return data.empty(); + } +}; int main(int argc, char ** argv) { std::setlocale(LC_NUMERIC, "C"); @@ -46,6 +62,14 @@ int main(int argc, char ** argv) { model_tgt = llama_init_tgt->model(); ctx_tgt = llama_init_tgt->context(); + // check if the context supports partial sequence removal + const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt); + const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + + if (use_ckpt) { + LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n"); + } + const llama_vocab * vocab = llama_model_get_vocab(model_tgt); // load the draft model @@ -119,7 +143,7 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // target model sampling context - struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); + common_sampler_ptr smpl(common_sampler_init(model_tgt, params.sampling)); // eval the prompt llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); @@ -142,21 +166,61 @@ int main(int argc, char ** argv) { llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); + size_t n_draft = 0; + + llama_tokens draft; + spec_checkpoint spec_ckpt; + const auto t_enc_end = ggml_time_us(); const auto t_dec_start = ggml_time_us(); while (true) { - // optionally, generate draft tokens that can be appended to the target batch + // generate or reuse draft tokens // // this is the most important part of the speculation. the more probable tokens that are provided here // the better the performance will be. in theory, this computation can be performed asynchronously and even // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens // from a cache or lookup tables. // - llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last); + if (draft.empty()) { + // generate a new draft + draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last); - //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); + if ((int) draft.size() > params_spec.n_max) { + LOG_WRN("draft size %zu exceeds max %d, truncating\n", draft.size(), params_spec.n_max); + draft.resize(params_spec.n_max); + } + + if ((int) draft.size() < params_spec.n_min) { + LOG_DBG("ignoring small draft: %zu < %d\n", draft.size(), params_spec.n_min); + draft.clear(); + } + + // save the original draft size + n_draft = draft.size(); + + // save a checkpoint of the target context before evaluating the draft + // this allows us to restore the state if partial draft acceptance occurs + if (!draft.empty() && use_ckpt) { + const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + spec_ckpt.data.resize(ckpt_size); + + const size_t n = llama_state_seq_get_data_ext(ctx_tgt, spec_ckpt.data.data(), ckpt_size, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + GGML_ASSERT(n == ckpt_size); + + spec_ckpt.n_tokens = (int64_t) prompt_tgt.size(); + LOG_DBG("created speculative checkpoint (n_tokens = %" PRId64 ", size = %.3f MiB)\n", + spec_ckpt.n_tokens, (float) spec_ckpt.data.size() / 1024 / 1024); + } + } else { + // we have a previous (partial) draft to reuse from checkpoint restoration + if (use_ckpt) { + GGML_ASSERT(!spec_ckpt.empty()); + } + } + + GGML_ASSERT(n_draft > 0); // always have a token to evaluate from before - id_last common_batch_clear(batch_tgt); @@ -178,6 +242,12 @@ int main(int argc, char ** argv) { llama_decode(ctx_tgt, batch_tgt); } + // only save the sampler sampler state if we use checkpoints + common_sampler_ptr smpl_save; + if (use_ckpt) { + smpl_save.reset(common_sampler_clone(smpl.get())); + } + // sample from the full target batch and return the accepted tokens based on the target sampler // // for each token to be accepted, the sampler would have to sample that same token @@ -185,14 +255,38 @@ int main(int argc, char ** argv) { // available logits from the batch and sample the next token until we run out of logits or the sampler // disagrees with the draft // - const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft); + auto ids = common_sampler_sample_and_accept_n(smpl.get(), ctx_tgt, draft); //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str()); GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token + // check for partial draft acceptance: + // if the context doesn't support partial sequence removal, restore the checkpoint + // and make the accepted tokens the new partial draft for the next iteration + if (use_ckpt && ids.size() - 1 < draft.size()) { + LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size()); + + draft = std::move(ids); + + const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + GGML_ASSERT(n == spec_ckpt.size()); + + llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1); + + prompt_tgt.resize(spec_ckpt.n_tokens); + smpl = std::move(smpl_save); + + n_past = (int) prompt_tgt.size(); + + continue; + } + + common_speculative_accept(spec, ids.size() - 1); + + // full acceptance: consume the draft and commit accepted tokens n_past += ids.size() - 1; - n_drafted += draft.size(); // note: we ignore the discarded small drafts + n_drafted += n_draft; // note: we ignore the discarded small drafts n_accept += ids.size() - 1; n_predict += ids.size(); @@ -222,6 +316,9 @@ int main(int argc, char ** argv) { LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last); + // clear the draft since it has been consumed + draft.clear(); + { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); @@ -254,11 +351,10 @@ int main(int argc, char ** argv) { LOG_INF("\n"); LOG_INF("target:\n\n"); - common_perf_print(ctx_tgt, smpl); + common_perf_print(ctx_tgt, smpl.get()); llama_batch_free(batch_tgt); - common_sampler_free(smpl); common_speculative_free(spec); llama_backend_free(); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 53f61b5a9b..b8c05cd80e 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2961,7 +2961,13 @@ private: // verify and try to accept the draft { - common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); + const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; + + // only save the sampler sampler state if we use checkpoints + common_sampler_ptr smpl_save; + if (use_ckpt) { + smpl_save.reset(common_sampler_clone(slot.smpl.get())); + } GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft); @@ -2973,7 +2979,7 @@ private: // check for partial draft acceptance if (accepted.size() < slot.spec_draft.size() + 1) { - if (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { + if (use_ckpt) { // partial acceptance is not supported by the context -> truncate the draft and restore the state slot.spec_draft = std::move(accepted);