diff --git a/common/debug.cpp b/common/debug.cpp index 0c238e6a5d..102c6924dc 100644 --- a/common/debug.cpp +++ b/common/debug.cpp @@ -1,9 +1,38 @@ #include "debug.h" +#include "common.h" #include "log.h" #include +#include #include +#include + +struct common_debug_cb_user_data::impl { + std::vector data; + std::vector tensor_filters; + bool abort_on_nan{false}; +}; + +common_debug_cb_user_data::common_debug_cb_user_data() : pimpl(std::make_unique()) {} +common_debug_cb_user_data::~common_debug_cb_user_data() = default; + +common_debug_cb_user_data::common_debug_cb_user_data(common_params & params, const std::vector & filter_patterns, bool abort_on_nan) + : pimpl(std::make_unique()) +{ + for (const auto & pattern : filter_patterns) { + try { + std::string anchored_pattern = "^" + pattern; + pimpl->tensor_filters.emplace_back(anchored_pattern, std::regex::optimize); + } catch (const std::regex_error & e) { + throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what()); + } + } + pimpl->abort_on_nan = abort_on_nan; + + params.cb_eval = common_debug_cb_eval; + params.cb_eval_user_data = this; +} static std::string common_ggml_ne_string(const ggml_tensor * t) { std::string str; @@ -113,6 +142,7 @@ static void common_debug_print_tensor(uint8_t * data, ggml_type type, const int6 */ bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { auto * cb_data = (common_debug_cb_user_data *) user_data; + auto * pimpl = cb_data->pimpl.get(); const struct ggml_tensor * src0 = t->src[0]; const struct ggml_tensor * src1 = t->src[1]; @@ -121,10 +151,10 @@ bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { return true; // Always retrieve data } - bool matches_filter = cb_data->tensor_filters.empty(); + bool matches_filter = pimpl->tensor_filters.empty(); if (!matches_filter) { - for (const auto & filter : cb_data->tensor_filters) { + for (const auto & filter : pimpl->tensor_filters) { if (std::regex_search(t->name, filter)) { matches_filter = true; break; @@ -147,13 +177,13 @@ bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { if (!is_host) { auto n_bytes = ggml_nbytes(t); - cb_data->data.resize(n_bytes); - ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes); + pimpl->data.resize(n_bytes); + ggml_backend_tensor_get(t, pimpl->data.data(), 0, n_bytes); } if (!ggml_is_quantized(t->type) && matches_filter) { - uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data(); - common_debug_print_tensor(data, t->type, t->ne, t->nb, 3, cb_data->abort_on_nan); + uint8_t * data = is_host ? (uint8_t *) t->data : pimpl->data.data(); + common_debug_print_tensor(data, t->type, t->ne, t->nb, 3, pimpl->abort_on_nan); } return true; diff --git a/common/debug.h b/common/debug.h index 4017ccdb8a..8b8f8c7aa9 100644 --- a/common/debug.h +++ b/common/debug.h @@ -1,39 +1,31 @@ #pragma once -#include "common.h" + +#include #include #include -#include // common debug functions and structs +struct common_params; + // Intended to use as callback for ggml_backend_sched_eval_callback // prints tensors that are processed in the computation graph // by default prints all tensors, but can be configured by creating a `common_debug_cb_user_data` instance with -// non-empty filter_patterns. See examples/debug.ccp for possible usage patterns +// non-empty filter_patterns. See examples/debug.cpp for possible usage patterns // `common_debug_cb_user_data` contains `abort_on_nan` flag that determines whether an error should be thrown whenever a NaN is encountered // in a tensor (useful for stopping debug sessions on first erroneous tensor) // The callback data will be passed as the third parameter (user_data) bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data); struct common_debug_cb_user_data { - std::vector data; - std::vector tensor_filters; - bool abort_on_nan{false}; + struct impl; + std::unique_ptr pimpl; - common_debug_cb_user_data() = default; + common_debug_cb_user_data(); + ~common_debug_cb_user_data(); - common_debug_cb_user_data(common_params & params, const std::vector & filter_patterns, bool abort_on_nan = false) { - for (const auto & pattern : filter_patterns) { - try { - std::string anchored_pattern = "^" + pattern; - tensor_filters.emplace_back(anchored_pattern, std::regex::optimize); - } catch (const std::regex_error & e) { - throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what()); - } - } - this->abort_on_nan = abort_on_nan; + common_debug_cb_user_data(const common_debug_cb_user_data &) = delete; + common_debug_cb_user_data & operator=(const common_debug_cb_user_data &) = delete; - params.cb_eval = common_debug_cb_eval; - params.cb_eval_user_data = this; - } + common_debug_cb_user_data(common_params & params, const std::vector & filter_patterns, bool abort_on_nan = false); };