stable-diffusion.cpp/examples/server/async_jobs.cpp
2026-04-11 17:49:00 +08:00

276 lines
9.1 KiB
C++

// Extracted from main.cpp during server refactor.
#include "async_jobs.h"
#include <iomanip>
#include <sstream>
#include "common/log.h"
#include "common/media_io.h"
#include "common/resource_owners.hpp"
const char* async_job_kind_name(AsyncJobKind kind) {
switch (kind) {
case AsyncJobKind::ImgGen:
return "img_gen";
case AsyncJobKind::VidGen:
return "vid_gen";
default:
return "img_gen";
}
}
const char* async_job_status_name(AsyncJobStatus status) {
switch (status) {
case AsyncJobStatus::Queued:
return "queued";
case AsyncJobStatus::Generating:
return "generating";
case AsyncJobStatus::Completed:
return "completed";
case AsyncJobStatus::Failed:
return "failed";
case AsyncJobStatus::Cancelled:
return "cancelled";
default:
return "failed";
}
}
void purge_expired_jobs(AsyncJobManager& manager) {
const int64_t now = unix_timestamp_now();
for (auto it = manager.expired_jobs.begin(); it != manager.expired_jobs.end();) {
if (it->second <= now) {
it = manager.expired_jobs.erase(it);
} else {
++it;
}
}
for (auto it = manager.jobs.begin(); it != manager.jobs.end();) {
const auto& job = it->second;
if (job->completed_at == 0) {
++it;
continue;
}
int64_t ttl_seconds = job->status == AsyncJobStatus::Completed
? manager.completed_ttl_seconds
: manager.failed_ttl_seconds;
if (now - job->completed_at >= ttl_seconds) {
manager.expired_jobs[job->id] = now + std::max<int64_t>(ttl_seconds, 60);
it = manager.jobs.erase(it);
} else {
++it;
}
}
}
size_t count_pending_jobs(const AsyncJobManager& manager) {
size_t pending = 0;
for (const auto& entry : manager.jobs) {
if (entry.second->status == AsyncJobStatus::Queued ||
entry.second->status == AsyncJobStatus::Generating) {
++pending;
}
}
return pending;
}
std::string make_async_job_id(AsyncJobManager& manager) {
std::ostringstream oss;
oss << "job_" << std::hex << unix_timestamp_now() << "_" << std::setw(8)
<< std::setfill('0') << manager.next_id++;
return oss.str();
}
bool cancel_queued_job(AsyncJobManager& manager, AsyncGenerationJob& job) {
auto new_end = std::remove(manager.queue.begin(), manager.queue.end(), job.id);
if (new_end == manager.queue.end()) {
return false;
}
manager.queue.erase(new_end, manager.queue.end());
job.status = AsyncJobStatus::Cancelled;
job.completed_at = unix_timestamp_now();
job.result_images_b64.clear();
job.error_code = "cancelled";
job.error_message = "job cancelled by client";
return true;
}
json make_async_job_json(const AsyncJobManager& manager, const AsyncGenerationJob& job) {
json result;
result["id"] = job.id;
result["kind"] = async_job_kind_name(job.kind);
result["status"] = async_job_status_name(job.status);
result["created"] = job.created_at;
result["started"] = job.started_at == 0 ? json(nullptr) : json(job.started_at);
result["completed"] = job.completed_at == 0 ? json(nullptr) : json(job.completed_at);
result["queue_position"] = 0;
if (job.status == AsyncJobStatus::Queued) {
size_t position = 1;
for (const auto& queued_id : manager.queue) {
if (queued_id == job.id) {
result["queue_position"] = position;
break;
}
++position;
}
}
if (job.status == AsyncJobStatus::Completed) {
json images = json::array();
for (size_t i = 0; i < job.result_images_b64.size(); ++i) {
images.push_back({{"index", i}, {"b64_json", job.result_images_b64[i]}});
}
result["result"] = {
{"output_format", job.img_gen.output_format},
{"images", images},
};
result["error"] = nullptr;
} else if (job.status == AsyncJobStatus::Failed ||
job.status == AsyncJobStatus::Cancelled) {
result["result"] = nullptr;
result["error"] = {
{"code",
job.error_code.empty()
? (job.status == AsyncJobStatus::Cancelled ? "cancelled" : "generation_failed")
: job.error_code},
{"message", job.error_message},
};
} else {
result["result"] = nullptr;
result["error"] = nullptr;
}
return result;
}
bool execute_img_gen_job(ServerRuntime& runtime,
AsyncGenerationJob& job,
std::vector<std::string>& output_images,
std::string& error_message) {
sd_img_gen_params_t params = job.img_gen.to_sd_img_gen_params_t();
SDImageVec results;
int num_results = 0;
{
std::lock_guard<std::mutex> lock(*runtime.sd_ctx_mutex);
sd_image_t* raw_results = generate_image(runtime.sd_ctx, &params);
num_results = params.batch_count;
results.adopt(raw_results, num_results);
}
if (results.empty() || num_results <= 0) {
error_message = "generate_image returned no results";
return false;
}
EncodedImageFormat encoded_format = EncodedImageFormat::PNG;
if (job.img_gen.output_format == "jpeg") {
encoded_format = EncodedImageFormat::JPEG;
} else if (job.img_gen.output_format == "webp") {
encoded_format = EncodedImageFormat::WEBP;
}
for (int i = 0; i < num_results; ++i) {
if (results[i].data == nullptr) {
continue;
}
const std::string metadata = job.img_gen.gen_params.embed_image_metadata
? get_image_params(*runtime.ctx_params,
job.img_gen.gen_params,
job.img_gen.gen_params.seed + i)
: "";
auto image_bytes = encode_image_to_vector(encoded_format,
results[i].data,
results[i].width,
results[i].height,
results[i].channel,
metadata,
job.img_gen.output_compression);
if (image_bytes.empty()) {
continue;
}
output_images.push_back(base64_encode(image_bytes));
}
if (output_images.empty()) {
error_message = "generate_image returned empty encoded outputs";
return false;
}
return true;
}
void async_job_worker(ServerRuntime& runtime) {
AsyncJobManager& manager = *runtime.async_job_manager;
while (true) {
std::shared_ptr<AsyncGenerationJob> job;
{
std::unique_lock<std::mutex> lock(manager.mutex);
manager.cv.wait(lock, [&]() { return manager.stop || !manager.queue.empty(); });
if (manager.stop && manager.queue.empty()) {
break;
}
purge_expired_jobs(manager);
if (manager.queue.empty()) {
continue;
}
const std::string job_id = manager.queue.front();
manager.queue.pop_front();
auto it = manager.jobs.find(job_id);
if (it == manager.jobs.end()) {
continue;
}
job = it->second;
job->status = AsyncJobStatus::Generating;
job->started_at = unix_timestamp_now();
}
std::vector<std::string> output_images;
std::string error_message;
bool ok = false;
if (job->kind == AsyncJobKind::ImgGen) {
ok = execute_img_gen_job(runtime, *job, output_images, error_message);
} else {
error_message = "unsupported job kind";
}
{
std::lock_guard<std::mutex> lock(manager.mutex);
auto it = manager.jobs.find(job->id);
if (it == manager.jobs.end()) {
continue;
}
job->completed_at = unix_timestamp_now();
if (ok) {
job->status = AsyncJobStatus::Completed;
job->result_images_b64 = std::move(output_images);
job->error_code.clear();
job->error_message.clear();
} else {
job->status = AsyncJobStatus::Failed;
job->error_code = "generation_failed";
job->error_message = error_message.empty() ? "unknown generation error" : error_message;
job->result_images_b64.clear();
}
purge_expired_jobs(manager);
}
}
}