From cdcfda164d784a435c19ac88dd0227b696057f39 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sun, 4 Jun 2023 15:30:54 +0200 Subject: [PATCH] enable long prompts, upgrade compel to enable .and() (concatenating prompts) --- invokeai/app/invocations/compel.py | 54 +++++++++++++++------- invokeai/app/invocations/latent.py | 10 ++++ invokeai/backend/prompting/conditioning.py | 4 +- pyproject.toml | 2 +- 4 files changed, 52 insertions(+), 18 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 076ce81021..58dc661baf 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field from invokeai.app.invocations.util.choose_model import choose_model from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig +from ...backend.prompting.conditioning import try_parse_legacy_blend from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent @@ -13,7 +14,7 @@ from compel.prompt_parser import ( Blend, CrossAttentionControlSubstitute, FlattenedPrompt, - Fragment, + Fragment, Conjunction, ) @@ -93,25 +94,22 @@ class CompelInvocation(BaseInvocation): text_encoder=text_encoder, textual_inversion_manager=pipeline.textual_inversion_manager, dtype_for_device_getter=torch_dtype, - truncate_long_prompts=True, # TODO: + truncate_long_prompts=False, ) - # TODO: support legacy blend? - - conjunction = Compel.parse_prompt_string(prompt_str) - prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] + legacy_blend = try_parse_legacy_blend(prompt_str, skip_normalize=False) + if legacy_blend is not None: + conjunction = legacy_blend + else: + conjunction = Compel.parse_prompt_string(prompt_str) if context.services.configuration.log_tokenization: - log_tokenization_for_prompt_object(prompt, tokenizer) + log_tokenization_for_conjunction(conjunction, tokenizer) - c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) - - # TODO: long prompt support - #if not self.truncate_long_prompts: - # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) + c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( - tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt), + tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction), cross_attention_control_args=options.get("cross_attention_control", None), ) @@ -128,14 +126,22 @@ class CompelInvocation(BaseInvocation): def get_max_token_count( - tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False + tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False ) -> int: if type(prompt) is Blend: blend: Blend = prompt return max( [ - get_max_token_count(tokenizer, c, truncate_if_too_long) - for c in blend.prompts + get_max_token_count(tokenizer, p, truncate_if_too_long) + for p in blend.prompts + ] + ) + elif type(prompt) is Conjunction: + conjunction: Conjunction = prompt + return sum( + [ + get_max_token_count(tokenizer, p, truncate_if_too_long) + for p in conjunction.prompts ] ) else: @@ -170,6 +176,22 @@ def get_tokens_for_prompt_object( return tokens +def log_tokenization_for_conjunction( + c: Conjunction, tokenizer, display_label_prefix=None +): + display_label_prefix = display_label_prefix or "" + for i, p in enumerate(c.prompts): + if len(c.prompts)>1: + this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" + else: + this_display_label_prefix = display_label_prefix + log_tokenization_for_prompt_object( + p, + tokenizer, + display_label_prefix=this_display_label_prefix + ) + + def log_tokenization_for_prompt_object( p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None ): diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 4dc1f6456c..ba65e214c3 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -4,6 +4,7 @@ import random import einops from typing import Literal, Optional, Union, List +from compel import Compel from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel from pydantic import BaseModel, Field, validator @@ -233,6 +234,15 @@ class TextToLatentsInvocation(BaseInvocation): c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) + compel = Compel( + tokenizer=model.tokenizer, + text_encoder=model.text_encoder, + textual_inversion_manager=model.textual_inversion_manager, + dtype_for_device_getter=torch_dtype, + truncate_long_prompts=False, + ) + [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) + conditioning_data = ConditioningData( uc, c, diff --git a/invokeai/backend/prompting/conditioning.py b/invokeai/backend/prompting/conditioning.py index 2e62853872..46201a5284 100644 --- a/invokeai/backend/prompting/conditioning.py +++ b/invokeai/backend/prompting/conditioning.py @@ -38,7 +38,7 @@ def get_uc_and_c_and_ec(prompt_string, dtype_for_device_getter=torch_dtype, truncate_long_prompts=False, ) - + config = get_invokeai_config() # get rid of any newline characters @@ -282,6 +282,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list: (match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, text) ] + if len(parsed_prompts) == 0: + return [] if skip_normalize: return parsed_prompts weight_sum = sum(map(lambda x: x[1], parsed_prompts)) diff --git a/pyproject.toml b/pyproject.toml index 38aa71bd0e..38f4b7673f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "albumentations", "click", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", - "compel~=1.1.5", + "compel>=1.2.1", "controlnet-aux>=0.0.4", "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "datasets",