feat: add flux2 small decoder support (#1402)

This commit is contained in:
leejet 2026-04-08 23:13:25 +08:00 committed by GitHub
parent dd753729cc
commit e8323cabb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 6 deletions

View File

@ -8,6 +8,8 @@
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
- Download FLUX.2-small-decoder (full_encoder_small_decoder.safetensors) as an alternative VAE option
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-small-decoder/tree/main
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main
@ -31,6 +33,8 @@
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-4B-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
- Download FLUX.2-small-decoder (full_encoder_small_decoder.safetensors) as an alternative VAE option
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-small-decoder/tree/main
- Download Qwen3 4b
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-4B/tree/main/split_files/text_encoders
- gguf: https://huggingface.co/unsloth/Qwen3-4B-GGUF/tree/main

View File

@ -501,11 +501,36 @@ protected:
bool double_z = true;
} dd_config;
static std::string get_tensor_name(const std::string& prefix, const std::string& name) {
return prefix.empty() ? name : prefix + "." + name;
}
void detect_decoder_ch(const String2TensorStorage& tensor_storage_map,
const std::string& prefix,
int& decoder_ch) {
auto conv_in_iter = tensor_storage_map.find(get_tensor_name(prefix, "decoder.conv_in.weight"));
if (conv_in_iter != tensor_storage_map.end() && conv_in_iter->second.n_dims >= 4 && conv_in_iter->second.ne[3] > 0) {
int last_ch_mult = dd_config.ch_mult.back();
int64_t conv_in_out_channels = conv_in_iter->second.ne[3];
if (last_ch_mult > 0 && conv_in_out_channels % last_ch_mult == 0) {
decoder_ch = static_cast<int>(conv_in_out_channels / last_ch_mult);
LOG_INFO("vae decoder: ch = %d", decoder_ch);
} else {
LOG_WARN("vae decoder: failed to infer ch from %s (%" PRId64 " / %d)",
get_tensor_name(prefix, "decoder.conv_in.weight").c_str(),
conv_in_out_channels,
last_ch_mult);
}
}
}
public:
AutoEncoderKLModel(SDVersion version = VERSION_SD1,
bool decode_only = true,
bool use_linear_projection = false,
bool use_video_decoder = false)
AutoEncoderKLModel(SDVersion version = VERSION_SD1,
bool decode_only = true,
bool use_linear_projection = false,
bool use_video_decoder = false,
const String2TensorStorage& tensor_storage_map = {},
const std::string& prefix = "")
: version(version), decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (sd_version_is_dit(version)) {
if (sd_version_is_flux2(version)) {
@ -519,7 +544,9 @@ public:
if (use_video_decoder) {
use_quant = false;
}
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(dd_config.ch,
int decoder_ch = dd_config.ch;
detect_decoder_ch(tensor_storage_map, prefix, decoder_ch);
blocks["decoder"] = std::shared_ptr<GGMLBlock>(new Decoder(decoder_ch,
dd_config.out_ch,
dd_config.ch_mult,
dd_config.num_res_blocks,
@ -662,7 +689,7 @@ struct AutoEncoderKL : public VAE {
break;
}
}
ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder);
ae = AutoEncoderKLModel(version, decode_only, use_linear_projection, use_video_decoder, tensor_storage_map, prefix);
ae.init(params_ctx, tensor_storage_map, prefix);
}