mirror of
https://github.com/leejet/stable-diffusion.cpp
synced 2026-04-27 09:11:56 +02:00
190 lines
6.7 KiB
C++
190 lines
6.7 KiB
C++
#include "bpe_tokenizer.h"
|
|
|
|
#include <algorithm>
|
|
#include <sstream>
|
|
|
|
#include "tokenize_util.h"
|
|
#include "util.h"
|
|
|
|
std::vector<std::pair<int, std::u32string>> BPETokenizer::bytes_to_unicode() {
|
|
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
|
|
std::set<int> byte_set;
|
|
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
|
|
byte_set.insert(b);
|
|
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
|
}
|
|
for (int b = 161; b <= 172; ++b) {
|
|
byte_set.insert(b);
|
|
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
|
}
|
|
for (int b = 174; b <= 255; ++b) {
|
|
byte_set.insert(b);
|
|
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
|
|
}
|
|
int n = 0;
|
|
for (int b = 0; b < 256; ++b) {
|
|
if (byte_set.find(b) == byte_set.end()) {
|
|
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(n + 256)));
|
|
++n;
|
|
}
|
|
}
|
|
return byte_unicode_pairs;
|
|
}
|
|
|
|
std::vector<std::string> BPETokenizer::token_split(const std::string& text) const {
|
|
return ::token_split(text);
|
|
}
|
|
|
|
std::vector<std::u32string> BPETokenizer::split_utf32(const std::string& text, char32_t delimiter) {
|
|
std::vector<std::u32string> result;
|
|
size_t start = 0;
|
|
size_t pos = 0;
|
|
std::u32string utf32_text = utf8_to_utf32(text);
|
|
while ((pos = utf32_text.find(delimiter, start)) != std::u32string::npos) {
|
|
result.push_back(utf32_text.substr(start, pos - start));
|
|
start = pos + 1;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
static std::set<std::pair<std::u32string, std::u32string>> get_pairs(const std::vector<std::u32string>& subwords) {
|
|
std::set<std::pair<std::u32string, std::u32string>> pairs;
|
|
if (subwords.empty()) {
|
|
return pairs;
|
|
}
|
|
|
|
std::u32string prev_subword = subwords[0];
|
|
for (int i = 1; i < static_cast<int>(subwords.size()); i++) {
|
|
std::u32string subword = subwords[i];
|
|
std::pair<std::u32string, std::u32string> pair(prev_subword, subword);
|
|
pairs.insert(pair);
|
|
prev_subword = subword;
|
|
}
|
|
return pairs;
|
|
}
|
|
|
|
std::vector<std::u32string> BPETokenizer::bpe(const std::u32string& token) const {
|
|
std::vector<std::u32string> word;
|
|
|
|
for (int i = 0; i < static_cast<int>(token.size()) - 1; i++) {
|
|
word.emplace_back(1, token[i]);
|
|
}
|
|
word.push_back(token.substr(token.size() - 1) + utf8_to_utf32(end_of_word_suffix));
|
|
|
|
std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs(word);
|
|
|
|
if (pairs.empty()) {
|
|
return {token + utf8_to_utf32(end_of_word_suffix)};
|
|
}
|
|
|
|
while (true) {
|
|
auto min_pair_iter = std::min_element(pairs.begin(),
|
|
pairs.end(),
|
|
[&](const std::pair<std::u32string, std::u32string>& a,
|
|
const std::pair<std::u32string, std::u32string>& b) {
|
|
if (bpe_ranks.find(a) == bpe_ranks.end()) {
|
|
return false;
|
|
} else if (bpe_ranks.find(b) == bpe_ranks.end()) {
|
|
return true;
|
|
}
|
|
return bpe_ranks.at(a) < bpe_ranks.at(b);
|
|
});
|
|
|
|
const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
|
|
|
|
if (bpe_ranks.find(bigram) == bpe_ranks.end()) {
|
|
break;
|
|
}
|
|
|
|
std::u32string first = bigram.first;
|
|
std::u32string second = bigram.second;
|
|
std::vector<std::u32string> new_word;
|
|
int32_t i = 0;
|
|
|
|
while (i < static_cast<int32_t>(word.size())) {
|
|
auto it = std::find(word.begin() + i, word.end(), first);
|
|
if (it == word.end()) {
|
|
new_word.insert(new_word.end(), word.begin() + i, word.end());
|
|
break;
|
|
}
|
|
new_word.insert(new_word.end(), word.begin() + i, it);
|
|
i = static_cast<int32_t>(std::distance(word.begin(), it));
|
|
|
|
if (word[i] == first && i < static_cast<int32_t>(word.size()) - 1 && word[i + 1] == second) {
|
|
new_word.push_back(first + second);
|
|
i += 2;
|
|
} else {
|
|
new_word.push_back(word[i]);
|
|
i += 1;
|
|
}
|
|
}
|
|
|
|
word = new_word;
|
|
|
|
if (word.size() == 1) {
|
|
break;
|
|
}
|
|
pairs = get_pairs(word);
|
|
}
|
|
|
|
return word;
|
|
}
|
|
|
|
std::vector<int> BPETokenizer::encode(const std::string& text, on_new_token_cb_t on_new_token_cb) {
|
|
std::string normalized_text = normalize(text);
|
|
std::vector<int32_t> bpe_tokens;
|
|
std::vector<std::string> token_strs;
|
|
|
|
auto splited_texts = split_with_special_tokens(normalized_text, special_tokens);
|
|
|
|
for (auto& splited_text : splited_texts) {
|
|
if (is_special_token(splited_text)) {
|
|
if (on_new_token_cb != nullptr) {
|
|
bool skip = on_new_token_cb(splited_text, bpe_tokens);
|
|
if (skip) {
|
|
token_strs.push_back(splited_text);
|
|
continue;
|
|
}
|
|
}
|
|
bpe_tokens.push_back(encoder[utf8_to_utf32(splited_text)]);
|
|
token_strs.push_back(splited_text);
|
|
continue;
|
|
}
|
|
auto tokens = token_split(splited_text);
|
|
for (auto& token : tokens) {
|
|
if (on_new_token_cb != nullptr) {
|
|
bool skip = on_new_token_cb(token, bpe_tokens);
|
|
if (skip) {
|
|
token_strs.push_back(splited_text);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
std::string token_str = token;
|
|
std::u32string utf32_token;
|
|
for (int i = 0; i < static_cast<int>(token_str.length()); i++) {
|
|
unsigned char b = token_str[i];
|
|
utf32_token += byte_encoder[b];
|
|
}
|
|
auto bpe_strs = bpe(utf32_token);
|
|
for (auto bpe_str : bpe_strs) {
|
|
bpe_tokens.push_back(encoder[bpe_str]);
|
|
token_strs.push_back(utf32_to_utf8(bpe_str));
|
|
}
|
|
}
|
|
}
|
|
|
|
std::stringstream ss;
|
|
ss << "[";
|
|
for (auto token : token_strs) {
|
|
ss << "\"" << token << "\", ";
|
|
}
|
|
ss << "]";
|
|
LOG_DEBUG("split prompt \"%s\" to tokens %s", text.c_str(), ss.str().c_str());
|
|
return bpe_tokens;
|
|
}
|
|
|
|
std::string BPETokenizer::decode_token(int token_id) const {
|
|
return utf32_to_utf8(decoder.at(token_id));
|
|
}
|