From 38d762d8fc4d3417c767f03d40253e12bbc559f1 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Sat, 25 Apr 2026 16:48:16 -0700 Subject: [PATCH] common: refactor common/debug to move abort_on_nan into base_callback_data Passing bool abort_on_nan as template parameter for common_debug_cb_eval is unnecessary and creates an issue with LTO. It should just be a member of the base_callback_data instead. --- common/debug.cpp | 15 ++++----------- common/debug.h | 15 ++++++++++----- examples/eval-callback/eval-callback.cpp | 2 +- tools/mtmd/debug/mtmd-debug.cpp | 2 +- tools/mtmd/mtmd-cli.cpp | 2 +- 5 files changed, 17 insertions(+), 19 deletions(-) diff --git a/common/debug.cpp b/common/debug.cpp index 0df409a79d..5d74936b3a 100644 --- a/common/debug.cpp +++ b/common/debug.cpp @@ -47,8 +47,7 @@ static float common_ggml_get_float_value(const uint8_t * data, #define INDENT " " -template -void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { +void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n, bool abort_on_nan) { GGML_ASSERT(n > 0); float sum = 0; for (int64_t i3 = 0; i3 < ne[3]; i3++) { @@ -94,7 +93,7 @@ void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * n LOG(INDENT "sum = %f\n", sum); } - if constexpr (abort) { + if (abort_on_nan) { if (std::isnan(sum)) { LOG("encountered NaN - aborting\n"); exit(0); @@ -112,7 +111,7 @@ void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * n * @param user_data user data to pass at each call back * @return true to receive data or continue the graph, false otherwise */ -template bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { +bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { auto * cb_data = (base_callback_data *) user_data; const struct ggml_tensor * src0 = t->src[0]; @@ -154,14 +153,8 @@ template bool common_debug_cb_eval(struct ggml_tensor * t, b if (!ggml_is_quantized(t->type) && matches_filter) { uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data(); - common_debug_print_tensor(data, t->type, t->ne, t->nb, 3); + common_debug_print_tensor(data, t->type, t->ne, t->nb, 3, cb_data->abort_on_nan); } return true; } - -// Explicit template instantiations -template bool common_debug_cb_eval(ggml_tensor *, bool, void *); -template bool common_debug_cb_eval(ggml_tensor *, bool, void *); -template void common_debug_print_tensor(uint8_t *, ggml_type, const int64_t *, const size_t *, int64_t); -template void common_debug_print_tensor(uint8_t *, ggml_type, const int64_t *, const size_t *, int64_t); diff --git a/common/debug.h b/common/debug.h index e563b40d68..aa1fa17337 100644 --- a/common/debug.h +++ b/common/debug.h @@ -12,23 +12,26 @@ // ne - the tensor dimensions array // nb - the tensor strides array // n - the number of rows/columns to fully print -template void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n); +// aon - abort if NaN is encountered +void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n, bool aon = false); // Intended to use as callback for ggml_backend_sched_eval_callback // prints tensors that are processed in the computation graph // by default prints all tensors, but can be configured by creating a `base_callback_data` instance with // non-empty filter_patterns. See examples/debug.ccp for possible usage patterns -// The template parameter determines whether an error should be thrown whenever a NaN is encountered +// `base_callback_data` contains `abort_on_nan` flag that determines whether an error should be thrown whenever a NaN is encountered // in a tensor (useful for stopping debug sessions on first erroneous tensor) // The callback data will be passed as the third parameter (user_data) -template bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data); +bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data); + struct base_callback_data { std::vector data; std::vector tensor_filters; + bool abort_on_nan{false}; base_callback_data() = default; - base_callback_data(common_params & params, const std::vector & filter_patterns) { + base_callback_data(common_params & params, const std::vector & filter_patterns, bool abort_on_nan = false) { for (const auto & pattern : filter_patterns) { try { std::string anchored_pattern = "^" + pattern; @@ -37,7 +40,9 @@ struct base_callback_data { throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what()); } } - params.cb_eval = common_debug_cb_eval; + this->abort_on_nan = abort_on_nan; + + params.cb_eval = common_debug_cb_eval; params.cb_eval_user_data = this; } }; diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 8832468451..6685e525f8 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -53,7 +53,7 @@ int main(int argc, char ** argv) { // pass the callback to the backend scheduler // it will be executed for each node during the graph computation - params.cb_eval = common_debug_cb_eval; + params.cb_eval = common_debug_cb_eval; params.cb_eval_user_data = &cb_data; params.warmup = false; diff --git a/tools/mtmd/debug/mtmd-debug.cpp b/tools/mtmd/debug/mtmd-debug.cpp index 6e32b283aa..182a9bca41 100644 --- a/tools/mtmd/debug/mtmd-debug.cpp +++ b/tools/mtmd/debug/mtmd-debug.cpp @@ -89,7 +89,7 @@ int main(int argc, char ** argv) { { // always enable debug callback mparams.cb_eval_user_data = &cb_data; - mparams.cb_eval = common_debug_cb_eval; + mparams.cb_eval = common_debug_cb_eval; } ctx_mtmd.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_mtmd.get()) { diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index dd72dfb17c..aa6ad73807 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -145,7 +145,7 @@ struct mtmd_cli_context { mparams.image_max_tokens = params.image_max_tokens; if (std::getenv("MTMD_DEBUG_GRAPH") != nullptr) { mparams.cb_eval_user_data = &cb_data; - mparams.cb_eval = common_debug_cb_eval; + mparams.cb_eval = common_debug_cb_eval; } ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_vision.get()) {