fix(vae): Fix dtype mismatch in FP32 VAE decode mode

The previous mixed-precision optimization for FP32 mode only converted
some VAE decoder layers (post_quant_conv, conv_in, mid_block) to the
latents dtype while leaving others (up_blocks, conv_norm_out) in float32.
This caused "expected scalar type Half but found Float" errors after
recent diffusers updates.

Simplify FP32 mode to consistently use float32 for both VAE and latents,
removing the incomplete mixed-precision logic. This trades some VRAM
usage for stability and correctness.

Also removes now-unused attention processor imports.
This commit is contained in:
Alexander Eichhorn 2025-12-16 15:58:48 +01:00
parent 4ce0ef5260
commit f417c269d1

View File

@ -2,12 +2,6 @@ from contextlib import nullcontext
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
@ -77,26 +71,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(TorchDevice.choose_torch_device())
if self.fp32:
# FP32 mode: convert everything to float32 for maximum precision
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(latents.dtype)
vae.decoder.conv_in.to(latents.dtype)
vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
latents = latents.float()
else:
vae.to(dtype=torch.float16)
latents = latents.half()