common: moved self-spec impl to ngram-map

This commit is contained in:
Sascha Rogmann 2026-01-25 01:16:06 +01:00
parent a1584ac80f
commit cb3a40277a
3 changed files with 165 additions and 117 deletions

View File

@ -5,6 +5,97 @@
#include <cinttypes>
#include <cstdint>
#include <cstdio>
#include <sstream>
// n-gram simple
//
/**
* Perform speculative generation using the model's own token history.
* Searches for a matching pattern in the token history and returns draft tokens.
*
* @param state Current state of this implementation
* @param tokens Token history to search in
* @param sampled Last sampled token
* @return Vector of draft tokens, empty if no matching pattern is found
*/
llama_tokens common_ngram_simple_draft(
common_ngram_simple_state & state,
const llama_tokens & tokens, llama_token sampled) {
// Simple implementation of self-speculative decoding without draft model, without ngram-map.
//
const size_t cur_len = tokens.size();
// Only check every check_rate tokens to save compute
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
if (state.idx_last_check + state.config.check_rate > cur_len) {
llama_tokens draft_tokens;
return draft_tokens;
}
size_t n_draft_min = state.config.size_ngram; // size of n-gram to lookup in token history
size_t n_draft_max = state.config.size_mgram; // the m-gram following the found n-gram is used for draft
// vector for tokens we want to verify.
// return empty vector if there is no match.
llama_tokens draft_tokens;
// We need at least n_draft_min + n_draft_max + 1 tokens.
if (cur_len <= static_cast<size_t>(n_draft_min + n_draft_max + 1)) {
return draft_tokens;
}
// pattern search
llama_tokens pattern;
pattern.reserve(n_draft_min);
for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) {
pattern.push_back(tokens[j]);
}
pattern.push_back(sampled); // add the last token to the pattern
// We do a search in the token history.
state.idx_last_check = tokens.size();
size_t match_pos = 0; // we ignore position 0, position 0 == no match
// search backwards, but skip the current match (we are currently there)
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
bool match = true;
for (size_t k = 0; k < pattern.size(); ++k) {
if (tokens[j + k] != pattern[k]) {
match = false;
break;
}
}
if (match) {
match_pos = j;
break;
}
}
if (match_pos == 0) {
return draft_tokens;
}
const size_t copy_max = std::min(
n_draft_max,
cur_len - (match_pos + n_draft_min)
);
if (copy_max < n_draft_min) {
return draft_tokens;
}
LOG_DBG("%s: #tokens = %zu: found matching pattern at pos %zu, length %zu, draft length %zu\n",
__func__, cur_len,
match_pos, pattern.size(), copy_max);
draft_tokens.reserve(copy_max);
for (size_t j = 0; j < copy_max; ++j) {
draft_tokens.push_back(tokens[match_pos + n_draft_min + j]);
}
return draft_tokens;
}
// n-gram map
//
// maximum number of counted values of a ngram map value.
#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
@ -262,14 +353,15 @@ void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
std::string result = "[";
std::ostringstream oss;
oss << '[';
for (size_t i = 0; i < length; ++i) {
if (i > 0) {
result += ", ";
oss << ", ";
}
result += std::to_string(inp[start + i]);
oss << inp[start + i];
}
result += "]";
return result;
oss << ']';
return oss.str();
}

View File

@ -3,12 +3,51 @@
// common/ngram-map.h: structures used to manage a map from n-grams to a list of m-grams
//
// These structures are used to do a lookup of n-grams followed by m-grams in token history.
//
// There are two algorithms implemented:
// 1. ngram_simple: lookup of n-grams followed by m-grams in token history.
// 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map.
// The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams.
//
#include "llama.h"
#include <string>
#include <vector>
// n-gram simple
//
// config of n-gram simple.
struct common_ngram_simple_config {
uint16_t size_ngram; // size of n-grams to lookup in self-mode
uint16_t size_mgram; // size of m-grams to draft in self-mode
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
};
// current state (and config) of n-gram simple.
struct common_ngram_simple_state {
common_ngram_simple_config config;
size_t idx_last_check = 0; // index of last check in context history (mutable)
common_ngram_simple_state(const common_ngram_simple_config & config)
: config(config) {}
};
// Searches for a n-gram in the history and checks whether a draft sequence should be generated.
// state: the ngram simple state to search in.
// inp: the tokens generated so far.
// sampled: the token that was just sampled.
// draft: vector to store the draft tokens, initially empty.
llama_tokens common_ngram_simple_draft(
common_ngram_simple_state & state,
const llama_tokens & tokens, llama_token sampled);
// n-gram map
//
// maximum number of m-gram values stored for each key n-gram.
#define COMMON_NGRAM_MAX_VALUES 4
@ -52,6 +91,7 @@ struct common_ngram_map {
size_t idx_last_check = 0; // index of last check in context history
};
// Searches for the n-gram in the history and checks whether a draft sequence should be generated.
// map: the ngram map to search in.
// inp: the tokens generated so far.

View File

@ -59,21 +59,13 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
// state of self-speculation (simple implementation, not ngram-map)
struct common_speculative_state_ngram_simple : public common_speculative_state {
uint16_t size_ngram; // size of n-grams to lookup in self-mode
uint16_t size_mgram; // size of m-grams to draft in self-mode
const uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
size_t idx_last_check = 0; // index of last check in context history (mutable)
common_ngram_simple_state state;
common_speculative_state_ngram_simple(
enum common_speculative_type type,
uint16_t size_ngram,
uint16_t size_mgram,
uint16_t check_rate)
: common_speculative_state(type)
, size_ngram(size_ngram)
, size_mgram(size_mgram)
, check_rate(check_rate) {}
common_ngram_simple_state state)
: common_speculative_state(type), state(std::move(state)) {}
};
struct common_speculative_state_ngram_map_k : public common_speculative_state {
@ -275,26 +267,30 @@ struct common_speculative * common_speculative_init(
uint16_t ngram_size_key = ngram_map.size_key;
uint16_t mgram_size_value = ngram_map.size_value;
uint16_t check_rate = ngram_map.check_rate;
auto state = std::make_unique<common_speculative_state_ngram_simple>(
/* .type = */ config.type,
auto config_simple = common_ngram_simple_config{
/* .size_ngram = */ ngram_size_key,
/* .size_mgram = */ mgram_size_value,
/* .check_rate = */ check_rate
};
auto state = std::make_unique<common_speculative_state_ngram_simple>(
/* .type = */ config.type,
/* .state = */ common_ngram_simple_state(config_simple)
);
implementations.push_back(std::move(state));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: {
implementations.push_back(std::make_unique<common_speculative_state_ngram_map_k>(
(config.type), get_common_ngram_map(config,
params.spec_ngram_size_n, params.spec_ngram_size_m)
(config.type),
get_common_ngram_map(config, params.spec_ngram_size_n, params.spec_ngram_size_m)
));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
implementations.push_back(std::make_unique<common_speculative_state_ngram_map_k4v>(
(config.type), get_common_ngram_map(config,
params.spec_ngram_size_n, params.spec_ngram_size_m)));
(config.type),
get_common_ngram_map(config, params.spec_ngram_size_n, params.spec_ngram_size_m)
));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
@ -475,10 +471,6 @@ llama_tokens common_speculative_use_draft_model(
const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
llama_token id_last);
llama_tokens common_speculative_gen_self_draft(
common_speculative_state_ngram_simple & state,
const llama_tokens & tokens, llama_token sampled);
llama_tokens common_speculative_gen_ngram_cache(
common_speculative_state_ngram_cache & state,
const llama_tokens & tokens, llama_token sampled);
@ -515,25 +507,31 @@ llama_tokens common_speculative_gen_draft(
// Use common_ngram_map_draft to generate a draft from the current context.
auto * state = dynamic_cast<struct common_speculative_state_ngram_simple *>(impl.get());
if (state) {
result = common_speculative_gen_self_draft(*state, prompt_tgt_main_model, id_last);
result = common_ngram_simple_draft(state->state, prompt_tgt_main_model, id_last);
} else {
GGML_ABORT("unexpected implementation in type %d", impl.get()->type);
}
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
{
// Use common_ngram_map_draft to generate a draft from the current context.
auto state = dynamic_cast<common_speculative_state_ngram_map_k *>(impl.get());
auto * state = dynamic_cast<common_speculative_state_ngram_map_k *>(impl.get());
if (state) {
common_ngram_map_draft(state->map, prompt_tgt_main_model, id_last, result);
} else {
GGML_ABORT("unexpected implementation in type %d", impl.get()->type);
}
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V:
{
// Use common_ngram_map_draft to generate a draft from the current context.
auto state = dynamic_cast<common_speculative_state_ngram_map_k *>(impl.get());
auto * state = dynamic_cast<common_speculative_state_ngram_map_k *>(impl.get());
if (state) {
common_ngram_map_draft(state->map, prompt_tgt_main_model, id_last, result);
} else {
GGML_ABORT("unexpected implementation in type %d", impl.get()->type);
}
break;
}
@ -542,6 +540,8 @@ llama_tokens common_speculative_gen_draft(
auto * state= dynamic_cast<common_speculative_state_ngram_cache *>(impl.get());
if (state) {
result = common_speculative_gen_ngram_cache(*state, prompt_tgt_main_model, id_last);
} else {
GGML_ABORT("unexpected implementation in type %d", impl.get()->type);
}
break;
}
@ -755,8 +755,10 @@ void common_speculative_accept(struct common_speculative * spec, const uint16_t
}
if (impl->type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K ||
impl->type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V) {
auto state = static_cast<struct common_speculative_state_ngram_map_k *>(impl);
common_ngram_map_accept(state->map, n_accepted);
auto * state = dynamic_cast<struct common_speculative_state_ngram_map_k *>(impl);
if (state) {
common_ngram_map_accept(state->map, n_accepted);
}
}
}
}
@ -777,92 +779,6 @@ void common_speculative_print_stats(const struct common_speculative * spec) {
}
// self-speculative decoding
//
/**
* Perform speculative generation using the model's own token history.
* Searches for a matching pattern in the token history and returns draft tokens.
*
* @param state Current state of this implementation
* @param tokens Token history to search in
* @param sampled Last sampled token
* @return Vector of draft tokens, empty if no matching pattern is found
*/
llama_tokens common_speculative_gen_self_draft(
common_speculative_state_ngram_simple & state,
const llama_tokens & tokens, llama_token sampled) {
// Simple implementation of self-speculative decoding without draft model, without ngram-map.
//
const size_t cur_len = tokens.size();
// Only check every check_rate tokens to save compute
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
if (state.idx_last_check + state.check_rate > cur_len) {
llama_tokens draft_tokens;
return draft_tokens;
}
size_t n_draft_min = state.size_ngram; // size of n-gram to lookup in token history
size_t n_draft_max = state.size_mgram; // the m-gram following the found n-gram is used for draft
// vector for tokens we want to verify.
// return empty vector if there is no match.
llama_tokens draft_tokens;
// We need at least n_draft_min + n_draft_max + 1 tokens.
if (cur_len <= static_cast<size_t>(n_draft_min + n_draft_max + 1)) {
return draft_tokens;
}
// pattern search
llama_tokens pattern;
pattern.reserve(n_draft_min);
for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) {
pattern.push_back(tokens[j]);
}
pattern.push_back(sampled); // add the last token to the pattern
// We do a search in the token history.
state.idx_last_check = tokens.size();
size_t match_pos = 0; // we ignore position 0, position 0 == no match
// search backwards, but skip the current match (we are currently there)
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
bool match = true;
for (size_t k = 0; k < pattern.size(); ++k) {
if (tokens[j + k] != pattern[k]) {
match = false;
break;
}
}
if (match) {
match_pos = j;
break;
}
}
if (match_pos == 0) {
return draft_tokens;
}
const size_t copy_max = std::min(
n_draft_max,
cur_len - (match_pos + n_draft_min)
);
if (copy_max < n_draft_min) {
return draft_tokens;
}
LOG_DBG("%s: #tokens = %zu: found matching pattern at pos %zu, length %zu, draft length %zu\n",
__func__, cur_len,
match_pos, pattern.size(), copy_max);
draft_tokens.reserve(copy_max);
for (size_t j = 0; j < copy_max; ++j) {
draft_tokens.push_back(tokens[match_pos + n_draft_min + j]);
}
return draft_tokens;
}
// n-gram cache
//