From 460d555a3dd89885bdaabb9a54515f6e412a3fb2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 24 May 2023 21:35:46 +1000 Subject: [PATCH] feat(nodes): add image mul, channel, convert nodes also make img node names consistent --- invokeai/app/invocations/image.py | 137 +++++++++++++++++++++++++++--- 1 file changed, 125 insertions(+), 12 deletions(-) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 8f789853ac..b25d3735c2 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -4,7 +4,7 @@ import io from typing import Literal, Optional, Union import numpy -from PIL import Image, ImageFilter, ImageOps +from PIL import Image, ImageFilter, ImageOps, ImageChops from pydantic import BaseModel, Field from ..models.image import ImageCategory, ImageField, ImageType @@ -112,11 +112,11 @@ class ShowImageInvocation(BaseInvocation): ) -class CropImageInvocation(BaseInvocation, PILInvocationConfig): +class ImageCropInvocation(BaseInvocation, PILInvocationConfig): """Crops an image to a specified box. The box can be outside of the image.""" # fmt: off - type: Literal["crop"] = "crop" + type: Literal["img_crop"] = "img_crop" # Inputs image: Union[ImageField, None] = Field(default=None, description="The image to crop") @@ -154,11 +154,11 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig): ) -class PasteImageInvocation(BaseInvocation, PILInvocationConfig): +class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): """Pastes an image into another image.""" # fmt: off - type: Literal["paste"] = "paste" + type: Literal["img_paste"] = "img_paste" # Inputs base_image: Union[ImageField, None] = Field(default=None, description="The base image") @@ -238,7 +238,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=image_mask, image_type=ImageType.INTERMEDIATE, - image_category=ImageCategory.GENERAL, + image_category=ImageCategory.MASK, node_id=self.id, session_id=context.graph_execution_state_id, ) @@ -252,11 +252,124 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): ) -class BlurInvocation(BaseInvocation, PILInvocationConfig): +class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): + """Multiplies two images together using `PIL.ImageChops.multiply()`.""" + + # fmt: off + type: Literal["img_mul"] = "img_mul" + + # Inputs + image1: Union[ImageField, None] = Field(default=None, description="The first image to multiply") + image2: Union[ImageField, None] = Field(default=None, description="The second image to multiply") + # fmt: on + + def invoke(self, context: InvocationContext) -> ImageOutput: + image1 = context.services.images.get_pil_image( + self.image1.image_type, self.image1.image_name + ) + image2 = context.services.images.get_pil_image( + self.image2.image_type, self.image2.image_name + ) + + multiply_image = ImageChops.multiply(image1, image2) + + image_dto = context.services.images.create( + image=multiply_image, + image_type=ImageType.INTERMEDIATE, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + ) + + return ImageOutput( + image=ImageField( + image_type=image_dto.image_type, image_name=image_dto.image_name + ), + width=image_dto.width, + height=image_dto.height, + ) + + +IMAGE_CHANNELS = Literal["A", "R", "G", "B"] + + +class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): + """Gets a channel from an image.""" + + # fmt: off + type: Literal["img_chan"] = "img_chan" + + # Inputs + image: Union[ImageField, None] = Field(default=None, description="The image to get the channel from") + channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get") + # fmt: on + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get_pil_image( + self.image.image_type, self.image.image_name + ) + + channel_image = image.getchannel(self.channel) + + image_dto = context.services.images.create( + image=channel_image, + image_type=ImageType.INTERMEDIATE, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + ) + + return ImageOutput( + image=ImageField( + image_type=image_dto.image_type, image_name=image_dto.image_name + ), + width=image_dto.width, + height=image_dto.height, + ) + + +IMAGE_MODES = Literal['L', 'RGB', 'RGBA', 'CMYK', 'YCbCr', 'LAB', 'HSV', 'I', 'F'] + +class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): + """Converts an image to a different mode.""" + + # fmt: off + type: Literal["img_conv"] = "img_conv" + + # Inputs + image: Union[ImageField, None] = Field(default=None, description="The image to convert") + mode: IMAGE_MODES = Field(default="L", description="The mode to convert to") + # fmt: on + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get_pil_image( + self.image.image_type, self.image.image_name + ) + + converted_image = image.convert(self.mode) + + image_dto = context.services.images.create( + image=converted_image, + image_type=ImageType.INTERMEDIATE, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + ) + + return ImageOutput( + image=ImageField( + image_type=image_dto.image_type, image_name=image_dto.image_name + ), + width=image_dto.width, + height=image_dto.height, + ) + + +class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): """Blurs an image""" # fmt: off - type: Literal["blur"] = "blur" + type: Literal["img_blur"] = "img_blur" # Inputs image: Union[ImageField, None] = Field(default=None, description="The image to blur") @@ -294,11 +407,11 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig): ) -class LerpInvocation(BaseInvocation, PILInvocationConfig): +class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): """Linear interpolation of all pixels of an image""" # fmt: off - type: Literal["lerp"] = "lerp" + type: Literal["img_lerp"] = "img_lerp" # Inputs image: Union[ImageField, None] = Field(default=None, description="The image to lerp") @@ -334,11 +447,11 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig): ) -class InverseLerpInvocation(BaseInvocation, PILInvocationConfig): +class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): """Inverse linear interpolation of all pixels of an image""" # fmt: off - type: Literal["ilerp"] = "ilerp" + type: Literal["img_ilerp"] = "img_ilerp" # Inputs image: Union[ImageField, None] = Field(default=None, description="The image to lerp")