mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-03-12 09:50:16 +01:00
fix(vae): Fix dtype mismatch in FP32 VAE decode mode
The previous mixed-precision optimization for FP32 mode only converted some VAE decoder layers (post_quant_conv, conv_in, mid_block) to the latents dtype while leaving others (up_blocks, conv_norm_out) in float32. This caused "expected scalar type Half but found Float" errors after recent diffusers updates. Simplify FP32 mode to consistently use float32 for both VAE and latents, removing the incomplete mixed-precision logic. This trades some VRAM usage for stability and correctness. Also removes now-unused attention processor imports.
This commit is contained in:
parent
4ce0ef5260
commit
f417c269d1
@ -2,12 +2,6 @@ from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
|
||||
@ -77,26 +71,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
latents = latents.to(TorchDevice.choose_torch_device())
|
||||
if self.fp32:
|
||||
# FP32 mode: convert everything to float32 for maximum precision
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||
vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
vae.post_quant_conv.to(latents.dtype)
|
||||
vae.decoder.conv_in.to(latents.dtype)
|
||||
vae.decoder.mid_block.to(latents.dtype)
|
||||
else:
|
||||
latents = latents.float()
|
||||
|
||||
latents = latents.float()
|
||||
else:
|
||||
vae.to(dtype=torch.float16)
|
||||
latents = latents.half()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user