mirror of
https://github.com/PABannier/bark.cpp
synced 2026-03-10 00:51:01 +01:00
72 lines
2.4 KiB
C++
72 lines
2.4 KiB
C++
#include "bark.h"
|
|
|
|
#include <cstdio>
|
|
#include <string>
|
|
#include <map>
|
|
#include <vector>
|
|
|
|
static const std::map<std::string, std::vector<bark_vocab::id>> & k_tests()
|
|
{
|
|
static std::map<std::string, std::vector<bark_vocab::id>> _k_tests = {
|
|
{ "Hello world!", { 31178, 11356, 106, }, },
|
|
{ "Hello world", { 31178, 11356, }, },
|
|
{ " Hello world!", { 31178, 11356, 106, }, },
|
|
// { "this is an audio generated by bark", { 10531, 10124, 10151, 23685, 48918, 10155, 18121, 10174, }, },
|
|
};
|
|
return _k_tests;
|
|
};
|
|
|
|
int main(int argc, char **argv) {
|
|
if (argc < 2) {
|
|
fprintf(stderr, "Usage: %s <model-file>\n", argv[0]);
|
|
return 1;
|
|
}
|
|
|
|
const std::string fname = argv[1];
|
|
|
|
fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
|
|
|
|
bark_vocab vocab;
|
|
int max_ctx_size = 256;
|
|
|
|
if(!bark_vocab_load(fname, vocab, 119547)) {
|
|
fprintf(stderr, "%s: invalid vocab file '%s'\n", __func__, fname.c_str());
|
|
return 1;
|
|
}
|
|
|
|
for (const auto & test_kv : k_tests()) {
|
|
std::vector<bark_vocab::id> res(test_kv.first.size());
|
|
int n_tokens;
|
|
bert_tokenize(vocab, test_kv.first.c_str(), res.data(), &n_tokens, max_ctx_size);
|
|
res.resize(n_tokens);
|
|
|
|
bool correct = res.size() == test_kv.second.size();
|
|
|
|
for (int i = 0; i < (int) res.size() && correct; ++i) {
|
|
if (res[i] != test_kv.second[i]) {
|
|
correct = false;
|
|
}
|
|
}
|
|
|
|
if (!correct) {
|
|
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
|
|
fprintf(stderr, "%s : expected tokens: ", __func__);
|
|
for (const auto & t : test_kv.second) {
|
|
fprintf(stderr, "%6d, ", t);
|
|
}
|
|
fprintf(stderr, "\n");
|
|
fprintf(stderr, "%s : got tokens: ", __func__);
|
|
for (const auto & t : res) {
|
|
fprintf(stderr, "%6d, ", t);
|
|
}
|
|
fprintf(stderr, "\n");
|
|
|
|
return 3;
|
|
}
|
|
}
|
|
|
|
fprintf(stderr, "%s : tests passed successfully.\n", __func__);
|
|
|
|
return 0;
|
|
}
|