From ae05d34584c352bedde71148eada535199ae26ef Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 30 Aug 2023 14:52:50 +1000 Subject: [PATCH] fix(nodes): fix uploading image metadata retention was causing failure to save images --- invokeai/app/invocations/baseinvocation.py | 43 ++++++++++++++++++- invokeai/app/invocations/collections.py | 11 ++--- invokeai/app/invocations/compel.py | 11 +++-- .../controlnet_image_processors.py | 20 ++++++++- invokeai/app/invocations/cv.py | 3 +- invokeai/app/services/image_file_storage.py | 21 ++++++--- 6 files changed, 87 insertions(+), 22 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index f7a09fe63b..f56e7c7aa5 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -372,8 +372,9 @@ class UIConfigBase(BaseModel): decorators, though you may add this class to a node definition to specify the title and tags. """ - tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI") - title: Optional[str] = Field(default=None, description="The display name of the node") + tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags") + title: Optional[str] = Field(default=None, description="The node's display name") + category: Optional[str] = Field(default=None, description="The node's category") class InvocationContext: @@ -469,6 +470,8 @@ class BaseInvocation(ABC, BaseModel): schema["title"] = uiconfig.title if uiconfig and hasattr(uiconfig, "tags"): schema["tags"] = uiconfig.tags + if uiconfig and hasattr(uiconfig, "category"): + schema["category"] = uiconfig.category if "required" not in schema or not isinstance(schema["required"], list): schema["required"] = list() schema["required"].extend(["type", "id"]) @@ -558,3 +561,39 @@ def tags(*tags: str) -> Callable[[Type[T]], Type[T]]: return cls return wrapper + + +def category(category: str) -> Callable[[Type[T]], Type[T]]: + """Adds a category to the invocation. This is used to group invocations in the UI.""" + + def wrapper(cls: Type[T]) -> Type[T]: + uiconf_name = cls.__qualname__ + ".UIConfig" + if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name: + cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict()) + cls.UIConfig.category = category + return cls + + return wrapper + + +def node( + title: Optional[str] = None, tags: Optional[list[str]] = None, category: Optional[str] = None +) -> Callable[[Type[T]], Type[T]]: + """ + Adds metadata to the invocation as a decorator. + + :param Optional[str] title: Adds a title to the node. Use if the auto-generated title isn't quite right. Defaults to None. + :param Optional[list[str]] tags: Adds tags to the node. Nodes may be searched for by their tags. Defaults to None. + :param Optional[str] category: Adds a category to the node. Used to group the nodes in the UI. Defaults to None. + """ + + def wrapper(cls: Type[T]) -> Type[T]: + uiconf_name = cls.__qualname__ + ".UIConfig" + if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name: + cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict()) + cls.UIConfig.title = title + cls.UIConfig.tags = tags + cls.UIConfig.category = category + return cls + + return wrapper diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index 65bce2737e..1f8568beed 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -8,11 +8,10 @@ from pydantic import validator from invokeai.app.invocations.primitives import IntegerCollectionOutput from invokeai.app.util.misc import SEED_MAX, get_random_seed -from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title +from .baseinvocation import BaseInvocation, InputField, InvocationContext, node -@title("Integer Range") -@tags("collection", "integer", "range") +@node(title="Integer Range", tags=["collection", "integer", "range"], category="collections") class RangeInvocation(BaseInvocation): """Creates a range of numbers from start to stop with step""" @@ -33,8 +32,7 @@ class RangeInvocation(BaseInvocation): return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step))) -@title("Integer Range of Size") -@tags("range", "integer", "size", "collection") +@node(title="Integer Range of Size", tags=["collection", "integer", "size", "range"], category="collections") class RangeOfSizeInvocation(BaseInvocation): """Creates a range from start to start + size with step""" @@ -49,8 +47,7 @@ class RangeOfSizeInvocation(BaseInvocation): return IntegerCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step))) -@title("Random Range") -@tags("range", "integer", "random", "collection") +@node(title="Random Range", tags=["range", "integer", "random", "collection"], category="collections") class RandomRangeInvocation(BaseInvocation): """Creates a collection of random numbers""" diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index e128792d70..9d0f848986 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -26,6 +26,7 @@ from .baseinvocation import ( InvocationContext, OutputField, UIComponent, + category, tags, title, ) @@ -44,8 +45,9 @@ class ConditioningFieldData: # PerpNeg = "perp_neg" -@title("Compel Prompt") +@title("Prompt") @tags("prompt", "compel") +@category("conditioning") class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -265,8 +267,9 @@ class SDXLPromptInvocationBase: return c, c_pooled, ec -@title("SDXL Compel Prompt") +@title("SDXL Prompt") @tags("sdxl", "compel", "prompt") +@category("conditioning") class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -324,8 +327,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): ) -@title("SDXL Refiner Compel Prompt") +@title("SDXL Refiner Prompt") @tags("sdxl", "compel", "prompt") +@category("conditioning") class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -381,6 +385,7 @@ class ClipSkipInvocationOutput(BaseInvocationOutput): @title("CLIP Skip") @tags("clipskip", "clip", "skip") +@category("conditioning") class ClipSkipInvocation(BaseInvocation): """Skip layers in clip text_encoder model.""" diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index a82b1cecc9..cc6455a714 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -40,6 +40,8 @@ from .baseinvocation import ( InvocationContext, OutputField, UIType, + category, + node, tags, title, ) @@ -96,8 +98,7 @@ class ControlOutput(BaseInvocationOutput): control: ControlField = OutputField(description=FieldDescriptions.control) -@title("ControlNet") -@tags("controlnet") +@node(title="ControlNet", tags=["controlnet"], category="controlnet") class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" @@ -177,6 +178,7 @@ class ImageProcessorInvocation(BaseInvocation): @title("Canny Processor") @tags("controlnet", "canny") +@category("controlnet") class CannyImageProcessorInvocation(ImageProcessorInvocation): """Canny edge detection for ControlNet""" @@ -198,6 +200,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation): @title("HED (softedge) Processor") @tags("controlnet", "hed", "softedge") +@category("controlnet") class HedImageProcessorInvocation(ImageProcessorInvocation): """Applies HED edge detection to image""" @@ -225,6 +228,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation): @title("Lineart Processor") @tags("controlnet", "lineart") +@category("controlnet") class LineartImageProcessorInvocation(ImageProcessorInvocation): """Applies line art processing to image""" @@ -245,6 +249,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation): @title("Lineart Anime Processor") @tags("controlnet", "lineart", "anime") +@category("controlnet") class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): """Applies line art anime processing to image""" @@ -266,6 +271,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): @title("Openpose Processor") @tags("controlnet", "openpose", "pose") +@category("controlnet") class OpenposeImageProcessorInvocation(ImageProcessorInvocation): """Applies Openpose processing to image""" @@ -289,6 +295,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation): @title("Midas (Depth) Processor") @tags("controlnet", "midas", "depth") +@category("controlnet") class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Midas depth processing to image""" @@ -314,6 +321,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): @title("Normal BAE Processor") @tags("controlnet", "normal", "bae") +@category("controlnet") class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): """Applies NormalBae processing to image""" @@ -333,6 +341,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): @title("MLSD Processor") @tags("controlnet", "mlsd") +@category("controlnet") class MlsdImageProcessorInvocation(ImageProcessorInvocation): """Applies MLSD processing to image""" @@ -358,6 +367,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation): @title("PIDI Processor") @tags("controlnet", "pidi") +@category("controlnet") class PidiImageProcessorInvocation(ImageProcessorInvocation): """Applies PIDI processing to image""" @@ -383,6 +393,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation): @title("Content Shuffle Processor") @tags("controlnet", "contentshuffle") +@category("controlnet") class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): """Applies content shuffle processing to image""" @@ -411,6 +422,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): # should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13 @title("Zoe (Depth) Processor") @tags("controlnet", "zoe", "depth") +@category("controlnet") class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" @@ -424,6 +436,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): @title("Mediapipe Face Processor") @tags("controlnet", "mediapipe", "face") +@category("controlnet") class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): """Applies mediapipe face processing to image""" @@ -445,6 +458,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): @title("Leres (Depth) Processor") @tags("controlnet", "leres", "depth") +@category("controlnet") class LeresImageProcessorInvocation(ImageProcessorInvocation): """Applies leres processing to image""" @@ -472,6 +486,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation): @title("Tile Resample Processor") @tags("controlnet", "tile") +@category("controlnet") class TileResamplerProcessorInvocation(ImageProcessorInvocation): """Tile resampler processor""" @@ -510,6 +525,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation): @title("Segment Anything Processor") @tags("controlnet", "segmentanything") +@category("controlnet") class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): """Applies segment anything processing to image""" diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 146d480938..c03422d95e 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -8,11 +8,12 @@ from PIL import Image, ImageOps from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.models.image import ImageCategory, ResourceOrigin -from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title +from .baseinvocation import BaseInvocation, InputField, InvocationContext, category, tags, title @title("OpenCV Inpaint") @tags("opencv", "inpaint") +@category("inpaint") class CvInpaintInvocation(BaseInvocation): """Simple inpaint using opencv.""" diff --git a/invokeai/app/services/image_file_storage.py b/invokeai/app/services/image_file_storage.py index e80c6adbab..75a5888175 100644 --- a/invokeai/app/services/image_file_storage.py +++ b/invokeai/app/services/image_file_storage.py @@ -119,13 +119,20 @@ class DiskImageFileStorage(ImageFileStorageBase): pnginfo = PngImagePlugin.PngInfo() - if metadata is not None: - pnginfo.add_text("invokeai_metadata", json.dumps(metadata)) - if workflow is not None: - pnginfo.add_text("invokeai_workflow", workflow) - # For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back - for item_name, item in image.info.items(): - pnginfo.add_text(item_name, item) + if metadata is not None or workflow is not None: + if metadata is not None: + pnginfo.add_text("invokeai_metadata", json.dumps(metadata)) + if workflow is not None: + pnginfo.add_text("invokeai_workflow", workflow) + else: + # For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back + # TODO: retain non-invokeai metadata on save... + original_metadata = image.info.get("invokeai_metadata", None) + if original_metadata is not None: + pnginfo.add_text("invokeai_metadata", original_metadata) + original_workflow = image.info.get("invokeai_workflow", None) + if original_workflow is not None: + pnginfo.add_text("invokeai_workflow", original_workflow) image.save(image_path, "PNG", pnginfo=pnginfo)