From 2d86298b7fd60051fbc405e5fb4644bdf18bc8c2 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 24 Oct 2024 01:19:40 +0000 Subject: [PATCH] Add first draft of Sd3TextEncoderInvocation. --- invokeai/app/invocations/sd3_text_encoder.py | 206 ++++++++++++++++++ .../diffusion/conditioning_data.py | 25 ++- 2 files changed, 230 insertions(+), 1 deletion(-) create mode 100644 invokeai/app/invocations/sd3_text_encoder.py diff --git a/invokeai/app/invocations/sd3_text_encoder.py b/invokeai/app/invocations/sd3_text_encoder.py new file mode 100644 index 0000000000..bd989e3b9a --- /dev/null +++ b/invokeai/app/invocations/sd3_text_encoder.py @@ -0,0 +1,206 @@ +from contextlib import ExitStack +from typing import Iterator, Tuple + +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField +from invokeai.app.invocations.model import CLIPField, T5EncoderField +from invokeai.app.invocations.primitives import FluxConditioningOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX +from invokeai.backend.lora.lora_model_raw import LoRAModelRaw +from invokeai.backend.lora.lora_patcher import LoRAPatcher +from invokeai.backend.model_manager.config import ModelFormat +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo + + +@invocation( + "sd3_text_encoder", + title="SD3 Text Encoding", + tags=["prompt", "conditioning", "sd3"], + category="conditioning", + version="1.0.0", + classification=Classification.Prototype, +) +class Sd3TextEncoderInvocation(BaseInvocation): + """Encodes and preps a prompt for a SD3 image.""" + + clip_l: CLIPField = InputField( + title="CLIP L", + description=FieldDescriptions.clip, + input=Input.Connection, + ) + clip_g: CLIPField = InputField( + title="CLIP G", + description=FieldDescriptions.clip, + input=Input.Connection, + ) + + # The SD3 models were trained with text encoder dropout, so the T5 encoder can be omitted to save time/memory. + t5_encoder: T5EncoderField | None = InputField( + title="T5Encoder", + description=FieldDescriptions.t5_encoder, + input=Input.Connection, + ) + prompt: str = InputField(description="Text prompt to encode.") + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> FluxConditioningOutput: + # Note: The text encoding model are run in separate functions to ensure that all model references are locally + # scoped. This ensures that earlier models can be freed and gc'd before loading later models (if necessary). + + clip_l_embeddings, clip_l_pooled_embeddings = self._clip_encode(context, self.clip_l) + clip_g_embeddings, clip_g_pooled_embeddings = self._clip_encode(context, self.clip_g) + + t5_max_seq_len = 256 + t5_embeddings: torch.Tensor | None = None + if self.t5_encoder is not None: + t5_embeddings = self._t5_encode(context, t5_max_seq_len) + + conditioning_data = ConditioningFieldData( + conditionings=[ + SD3ConditioningInfo( + clip_l_embeds=clip_l_embeddings, + clip_l_pooled_embeds=clip_l_pooled_embeddings, + clip_g_embeds=clip_g_embeddings, + clip_g_pooled_embeds=clip_g_pooled_embeddings, + t5_embeds=t5_embeddings, + ) + ] + ) + + conditioning_name = context.conditioning.save(conditioning_data) + return FluxConditioningOutput.build(conditioning_name) + + def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor: + assert self.t5_encoder is not None + t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer) + t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder) + + prompt = [self.prompt] + + with ( + t5_text_encoder_info as t5_text_encoder, + t5_tokenizer_info as t5_tokenizer, + ): + assert isinstance(t5_text_encoder, T5EncoderModel) + assert isinstance(t5_tokenizer, T5Tokenizer) + + text_inputs = t5_tokenizer( + prompt, + padding="max_length", + max_length=max_seq_len, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = t5_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + assert isinstance(text_input_ids, torch.Tensor) + assert isinstance(untruncated_ids, torch.Tensor) + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = t5_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1]) + context.logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_seq_len} tokens: {removed_text}" + ) + + prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0] + + assert isinstance(prompt_embeds, torch.Tensor) + return prompt_embeds + + def _clip_encode( + self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77 + ) -> Tuple[torch.Tensor, torch.Tensor]: + clip_tokenizer_info = context.models.load(clip_model.tokenizer) + clip_text_encoder_info = context.models.load(clip_model.text_encoder) + + prompt = [self.prompt] + + with ( + clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder), + clip_tokenizer_info as clip_tokenizer, + ExitStack() as exit_stack, + ): + assert isinstance(clip_text_encoder, CLIPTextModel) + assert isinstance(clip_tokenizer, CLIPTokenizer) + + clip_text_encoder_config = clip_text_encoder_info.config + assert clip_text_encoder_config is not None + + # Apply LoRA models to the CLIP encoder. + # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching. + if clip_text_encoder_config.format in [ModelFormat.Diffusers]: + # The model is non-quantized, so we can apply the LoRA weights directly into the model. + exit_stack.enter_context( + LoRAPatcher.apply_lora_patches( + model=clip_text_encoder, + patches=self._clip_lora_iterator(context, clip_model), + prefix=FLUX_LORA_CLIP_PREFIX, + cached_weights=cached_weights, + ) + ) + else: + # There are currently no supported CLIP quantized models. Add support here if needed. + raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}") + + clip_text_encoder = clip_text_encoder.eval().requires_grad_(False) + + batch_encoding = clip_tokenizer( + prompt, + truncation=True, + max_length=77, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + outputs = clip_text_encoder( + input_ids=batch_encoding["input_ids"].to(clip_text_encoder.device), + attention_mask=None, + output_hidden_states=False, + ) + # TODO(ryand): Confirm that this is the correct output. ('last_hidden_state' is the default) + pooled_prompt_embeds = outputs["pooler_output"] + + text_inputs = clip_tokenizer( + prompt, + padding="max_length", + max_length=tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + assert isinstance(text_input_ids, torch.Tensor) + assert isinstance(untruncated_ids, torch.Tensor) + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = clip_tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1]) + context.logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = clip_text_encoder( + input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True + ) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + return prompt_embeds, pooled_prompt_embeds + + def _clip_lora_iterator( + self, context: InvocationContext, clip_model: CLIPField + ) -> Iterator[Tuple[LoRAModelRaw, float]]: + for lora in clip_model.loras: + lora_info = context.models.load(lora.lora) + assert isinstance(lora_info.model, LoRAModelRaw) + yield (lora_info.model, lora.weight) + del lora_info diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index b7e9038cf7..184cdb9b02 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -49,9 +49,32 @@ class FLUXConditioningInfo: return self +@dataclass +class SD3ConditioningInfo: + clip_l_pooled_embeds: torch.Tensor + clip_l_embeds: torch.Tensor + clip_g_pooled_embeds: torch.Tensor + clip_g_embeds: torch.Tensor + t5_embeds: torch.Tensor | None + + def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + self.clip_l_pooled_embeds = self.clip_l_pooled_embeds.to(device=device, dtype=dtype) + self.clip_l_embeds = self.clip_l_embeds.to(device=device, dtype=dtype) + self.clip_g_pooled_embeds = self.clip_g_pooled_embeds.to(device=device, dtype=dtype) + self.clip_g_embeds = self.clip_g_embeds.to(device=device, dtype=dtype) + if self.t5_embeds is not None: + self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype) + return self + + @dataclass class ConditioningFieldData: - conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo] + conditionings: ( + List[BasicConditioningInfo] + | List[SDXLConditioningInfo] + | List[FLUXConditioningInfo] + | List[SD3ConditioningInfo] + ) @dataclass