From 1cffcc02a5cfa32ca127c9988091bd3136793f42 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 10 Sep 2024 20:56:42 +1000 Subject: [PATCH] feat(nodes): add `HEDEdgeDetectionInvocation` Similar to the existing node, but without any resizing and with a revised model loading API that uses the model manager. --- invokeai/app/invocations/hed.py | 33 +++++++++++++ invokeai/backend/image_util/hed.py | 76 +++++++++++++++++++++++++++++- 2 files changed, 108 insertions(+), 1 deletion(-) create mode 100644 invokeai/app/invocations/hed.py diff --git a/invokeai/app/invocations/hed.py b/invokeai/app/invocations/hed.py new file mode 100644 index 0000000000..5ea6e8df1f --- /dev/null +++ b/invokeai/app/invocations/hed.py @@ -0,0 +1,33 @@ +from builtins import bool + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, WithBoard, WithMetadata +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.image_util.hed import ControlNetHED_Apache2, HEDEdgeDetector + + +@invocation( + "hed_edge_detection", + title="HED Edge Detection", + tags=["controlnet", "hed", "softedge"], + category="controlnet", + version="1.0.0", +) +class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): + """Geneartes an edge map using the HED (softedge) model.""" + + image: ImageField = InputField(description="The image to process") + scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.images.get_pil(self.image.image_name, "RGB") + loaded_model = context.models.load_remote_model(HEDEdgeDetector.get_model_url(), HEDEdgeDetector.load_model) + + with loaded_model as model: + assert isinstance(model, ControlNetHED_Apache2) + hed_processor = HEDEdgeDetector(model) + edge_map = hed_processor.run(image=image, scribble=self.scribble) + + image_dto = context.images.save(image=edge_map) + return ImageOutput.build(image_dto) diff --git a/invokeai/backend/image_util/hed.py b/invokeai/backend/image_util/hed.py index 97706df8b9..ec12c26b2e 100644 --- a/invokeai/backend/image_util/hed.py +++ b/invokeai/backend/image_util/hed.py @@ -1,6 +1,9 @@ -"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license).""" +# Adapted from https://github.com/huggingface/controlnet_aux + +import pathlib import cv2 +import huggingface_hub import numpy as np import torch from einops import rearrange @@ -140,3 +143,74 @@ class HEDProcessor: detected_map[detected_map < 255] = 0 return np_to_pil(detected_map) + + +class HEDEdgeDetector: + """Simple wrapper around the HED model for detecting edges in an image.""" + + hf_repo_id = "lllyasviel/Annotators" + hf_filename = "ControlNetHED.pth" + + def __init__(self, model: ControlNetHED_Apache2): + self.model = model + + @classmethod + def get_model_url(cls) -> str: + """Get the URL to download the model from the Hugging Face Hub.""" + return huggingface_hub.hf_hub_url(cls.hf_repo_id, cls.hf_filename) + + @classmethod + def load_model(cls, model_path: pathlib.Path) -> ControlNetHED_Apache2: + """Load the model from a file.""" + model = ControlNetHED_Apache2() + model.load_state_dict(torch.load(model_path, map_location="cpu")) + model.float().eval() + return model + + def to(self, device: torch.device): + self.model.to(device) + return self + + def run(self, image: Image.Image, safe: bool = False, scribble: bool = False) -> Image.Image: + """Processes an image and returns the detected edges. + + Args: + image: The input image. + safe: Whether to apply safe step to the detected edges. + scribble: Whether to apply non-maximum suppression and Gaussian blur to the detected edges. + + Returns: + The detected edges. + """ + + device = next(iter(self.model.parameters())).device + + np_image = pil_to_np(image) + + height, width, _channels = np_image.shape + + with torch.no_grad(): + image_hed = torch.from_numpy(np_image.copy()).float().to(device) + image_hed = rearrange(image_hed, "h w c -> 1 c h w") + edges = self.model(image_hed) + edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] + edges = [cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR) for e in edges] + edges = np.stack(edges, axis=2) + edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) + if safe: + edge = safe_step(edge) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = edge + + detected_map = cv2.resize(detected_map, (width, height), interpolation=cv2.INTER_LINEAR) + + if scribble: + detected_map = nms(detected_map, 127, 3.0) + detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) + detected_map[detected_map > 4] = 255 + detected_map[detected_map < 255] = 0 + + output = np_to_pil(detected_map) + + return output