stable-diffusion.cpp/src/sample-cache.cpp

362 lines
16 KiB
C++

#include "sample-cache.h"
namespace sd_sample {
static float get_cache_reuse_threshold(const sd_cache_params_t& params) {
float reuse_threshold = params.reuse_threshold;
if (reuse_threshold == INFINITY) {
if (params.mode == SD_CACHE_EASYCACHE) {
reuse_threshold = 0.2f;
} else if (params.mode == SD_CACHE_UCACHE) {
reuse_threshold = 1.0f;
}
}
return std::max(0.0f, reuse_threshold);
}
bool SampleCacheRuntime::easycache_enabled() const {
return mode == SampleCacheMode::EASYCACHE;
}
bool SampleCacheRuntime::ucache_enabled() const {
return mode == SampleCacheMode::UCACHE;
}
bool SampleCacheRuntime::cachedit_enabled() const {
return mode == SampleCacheMode::CACHEDIT;
}
static bool has_valid_cache_percent_range(const sd_cache_params_t& cache_params) {
if (cache_params.mode != SD_CACHE_EASYCACHE && cache_params.mode != SD_CACHE_UCACHE) {
return true;
}
return cache_params.start_percent >= 0.0f &&
cache_params.start_percent < 1.0f &&
cache_params.end_percent > 0.0f &&
cache_params.end_percent <= 1.0f &&
cache_params.start_percent < cache_params.end_percent;
}
static void init_easycache_runtime(SampleCacheRuntime& runtime,
SDVersion version,
const sd_cache_params_t& cache_params,
Denoiser* denoiser) {
if (!sd_version_is_dit(version)) {
LOG_WARN("EasyCache requested but not supported for this model type");
return;
}
EasyCacheConfig config;
config.enabled = true;
config.reuse_threshold = get_cache_reuse_threshold(cache_params);
config.start_percent = cache_params.start_percent;
config.end_percent = cache_params.end_percent;
runtime.easycache.init(config, denoiser);
if (!runtime.easycache.enabled()) {
LOG_WARN("EasyCache requested but could not be initialized for this run");
return;
}
runtime.mode = SampleCacheMode::EASYCACHE;
LOG_INFO("EasyCache enabled - threshold: %.3f, start: %.2f, end: %.2f",
config.reuse_threshold,
config.start_percent,
config.end_percent);
}
static void init_ucache_runtime(SampleCacheRuntime& runtime,
SDVersion version,
const sd_cache_params_t& cache_params,
Denoiser* denoiser,
const std::vector<float>& sigmas) {
if (!sd_version_is_unet(version)) {
LOG_WARN("UCache requested but not supported for this model type (only UNET models)");
return;
}
UCacheConfig config;
config.enabled = true;
config.reuse_threshold = get_cache_reuse_threshold(cache_params);
config.start_percent = cache_params.start_percent;
config.end_percent = cache_params.end_percent;
config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params.error_decay_rate));
config.use_relative_threshold = cache_params.use_relative_threshold;
config.reset_error_on_compute = cache_params.reset_error_on_compute;
runtime.ucache.init(config, denoiser);
if (!runtime.ucache.enabled()) {
LOG_WARN("UCache requested but could not be initialized for this run");
return;
}
runtime.ucache.set_sigmas(sigmas);
runtime.mode = SampleCacheMode::UCACHE;
LOG_INFO("UCache enabled - threshold: %.3f, start: %.2f, end: %.2f, decay: %.2f, relative: %s, reset: %s",
config.reuse_threshold,
config.start_percent,
config.end_percent,
config.error_decay_rate,
config.use_relative_threshold ? "true" : "false",
config.reset_error_on_compute ? "true" : "false");
}
static void init_cachedit_runtime(SampleCacheRuntime& runtime,
SDVersion version,
const sd_cache_params_t& cache_params,
const std::vector<float>& sigmas) {
if (!sd_version_is_dit(version)) {
LOG_WARN("CacheDIT requested but not supported for this model type (only DiT models)");
return;
}
DBCacheConfig dbcfg;
dbcfg.enabled = (cache_params.mode == SD_CACHE_DBCACHE || cache_params.mode == SD_CACHE_CACHE_DIT);
dbcfg.Fn_compute_blocks = cache_params.Fn_compute_blocks;
dbcfg.Bn_compute_blocks = cache_params.Bn_compute_blocks;
dbcfg.residual_diff_threshold = cache_params.residual_diff_threshold;
dbcfg.max_warmup_steps = cache_params.max_warmup_steps;
dbcfg.max_cached_steps = cache_params.max_cached_steps;
dbcfg.max_continuous_cached_steps = cache_params.max_continuous_cached_steps;
if (cache_params.scm_mask != nullptr && strlen(cache_params.scm_mask) > 0) {
dbcfg.steps_computation_mask = parse_scm_mask(cache_params.scm_mask);
}
dbcfg.scm_policy_dynamic = cache_params.scm_policy_dynamic;
TaylorSeerConfig tcfg;
tcfg.enabled = (cache_params.mode == SD_CACHE_TAYLORSEER || cache_params.mode == SD_CACHE_CACHE_DIT);
tcfg.n_derivatives = cache_params.taylorseer_n_derivatives;
tcfg.skip_interval_steps = cache_params.taylorseer_skip_interval;
runtime.cachedit.init(dbcfg, tcfg);
if (!runtime.cachedit.enabled()) {
LOG_WARN("CacheDIT requested but could not be initialized for this run");
return;
}
runtime.cachedit.set_sigmas(sigmas);
runtime.mode = SampleCacheMode::CACHEDIT;
LOG_INFO("CacheDIT enabled - mode: %s, Fn: %d, Bn: %d, threshold: %.3f, warmup: %d",
cache_params.mode == SD_CACHE_CACHE_DIT ? "DBCache+TaylorSeer" : (cache_params.mode == SD_CACHE_DBCACHE ? "DBCache" : "TaylorSeer"),
dbcfg.Fn_compute_blocks,
dbcfg.Bn_compute_blocks,
dbcfg.residual_diff_threshold,
dbcfg.max_warmup_steps);
}
static void init_spectrum_runtime(SampleCacheRuntime& runtime,
SDVersion version,
const sd_cache_params_t& cache_params,
const std::vector<float>& sigmas) {
if (!sd_version_is_unet(version) && !sd_version_is_dit(version)) {
LOG_WARN("Spectrum requested but not supported for this model type (only UNET and DiT models)");
return;
}
SpectrumConfig config;
config.w = cache_params.spectrum_w;
config.m = cache_params.spectrum_m;
config.lam = cache_params.spectrum_lam;
config.window_size = cache_params.spectrum_window_size;
config.flex_window = cache_params.spectrum_flex_window;
config.warmup_steps = cache_params.spectrum_warmup_steps;
config.stop_percent = cache_params.spectrum_stop_percent;
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
runtime.spectrum.init(config, total_steps);
runtime.spectrum_enabled = true;
LOG_INFO("Spectrum enabled - w: %.2f, m: %d, lam: %.2f, window: %d, flex: %.2f, warmup: %d, stop: %.0f%%",
config.w, config.m, config.lam,
config.window_size, config.flex_window,
config.warmup_steps, config.stop_percent * 100.0f);
}
SampleCacheRuntime init_sample_cache_runtime(SDVersion version,
const sd_cache_params_t* cache_params,
Denoiser* denoiser,
const std::vector<float>& sigmas) {
SampleCacheRuntime runtime;
if (cache_params == nullptr || cache_params->mode == SD_CACHE_DISABLED) {
return runtime;
}
if (!has_valid_cache_percent_range(*cache_params)) {
LOG_WARN("Cache disabled due to invalid percent range (start=%.3f, end=%.3f)",
cache_params->start_percent,
cache_params->end_percent);
return runtime;
}
switch (cache_params->mode) {
case SD_CACHE_EASYCACHE:
init_easycache_runtime(runtime, version, *cache_params, denoiser);
break;
case SD_CACHE_UCACHE:
init_ucache_runtime(runtime, version, *cache_params, denoiser, sigmas);
break;
case SD_CACHE_DBCACHE:
case SD_CACHE_TAYLORSEER:
case SD_CACHE_CACHE_DIT:
init_cachedit_runtime(runtime, version, *cache_params, sigmas);
break;
case SD_CACHE_SPECTRUM:
init_spectrum_runtime(runtime, version, *cache_params, sigmas);
break;
default:
break;
}
return runtime;
}
SampleStepCacheDispatcher::SampleStepCacheDispatcher(SampleCacheRuntime& runtime, int step, float sigma)
: runtime(runtime), step(step), sigma(sigma), step_index(step > 0 ? (step - 1) : -1) {
if (step_index < 0) {
return;
}
switch (runtime.mode) {
case SampleCacheMode::EASYCACHE:
runtime.easycache.begin_step(step_index, sigma);
break;
case SampleCacheMode::UCACHE:
runtime.ucache.begin_step(step_index, sigma);
break;
case SampleCacheMode::CACHEDIT:
runtime.cachedit.begin_step(step_index, sigma);
break;
case SampleCacheMode::NONE:
break;
}
}
bool SampleStepCacheDispatcher::before_condition(const void* condition,
const sd::Tensor<float>& input,
sd::Tensor<float>* output) {
if (step_index < 0 || condition == nullptr || output == nullptr) {
return false;
}
switch (runtime.mode) {
case SampleCacheMode::EASYCACHE:
return runtime.easycache.before_condition(condition, input, output, sigma, step_index);
case SampleCacheMode::UCACHE:
return runtime.ucache.before_condition(condition, input, output, sigma, step_index);
case SampleCacheMode::CACHEDIT:
return runtime.cachedit.before_condition(condition, input, output, sigma, step_index);
case SampleCacheMode::NONE:
return false;
}
return false;
}
void SampleStepCacheDispatcher::after_condition(const void* condition,
const sd::Tensor<float>& input,
const sd::Tensor<float>& output) {
if (step_index < 0 || condition == nullptr) {
return;
}
switch (runtime.mode) {
case SampleCacheMode::EASYCACHE:
runtime.easycache.after_condition(condition, input, output);
break;
case SampleCacheMode::UCACHE:
runtime.ucache.after_condition(condition, input, output);
break;
case SampleCacheMode::CACHEDIT:
runtime.cachedit.after_condition(condition, input, output);
break;
case SampleCacheMode::NONE:
break;
}
}
bool SampleStepCacheDispatcher::is_step_skipped() const {
switch (runtime.mode) {
case SampleCacheMode::EASYCACHE:
return runtime.easycache.is_step_skipped();
case SampleCacheMode::UCACHE:
return runtime.ucache.is_step_skipped();
case SampleCacheMode::CACHEDIT:
return runtime.cachedit.is_step_skipped();
case SampleCacheMode::NONE:
return false;
}
return false;
}
void log_sample_cache_summary(const SampleCacheRuntime& runtime, size_t total_steps) {
if (runtime.easycache_enabled()) {
if (runtime.easycache.total_steps_skipped > 0 && total_steps > 0) {
if (runtime.easycache.total_steps_skipped < static_cast<int>(total_steps)) {
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - runtime.easycache.total_steps_skipped);
LOG_INFO("EasyCache skipped %d/%zu steps (%.2fx estimated speedup)",
runtime.easycache.total_steps_skipped,
total_steps,
speedup);
} else {
LOG_INFO("EasyCache skipped %d/%zu steps",
runtime.easycache.total_steps_skipped,
total_steps);
}
} else if (total_steps > 0) {
LOG_INFO("EasyCache completed without skipping steps");
}
}
if (runtime.ucache_enabled()) {
if (runtime.ucache.total_steps_skipped > 0 && total_steps > 0) {
if (runtime.ucache.total_steps_skipped < static_cast<int>(total_steps)) {
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - runtime.ucache.total_steps_skipped);
LOG_INFO("UCache skipped %d/%zu steps (%.2fx estimated speedup)",
runtime.ucache.total_steps_skipped,
total_steps,
speedup);
} else {
LOG_INFO("UCache skipped %d/%zu steps",
runtime.ucache.total_steps_skipped,
total_steps);
}
} else if (total_steps > 0) {
LOG_INFO("UCache completed without skipping steps");
}
}
if (runtime.cachedit_enabled()) {
if (runtime.cachedit.total_steps_skipped > 0 && total_steps > 0) {
if (runtime.cachedit.total_steps_skipped < static_cast<int>(total_steps)) {
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - runtime.cachedit.total_steps_skipped);
LOG_INFO("CacheDIT skipped %d/%zu steps (%.2fx estimated speedup)",
runtime.cachedit.total_steps_skipped,
total_steps,
speedup);
} else {
LOG_INFO("CacheDIT skipped %d/%zu steps",
runtime.cachedit.total_steps_skipped,
total_steps);
}
} else if (total_steps > 0) {
LOG_INFO("CacheDIT completed without skipping steps");
}
}
if (runtime.spectrum_enabled && runtime.spectrum.total_steps_skipped > 0 && total_steps > 0) {
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - runtime.spectrum.total_steps_skipped);
LOG_INFO("Spectrum skipped %d/%zu steps (%.2fx estimated speedup)",
runtime.spectrum.total_steps_skipped,
total_steps,
speedup);
}
}
} // namespace sd_sample