stable-diffusion.cpp/examples/server/routes_sdapi.cpp
2026-04-11 17:49:00 +08:00

400 lines
16 KiB
C++

#include "routes.h"
#include <algorithm>
#include <cstring>
#include <regex>
#include <string_view>
#include <unordered_map>
#include "common/common.h"
#include "common/media_io.h"
#include "common/resource_owners.hpp"
namespace fs = std::filesystem;
static std::string extract_and_remove_sd_cpp_extra_args(std::string& text) {
std::regex re("<sd_cpp_extra_args>(.*?)</sd_cpp_extra_args>");
std::smatch match;
std::string extracted;
if (std::regex_search(text, match, re)) {
extracted = match[1].str();
text = std::regex_replace(text, re, "");
}
return extracted;
}
static fs::path resolve_display_model_path(const ServerRuntime& runtime) {
const auto& ctx = *runtime.ctx_params;
if (!ctx.model_path.empty()) {
return fs::path(ctx.model_path);
}
if (!ctx.diffusion_model_path.empty()) {
return fs::path(ctx.diffusion_model_path);
}
return {};
}
static enum sample_method_t get_sdapi_sample_method(std::string name) {
enum sample_method_t result = str_to_sample_method(name.c_str());
if (result != SAMPLE_METHOD_COUNT) {
return result;
}
std::transform(name.begin(), name.end(), name.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
static const std::unordered_map<std::string_view, sample_method_t> hardcoded{
{"euler a", EULER_A_SAMPLE_METHOD},
{"k_euler_a", EULER_A_SAMPLE_METHOD},
{"euler", EULER_SAMPLE_METHOD},
{"k_euler", EULER_SAMPLE_METHOD},
{"heun", HEUN_SAMPLE_METHOD},
{"k_heun", HEUN_SAMPLE_METHOD},
{"dpm2", DPM2_SAMPLE_METHOD},
{"k_dpm_2", DPM2_SAMPLE_METHOD},
{"lcm", LCM_SAMPLE_METHOD},
{"ddim", DDIM_TRAILING_SAMPLE_METHOD},
{"dpm++ 2m", DPMPP2M_SAMPLE_METHOD},
{"k_dpmpp_2m", DPMPP2M_SAMPLE_METHOD},
{"res multistep", RES_MULTISTEP_SAMPLE_METHOD},
{"k_res_multistep", RES_MULTISTEP_SAMPLE_METHOD},
{"res 2s", RES_2S_SAMPLE_METHOD},
{"k_res_2s", RES_2S_SAMPLE_METHOD},
};
auto it = hardcoded.find(name);
return it != hardcoded.end() ? it->second : SAMPLE_METHOD_COUNT;
}
static void assign_solid_mask(SDImageOwner& mask_owner, int width, int height) {
const size_t pixel_count = static_cast<size_t>(width) * static_cast<size_t>(height);
uint8_t* raw_mask = static_cast<uint8_t*>(malloc(pixel_count));
if (raw_mask == nullptr) {
mask_owner.reset({0, 0, 1, nullptr});
return;
}
std::memset(raw_mask, 255, pixel_count);
mask_owner.reset({(uint32_t)width, (uint32_t)height, 1, raw_mask});
}
static bool build_sdapi_img_gen_request(const json& j,
ServerRuntime& runtime,
bool img2img,
ImgGenJobRequest& request,
std::string& error_message) {
std::string prompt = j.value("prompt", "");
std::string negative_prompt = j.value("negative_prompt", "");
int width = j.value("width", 512);
int height = j.value("height", 512);
int steps = j.value("steps", runtime.default_gen_params->sample_params.sample_steps);
float cfg_scale = j.value("cfg_scale", runtime.default_gen_params->sample_params.guidance.txt_cfg);
int64_t seed = j.value("seed", -1);
int batch_size = j.value("batch_size", 1);
int clip_skip = j.value("clip_skip", -1);
std::string sampler_name = j.value("sampler_name", "");
std::string scheduler_name = j.value("scheduler", "");
if (width <= 0 || height <= 0) {
error_message = "width and height must be positive";
return false;
}
if (prompt.empty()) {
error_message = "prompt required";
return false;
}
request.gen_params = *runtime.default_gen_params;
request.gen_params.prompt = prompt;
request.gen_params.negative_prompt = negative_prompt;
request.gen_params.seed = seed;
request.gen_params.sample_params.sample_steps = steps;
request.gen_params.batch_count = batch_size;
request.gen_params.sample_params.guidance.txt_cfg = cfg_scale;
request.gen_params.width = j.value("width", -1);
request.gen_params.height = j.value("height", -1);
std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(request.gen_params.prompt);
if (!sd_cpp_extra_args_str.empty() && !request.gen_params.from_json_str(sd_cpp_extra_args_str)) {
error_message = "invalid sd_cpp_extra_args";
return false;
}
if (clip_skip > 0) {
request.gen_params.clip_skip = clip_skip;
}
enum sample_method_t sample_method = get_sdapi_sample_method(sampler_name);
if (sample_method != SAMPLE_METHOD_COUNT) {
request.gen_params.sample_params.sample_method = sample_method;
}
enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str());
if (scheduler != SCHEDULER_COUNT) {
request.gen_params.sample_params.scheduler = scheduler;
}
if (j.contains("lora") && j["lora"].is_array()) {
request.gen_params.lora_map.clear();
request.gen_params.high_noise_lora_map.clear();
for (const auto& item : j["lora"]) {
if (!item.is_object()) {
continue;
}
std::string path = item.value("path", "");
float multiplier = item.value("multiplier", 1.0f);
bool is_high_noise = item.value("is_high_noise", false);
if (path.empty()) {
error_message = "lora.path required";
return false;
}
std::string fullpath = get_lora_full_path(runtime, path);
if (fullpath.empty()) {
error_message = "invalid lora path: " + path;
return false;
}
if (is_high_noise) {
request.gen_params.high_noise_lora_map[fullpath] += multiplier;
} else {
request.gen_params.lora_map[fullpath] += multiplier;
}
}
}
if (img2img) {
const int expected_width = request.gen_params.width_and_height_are_set() ? request.gen_params.width : 0;
const int expected_height = request.gen_params.width_and_height_are_set() ? request.gen_params.height : 0;
if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) {
if (decode_base64_image(j["init_images"][0].get<std::string>(),
3,
expected_width,
expected_height,
request.gen_params.init_image)) {
const sd_image_t& image = request.gen_params.init_image.get();
request.gen_params.set_width_and_height_if_unset(image.width, image.height);
}
}
if (j.contains("mask") && j["mask"].is_string()) {
if (decode_base64_image(j["mask"].get<std::string>(),
1,
expected_width,
expected_height,
request.gen_params.mask_image)) {
const sd_image_t& image = request.gen_params.mask_image.get();
request.gen_params.set_width_and_height_if_unset(image.width, image.height);
}
sd_image_t& mask_image = request.gen_params.mask_image.get();
bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0;
if (inpainting_mask_invert && mask_image.data != nullptr) {
for (uint32_t i = 0; i < mask_image.width * mask_image.height; ++i) {
mask_image.data[i] = 255 - mask_image.data[i];
}
}
} else {
const int resolved_width = request.gen_params.get_resolved_width();
const int resolved_height = request.gen_params.get_resolved_height();
assign_solid_mask(request.gen_params.mask_image, resolved_width, resolved_height);
}
float denoising_strength = j.value("denoising_strength", -1.f);
if (denoising_strength >= 0.f) {
request.gen_params.strength = std::min(denoising_strength, 1.0f);
}
}
if (j.contains("extra_images") && j["extra_images"].is_array()) {
for (const auto& extra_image : j["extra_images"]) {
if (!extra_image.is_string()) {
continue;
}
SDImageOwner image_owner;
if (decode_base64_image(extra_image.get<std::string>(),
3,
request.gen_params.width_and_height_are_set() ? request.gen_params.width : 0,
request.gen_params.width_and_height_are_set() ? request.gen_params.height : 0,
image_owner)) {
const sd_image_t& image = image_owner.get();
request.gen_params.set_width_and_height_if_unset(image.width, image.height);
request.gen_params.ref_images.push_back(std::move(image_owner));
}
}
}
// Intentionally disable prompt-embedded LoRA tag parsing for server APIs.
if (!request.gen_params.resolve_and_validate(IMG_GEN, "", true)) {
error_message = "invalid params";
return false;
}
return true;
}
void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
ServerRuntime* runtime = &rt;
auto sdapi_any2img = [runtime](const httplib::Request& req, httplib::Response& res, bool img2img) {
try {
if (req.body.empty()) {
res.status = 400;
res.set_content(R"({"error":"empty body"})", "application/json");
return;
}
json j = json::parse(req.body);
ImgGenJobRequest request;
std::string error_message;
if (!build_sdapi_img_gen_request(j, *runtime, img2img, request, error_message)) {
res.status = 400;
res.set_content(json({{"error", error_message}}).dump(), "application/json");
return;
}
LOG_DEBUG("%s\n", request.gen_params.to_string().c_str());
sd_img_gen_params_t img_gen_params = request.to_sd_img_gen_params_t();
SDImageVec results;
int num_results = 0;
{
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
sd_image_t* raw_results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = request.gen_params.batch_count;
results.adopt(raw_results, num_results);
}
if (results.empty()) {
res.status = 500;
res.set_content(R"({"error":"generate_image returned no results"})", "application/json");
return;
}
json out;
out["images"] = json::array();
out["parameters"] = j;
out["info"] = "";
for (int i = 0; i < num_results; ++i) {
if (results[i].data == nullptr) {
continue;
}
std::string params = request.gen_params.embed_image_metadata
? get_image_params(*runtime->ctx_params,
request.gen_params,
request.gen_params.seed + i)
: "";
auto image_bytes = encode_image_to_vector(EncodedImageFormat::PNG,
results[i].data,
results[i].width,
results[i].height,
results[i].channel,
params);
if (image_bytes.empty()) {
LOG_ERROR("write image to mem failed");
continue;
}
out["images"].push_back(base64_encode(image_bytes));
}
res.set_content(out.dump(), "application/json");
res.status = 200;
} catch (const std::exception& e) {
res.status = 500;
json err;
err["error"] = "server_error";
err["message"] = e.what();
res.set_content(err.dump(), "application/json");
}
};
svr.Post("/sdapi/v1/txt2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) {
sdapi_any2img(req, res, false);
});
svr.Post("/sdapi/v1/img2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) {
sdapi_any2img(req, res, true);
});
svr.Get("/sdapi/v1/loras", [runtime](const httplib::Request&, httplib::Response& res) {
refresh_lora_cache(*runtime);
json result = json::array();
{
std::lock_guard<std::mutex> lock(*runtime->lora_mutex);
for (const auto& e : *runtime->lora_cache) {
json item;
item["name"] = e.name;
item["path"] = e.path;
result.push_back(item);
}
}
res.set_content(result.dump(), "application/json");
});
svr.Get("/sdapi/v1/samplers", [runtime](const httplib::Request&, httplib::Response& res) {
std::vector<std::string> sampler_names;
sampler_names.push_back("default");
for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) {
sampler_names.push_back(sd_sample_method_name((sample_method_t)i));
}
json r = json::array();
for (auto name : sampler_names) {
json entry;
entry["name"] = name;
entry["aliases"] = json::array({name});
entry["options"] = json::object();
r.push_back(entry);
}
res.set_content(r.dump(), "application/json");
});
svr.Get("/sdapi/v1/schedulers", [runtime](const httplib::Request&, httplib::Response& res) {
std::vector<std::string> scheduler_names;
scheduler_names.push_back("default");
for (int i = 0; i < SCHEDULER_COUNT; i++) {
scheduler_names.push_back(sd_scheduler_name((scheduler_t)i));
}
json r = json::array();
for (auto name : scheduler_names) {
json entry;
entry["name"] = name;
entry["label"] = name;
r.push_back(entry);
}
res.set_content(r.dump(), "application/json");
});
svr.Get("/sdapi/v1/sd-models", [runtime](const httplib::Request&, httplib::Response& res) {
fs::path model_path = resolve_display_model_path(*runtime);
json entry;
entry["title"] = model_path.stem();
entry["model_name"] = model_path.stem();
entry["filename"] = model_path.filename();
entry["hash"] = "8888888888";
entry["sha256"] = "8888888888888888888888888888888888888888888888888888888888888888";
entry["config"] = nullptr;
json r = json::array();
r.push_back(entry);
res.set_content(r.dump(), "application/json");
});
svr.Get("/sdapi/v1/options", [runtime](const httplib::Request&, httplib::Response& res) {
fs::path model_path = resolve_display_model_path(*runtime);
json r;
r["samples_format"] = "png";
r["sd_model_checkpoint"] = model_path.stem();
res.set_content(r.dump(), "application/json");
});
}