common: refactor common/debug to move abort_on_nan into base_callback_data

Passing bool abort_on_nan as template parameter for common_debug_cb_eval is unnecessary and creates an issue with LTO.
It should just be a member of the base_callback_data instead.
This commit is contained in:
Max Krasnyansky 2026-04-25 16:48:16 -07:00
parent 0adede866d
commit 38d762d8fc
5 changed files with 17 additions and 19 deletions

View File

@ -47,8 +47,7 @@ static float common_ggml_get_float_value(const uint8_t * data,
#define INDENT " "
template <bool abort>
void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n, bool abort_on_nan) {
GGML_ASSERT(n > 0);
float sum = 0;
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
@ -94,7 +93,7 @@ void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * n
LOG(INDENT "sum = %f\n", sum);
}
if constexpr (abort) {
if (abort_on_nan) {
if (std::isnan(sum)) {
LOG("encountered NaN - aborting\n");
exit(0);
@ -112,7 +111,7 @@ void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * n
* @param user_data user data to pass at each call back
* @return true to receive data or continue the graph, false otherwise
*/
template <bool abort_on_nan> bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
auto * cb_data = (base_callback_data *) user_data;
const struct ggml_tensor * src0 = t->src[0];
@ -154,14 +153,8 @@ template <bool abort_on_nan> bool common_debug_cb_eval(struct ggml_tensor * t, b
if (!ggml_is_quantized(t->type) && matches_filter) {
uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
common_debug_print_tensor<abort_on_nan>(data, t->type, t->ne, t->nb, 3);
common_debug_print_tensor(data, t->type, t->ne, t->nb, 3, cb_data->abort_on_nan);
}
return true;
}
// Explicit template instantiations
template bool common_debug_cb_eval<false>(ggml_tensor *, bool, void *);
template bool common_debug_cb_eval<true>(ggml_tensor *, bool, void *);
template void common_debug_print_tensor<false>(uint8_t *, ggml_type, const int64_t *, const size_t *, int64_t);
template void common_debug_print_tensor<true>(uint8_t *, ggml_type, const int64_t *, const size_t *, int64_t);

View File

@ -12,23 +12,26 @@
// ne - the tensor dimensions array
// nb - the tensor strides array
// n - the number of rows/columns to fully print
template <bool abort_on_nan> void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n);
// aon - abort if NaN is encountered
void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n, bool aon = false);
// Intended to use as callback for ggml_backend_sched_eval_callback
// prints tensors that are processed in the computation graph
// by default prints all tensors, but can be configured by creating a `base_callback_data` instance with
// non-empty filter_patterns. See examples/debug.ccp for possible usage patterns
// The template parameter determines whether an error should be thrown whenever a NaN is encountered
// `base_callback_data` contains `abort_on_nan` flag that determines whether an error should be thrown whenever a NaN is encountered
// in a tensor (useful for stopping debug sessions on first erroneous tensor)
// The callback data will be passed as the third parameter (user_data)
template <bool abort_on_nan> bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data);
bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data);
struct base_callback_data {
std::vector<uint8_t> data;
std::vector<std::regex> tensor_filters;
bool abort_on_nan{false};
base_callback_data() = default;
base_callback_data(common_params & params, const std::vector<std::string> & filter_patterns) {
base_callback_data(common_params & params, const std::vector<std::string> & filter_patterns, bool abort_on_nan = false) {
for (const auto & pattern : filter_patterns) {
try {
std::string anchored_pattern = "^" + pattern;
@ -37,7 +40,9 @@ struct base_callback_data {
throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what());
}
}
params.cb_eval = common_debug_cb_eval<false>;
this->abort_on_nan = abort_on_nan;
params.cb_eval = common_debug_cb_eval;
params.cb_eval_user_data = this;
}
};

View File

@ -53,7 +53,7 @@ int main(int argc, char ** argv) {
// pass the callback to the backend scheduler
// it will be executed for each node during the graph computation
params.cb_eval = common_debug_cb_eval<false>;
params.cb_eval = common_debug_cb_eval;
params.cb_eval_user_data = &cb_data;
params.warmup = false;

View File

@ -89,7 +89,7 @@ int main(int argc, char ** argv) {
{
// always enable debug callback
mparams.cb_eval_user_data = &cb_data;
mparams.cb_eval = common_debug_cb_eval<false>;
mparams.cb_eval = common_debug_cb_eval;
}
ctx_mtmd.reset(mtmd_init_from_file(clip_path, model, mparams));
if (!ctx_mtmd.get()) {

View File

@ -145,7 +145,7 @@ struct mtmd_cli_context {
mparams.image_max_tokens = params.image_max_tokens;
if (std::getenv("MTMD_DEBUG_GRAPH") != nullptr) {
mparams.cb_eval_user_data = &cb_data;
mparams.cb_eval = common_debug_cb_eval<false>;
mparams.cb_eval = common_debug_cb_eval;
}
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
if (!ctx_vision.get()) {