#include #include #include #include #include #include #include "httplib.h" #include "async_jobs.h" #include "common/common.h" #include "common/resource_owners.hpp" #include "routes.h" #include "runtime.h" #ifdef HAVE_INDEX_HTML #include "frontend/dist/gen_index_html.h" #endif static void print_usage(const char* argv0, const std::vector& options_list) { std::cout << version_string() << "\n"; std::cout << "Usage: " << argv0 << " [options]\n\n"; std::cout << "Svr Options:\n"; options_list[0].print(); std::cout << "\nContext Options:\n"; options_list[1].print(); std::cout << "\nDefault Generation Options:\n"; options_list[2].print(); } static void parse_args(int argc, const char** argv, SDSvrParams& svr_params, SDContextParams& ctx_params, SDGenerationParams& default_gen_params) { std::vector options_vec = { svr_params.get_options(), ctx_params.get_options(), default_gen_params.get_options(), }; if (!parse_options(argc, argv, options_vec)) { print_usage(argv[0], options_vec); exit(svr_params.normal_exit ? 0 : 1); } const bool random_seed_requested = default_gen_params.seed < 0; if (!svr_params.resolve_and_validate() || !ctx_params.resolve_and_validate(IMG_GEN) || !default_gen_params.resolve_and_validate(IMG_GEN, ctx_params.lora_model_dir)) { print_usage(argv[0], options_vec); exit(1); } if (random_seed_requested) { default_gen_params.seed = -1; } } void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { SDSvrParams* svr_params = (SDSvrParams*)data; log_print(level, log, svr_params->verbose, svr_params->color); } int main(int argc, const char** argv) { if (argc > 1 && std::string(argv[1]) == "--version") { std::cout << version_string() << "\n"; return EXIT_SUCCESS; } SDSvrParams svr_params; SDContextParams ctx_params; SDGenerationParams default_gen_params; parse_args(argc, argv, svr_params, ctx_params, default_gen_params); sd_set_log_callback(sd_log_cb, (void*)&svr_params); log_verbose = svr_params.verbose; log_color = svr_params.color; LOG_DEBUG("version: %s", version_string().c_str()); LOG_DEBUG("%s", sd_get_system_info()); LOG_DEBUG("%s", svr_params.to_string().c_str()); LOG_DEBUG("%s", ctx_params.to_string().c_str()); LOG_DEBUG("%s", default_gen_params.to_string().c_str()); sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false); SDCtxPtr sd_ctx(new_sd_ctx(&sd_ctx_params)); if (sd_ctx == nullptr) { LOG_ERROR("new_sd_ctx_t failed"); return 1; } std::mutex sd_ctx_mutex; std::vector lora_cache; std::mutex lora_mutex; AsyncJobManager async_job_manager; ServerRuntime runtime = { sd_ctx.get(), &sd_ctx_mutex, &svr_params, &ctx_params, &default_gen_params, &lora_cache, &lora_mutex, &async_job_manager, }; std::thread async_worker(async_job_worker, std::ref(runtime)); httplib::Server svr; svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) { std::string origin = req.get_header_value("Origin"); if (origin.empty()) { origin = "*"; } res.set_header("Access-Control-Allow-Origin", origin); res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "*"); res.set_header("Access-Control-Allow-Headers", "*"); if (req.method == "OPTIONS") { res.status = 204; return httplib::Server::HandlerResponse::Handled; } return httplib::Server::HandlerResponse::Unhandled; }); std::string index_html; #ifdef HAVE_INDEX_HTML index_html.assign(reinterpret_cast(index_html_bytes), index_html_size); #else index_html = "Stable Diffusion Server is running"; #endif register_index_endpoints(svr, svr_params, index_html); register_openai_api_endpoints(svr, runtime); register_sdapi_endpoints(svr, runtime); register_sdcpp_api_endpoints(svr, runtime); LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port); svr.listen(svr_params.listen_ip, svr_params.listen_port); { std::lock_guard lock(async_job_manager.mutex); async_job_manager.stop = true; } async_job_manager.cv.notify_all(); async_worker.join(); return 0; }