FIX gpt_eval definition (#100)

This commit is contained in:
PAB 2023-09-01 12:59:25 +02:00 committed by GitHub
parent 21c8350c41
commit 441490b4ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 32 deletions

View File

@ -874,16 +874,18 @@ bool fine_gpt_eval(
bool gpt_eval(
const gpt_model & model,
const int n_threads,
int * n_past,
const bool merge_ctx,
const bark_sequence & embd_inp,
std::vector<float> & embd_w,
size_t & mem_per_token) {
int N = embd_inp.size();
const gpt_model & model,
bark_vocab::id * tokens,
int n_tokens,
float * logits,
int * n_past,
bool merge_ctx,
int n_threads,
size_t & mem_per_token) {
BARK_ASSERT(n_past != NULL);
int N = n_tokens;
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
@ -917,7 +919,7 @@ bool gpt_eval(
struct ggml_cgraph gf = {};
struct ggml_tensor * input = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(input->data, embd_inp.data(), N*ggml_element_size(input));
memcpy(input->data, tokens, N*ggml_element_size(input));
struct ggml_tensor * tok_emb;
@ -1183,9 +1185,10 @@ bool gpt_eval(
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
// return result just for the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
if (logits != NULL) {
// return result just for the last token
memcpy(logits, (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
}
if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
@ -1315,25 +1318,31 @@ bark_sequence bark_forward_text_encoder(
int64_t t_sample_us = 0;
int64_t t_predict_us = 0;
auto & hparams = model.hparams;
const int n_vocab = hparams.n_out_vocab;
const int64_t t_main_start_us = ggml_time_us();
float eos_p = 0;
bark_sequence input = tokens;
std::vector<float> logits;
logits.resize(n_vocab);
// dry run to estimate mem_per_token
size_t mem_per_token = 0;
{
int n_past = 0;
gpt_eval(model, n_threads, &n_past, false, { 0, 1, 2, 3 }, logits, mem_per_token);
bark_vocab::id decoy[4] = { 0, 1, 2, 3 };
gpt_eval(model, decoy, 4, nullptr, &n_past, false, n_threads, mem_per_token);
}
int n_past = 0;
for (int i = 0; i < 768; i++) {
int64_t t_predict_start_us = ggml_time_us();
gpt_eval(model, n_threads, &n_past, true, input, logits, mem_per_token);
gpt_eval(model, input.data(), input.size(), logits.data(), &n_past, true, n_threads, mem_per_token);
t_predict_us += (ggml_time_us() - t_predict_start_us);
std::vector<float> relevant_logits(logits.begin(), logits.begin() + SEMANTIC_VOCAB_SIZE);
@ -1395,14 +1404,20 @@ bark_codes bark_forward_coarse_encoder(
int n_window_steps = ceilf(static_cast<float>(n_steps) / sliding_window_size);
auto & hparams = model.hparams;
const int n_vocab = hparams.n_out_vocab;
bark_sequence input = tokens;
std::vector<float> logits;
logits.resize(n_vocab);
// dry run to estimate mem_per_token
size_t mem_per_token = 0;
{
int n_past = 0;
gpt_eval(model, n_threads, &n_past, false, { 0, 1, 2, 3 }, logits, mem_per_token);
bark_vocab::id decoy[4] = { 0, 1, 2, 3 };
gpt_eval(model, decoy, 4, nullptr, &n_past, false, n_threads, mem_per_token);
}
for (int i = 0; i < n_window_steps; i++) {
@ -1436,7 +1451,7 @@ bark_codes bark_forward_coarse_encoder(
continue;
int64_t t_predict_start_us = ggml_time_us();
gpt_eval(model, n_threads, &n_past, false, input_in, logits, mem_per_token);
gpt_eval(model, input_in.data(), input_in.size(), logits.data(), &n_past, false, n_threads, mem_per_token);
t_predict_us += (ggml_time_us() - t_predict_start_us);
input_in.clear();
@ -1455,9 +1470,6 @@ bark_codes bark_forward_coarse_encoder(
input_in.push_back(next);
out.push_back(next);
// printf("%d ", next);
// fflush(stdout);
step_ix += 1;
progress.callback((float) (i*sliding_window_size+j)/n_steps);
@ -1530,8 +1542,8 @@ bark_codes bark_forward_fine_encoder(
}
// dry run to estimate mem_per_token
bark_sequence decoy = { 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8 };
fine_gpt_eval(model, decoy.data(), decoy.size(), nullptr, n_threads, 2, mem_per_token);
bark_vocab::id decoy[16] = { 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8 };
fine_gpt_eval(model, decoy, 16, nullptr, n_threads, 2, mem_per_token);
int n_loops = std::max(0, (int) ceilf((input.size() - 1024)/512.f)) + 1;

15
bark.h
View File

@ -137,13 +137,14 @@ struct bark_model {
bool gpt_model_load(const std::string& fname, gpt_model& model);
bool gpt_eval(
const gpt_model & model,
const int n_threads,
int * n_past,
const bool merge_ctx,
const bark_sequence & embd_inp,
std::vector<float> & embd_w,
size_t & mem_per_token);
const gpt_model & model,
bark_vocab::id * tokens,
int n_tokens,
float * logits,
int * n_past,
bool merge_ctx,
int n_threads,
size_t & mem_per_token);
bool fine_gpt_eval(
const gpt_model & model,

View File

@ -12,7 +12,7 @@ static const std::vector<std::string> test_data = {
"./data/semantic/test_pass_semantic_3.bin", // prompt: El Arte de Vencer se Aprende en las Derrotas
};
static const int n_threads = 4;
static const int n_threads = 4;
static const float min_eos_p = 0.2;
static const float temp = 0.0f; // deterministic sampling

View File

@ -29,17 +29,21 @@ int main() {
bark_sequence tokens;
logit_sequence gt_logits, logits;
auto & hparams = model.hparams;
int n_vocab = hparams.n_out_vocab;
logits.resize(n_vocab);
// dry run to estimate mem_per_token
size_t mem_per_token = 0;
{
int n_past = 0;
gpt_eval(model, n_threads, &n_past, false, { 0, 1, 2, 3 }, logits, mem_per_token);
bark_vocab::id decoy[4] = { 0, 1, 2, 3 };
gpt_eval(model, decoy, 4, nullptr, &n_past, false, n_threads, mem_per_token);
}
for (int i = 0; i < (int) test_args.size(); i++) {
tokens.clear();
gt_logits.clear();
logits.clear();
std::string path = std::get<0>(test_args[i]);
bool merge_ctx = std::get<1>(test_args[i]);
@ -47,7 +51,7 @@ int main() {
load_test_data(path, tokens, gt_logits);
int n_past = 0;
gpt_eval(model, n_threads, &n_past, merge_ctx, tokens, logits, mem_per_token);
gpt_eval(model, tokens.data(), tokens.size(), logits.data(), &n_past, merge_ctx, n_threads, mem_per_token);
printf("\n");
printf("%s: %s\n", __func__, path.c_str());