Update the model loading logic for several of the large FLUX-related models to ensure that the model is initialized on the meta device prior to loading the state dict into it. This helps to keep peak memory down.

This commit is contained in:
Ryan Dick 2025-01-16 02:30:28 +00:00
parent b57aa06d9e
commit b2bb359d47

View File

@ -80,19 +80,19 @@ class FluxVAELoader(ModelLoader):
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path)
with SilenceWarnings():
with accelerate.init_empty_weights():
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
# VAE is broken in float16, which mps defaults to
if self._torch_dtype == torch.float16:
try:
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
except TypeError:
vae_dtype = torch.float32
else:
vae_dtype = self._torch_dtype
model.to(vae_dtype)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
# VAE is broken in float16, which mps defaults to
if self._torch_dtype == torch.float16:
try:
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
except TypeError:
vae_dtype = torch.float32
else:
vae_dtype = self._torch_dtype
model.to(vae_dtype)
return model
@ -183,7 +183,9 @@ class T5EncoderCheckpointModel(ModelLoader):
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
return T5EncoderModel.from_pretrained(
Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True
)
raise ValueError(
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
@ -217,17 +219,18 @@ class FluxCheckpointModel(ModelLoader):
assert isinstance(config, MainCheckpointConfig)
model_path = Path(config.path)
with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
self._ram_cache.make_room(new_sd_size)
for k in sd.keys():
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
sd[k] = sd[k].to(torch.bfloat16)
model.load_state_dict(sd, assign=True)
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()])
self._ram_cache.make_room(new_sd_size)
for k in sd.keys():
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
sd[k] = sd[k].to(torch.bfloat16)
model.load_state_dict(sd, assign=True)
return model
@ -258,11 +261,11 @@ class FluxGGUFCheckpointModel(ModelLoader):
assert isinstance(config, MainGGUFCheckpointConfig)
model_path = Path(config.path)
with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
# HACK(ryand): We shouldn't be hard-coding the compute_dtype here.
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
# HACK(ryand): There are some broken GGUF models in circulation that have the wrong shape for img_in.weight.
# We override the shape here to fix the issue.