mirror of
https://github.com/leejet/stable-diffusion.cpp
synced 2026-04-21 22:32:14 +02:00
153 lines
4.7 KiB
C++
153 lines
4.7 KiB
C++
#include <cstdlib>
|
|
#include <iostream>
|
|
#include <mutex>
|
|
#include <string>
|
|
#include <thread>
|
|
#include <vector>
|
|
|
|
#include "httplib.h"
|
|
|
|
#include "async_jobs.h"
|
|
#include "common/common.h"
|
|
#include "common/resource_owners.hpp"
|
|
#include "routes.h"
|
|
#include "runtime.h"
|
|
|
|
#ifdef HAVE_INDEX_HTML
|
|
#include "frontend/dist/gen_index_html.h"
|
|
#endif
|
|
|
|
static void print_usage(const char* argv0, const std::vector<ArgOptions>& options_list) {
|
|
std::cout << version_string() << "\n";
|
|
std::cout << "Usage: " << argv0 << " [options]\n\n";
|
|
std::cout << "Svr Options:\n";
|
|
options_list[0].print();
|
|
std::cout << "\nContext Options:\n";
|
|
options_list[1].print();
|
|
std::cout << "\nDefault Generation Options:\n";
|
|
options_list[2].print();
|
|
}
|
|
|
|
static void parse_args(int argc,
|
|
const char** argv,
|
|
SDSvrParams& svr_params,
|
|
SDContextParams& ctx_params,
|
|
SDGenerationParams& default_gen_params) {
|
|
std::vector<ArgOptions> options_vec = {
|
|
svr_params.get_options(),
|
|
ctx_params.get_options(),
|
|
default_gen_params.get_options(),
|
|
};
|
|
|
|
if (!parse_options(argc, argv, options_vec)) {
|
|
print_usage(argv[0], options_vec);
|
|
exit(svr_params.normal_exit ? 0 : 1);
|
|
}
|
|
|
|
const bool random_seed_requested = default_gen_params.seed < 0;
|
|
|
|
if (!svr_params.resolve_and_validate() ||
|
|
!ctx_params.resolve_and_validate(IMG_GEN) ||
|
|
!default_gen_params.resolve_and_validate(IMG_GEN, ctx_params.lora_model_dir)) {
|
|
print_usage(argv[0], options_vec);
|
|
exit(1);
|
|
}
|
|
|
|
if (random_seed_requested) {
|
|
default_gen_params.seed = -1;
|
|
}
|
|
}
|
|
|
|
void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
|
|
SDSvrParams* svr_params = (SDSvrParams*)data;
|
|
log_print(level, log, svr_params->verbose, svr_params->color);
|
|
}
|
|
|
|
int main(int argc, const char** argv) {
|
|
if (argc > 1 && std::string(argv[1]) == "--version") {
|
|
std::cout << version_string() << "\n";
|
|
return EXIT_SUCCESS;
|
|
}
|
|
SDSvrParams svr_params;
|
|
SDContextParams ctx_params;
|
|
SDGenerationParams default_gen_params;
|
|
parse_args(argc, argv, svr_params, ctx_params, default_gen_params);
|
|
|
|
sd_set_log_callback(sd_log_cb, (void*)&svr_params);
|
|
log_verbose = svr_params.verbose;
|
|
log_color = svr_params.color;
|
|
|
|
LOG_DEBUG("version: %s", version_string().c_str());
|
|
LOG_DEBUG("%s", sd_get_system_info());
|
|
LOG_DEBUG("%s", svr_params.to_string().c_str());
|
|
LOG_DEBUG("%s", ctx_params.to_string().c_str());
|
|
LOG_DEBUG("%s", default_gen_params.to_string().c_str());
|
|
|
|
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false);
|
|
SDCtxPtr sd_ctx(new_sd_ctx(&sd_ctx_params));
|
|
|
|
if (sd_ctx == nullptr) {
|
|
LOG_ERROR("new_sd_ctx_t failed");
|
|
return 1;
|
|
}
|
|
|
|
std::mutex sd_ctx_mutex;
|
|
|
|
std::vector<LoraEntry> lora_cache;
|
|
std::mutex lora_mutex;
|
|
AsyncJobManager async_job_manager;
|
|
ServerRuntime runtime = {
|
|
sd_ctx.get(),
|
|
&sd_ctx_mutex,
|
|
&svr_params,
|
|
&ctx_params,
|
|
&default_gen_params,
|
|
&lora_cache,
|
|
&lora_mutex,
|
|
&async_job_manager,
|
|
};
|
|
|
|
std::thread async_worker(async_job_worker, std::ref(runtime));
|
|
|
|
httplib::Server svr;
|
|
|
|
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
|
|
std::string origin = req.get_header_value("Origin");
|
|
if (origin.empty()) {
|
|
origin = "*";
|
|
}
|
|
res.set_header("Access-Control-Allow-Origin", origin);
|
|
res.set_header("Access-Control-Allow-Credentials", "true");
|
|
res.set_header("Access-Control-Allow-Methods", "*");
|
|
res.set_header("Access-Control-Allow-Headers", "*");
|
|
|
|
if (req.method == "OPTIONS") {
|
|
res.status = 204;
|
|
return httplib::Server::HandlerResponse::Handled;
|
|
}
|
|
return httplib::Server::HandlerResponse::Unhandled;
|
|
});
|
|
|
|
std::string index_html;
|
|
#ifdef HAVE_INDEX_HTML
|
|
index_html.assign(reinterpret_cast<const char*>(index_html_bytes), index_html_size);
|
|
#else
|
|
index_html = "Stable Diffusion Server is running";
|
|
#endif
|
|
register_index_endpoints(svr, svr_params, index_html);
|
|
register_openai_api_endpoints(svr, runtime);
|
|
register_sdapi_endpoints(svr, runtime);
|
|
register_sdcpp_api_endpoints(svr, runtime);
|
|
|
|
LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port);
|
|
svr.listen(svr_params.listen_ip, svr_params.listen_port);
|
|
|
|
{
|
|
std::lock_guard<std::mutex> lock(async_job_manager.mutex);
|
|
async_job_manager.stop = true;
|
|
}
|
|
async_job_manager.cv.notify_all();
|
|
async_worker.join();
|
|
return 0;
|
|
}
|