diff --git a/invokeai/app/invocations/segment_anything.py b/invokeai/app/invocations/segment_anything.py index f1a3ff1090..97400b4dab 100644 --- a/invokeai/app/invocations/segment_anything.py +++ b/invokeai/app/invocations/segment_anything.py @@ -1,7 +1,8 @@ from enum import Enum from pathlib import Path -from typing import Literal, TypedDict +from typing import Literal +import numpy as np import torch from PIL import Image from pydantic import BaseModel, Field, model_validator @@ -13,6 +14,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField from invokeai.app.invocations.primitives import MaskOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"] @@ -35,55 +37,21 @@ class SAMPoint(BaseModel): label: SAMPointLabel = Field(..., description="The label of the point") -class SAMObjectIdentifierInputData(TypedDict): - points: list[tuple[int, int]] - labels: list[int] - bounding_box: tuple[int, int, int, int] | None +class SAMPointsField(BaseModel): + points: list[SAMPoint] = Field(..., description="The points of the object") - -class SAMObjectInputKwargs(TypedDict): - points: list[list[tuple[int, int]]] - labels: list[list[int]] - bounding_boxes: list[tuple[int, int, int, int]] - - -class SAMObjectIdentifierField(BaseModel): - points: list[SAMPoint] | None = Field(None, description="The points of the feature") - bounding_box: BoundingBoxField | None = Field(None, description="The bounding box of the feature") - - @model_validator(mode="after") - def check_points_or_bounding_box(self): - if self.points is None and self.bounding_box is None: - raise ValueError("Either points or bounding_box must be provided.") - return self - - def get_input_data(self) -> SAMObjectIdentifierInputData: - input_data: SAMObjectIdentifierInputData = {"points": [], "labels": [], "bounding_box": None} - - if self.points is not None: - for point in self.points: - input_data["points"].append((point.x, point.y)) - input_data["labels"].append(point.label.value) - - if self.bounding_box is not None: - input_data["bounding_box"] = ( - self.bounding_box.x_min, - self.bounding_box.y_min, - self.bounding_box.x_max, - self.bounding_box.y_max, - ) - - return input_data + def to_list(self) -> list[list[int]]: + return [[point.x, point.y, point.label.value] for point in self.points] @invocation( - "segment_anything_object_identifier", + "segment_anything", title="Segment Anything", tags=["prompt", "segmentation"], category="segmentation", - version="1.0.0", + version="1.1.0", ) -class SegmentAnythingObjectIdentifierInvocation(BaseInvocation): +class SegmentAnythingInvocation(BaseInvocation): """Runs a Segment Anything Model.""" # Reference: @@ -93,8 +61,12 @@ class SegmentAnythingObjectIdentifierInvocation(BaseInvocation): model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.") image: ImageField = InputField(description="The image to segment.") - object_identifiers: list[SAMObjectIdentifierField] = InputField( - description="The bounding boxes to prompt the SAM model with." + bounding_boxes: list[BoundingBoxField] | None = InputField( + default=None, description="The bounding boxes to prompt the SAM model with." + ) + point_lists: list[SAMPointsField] | None = InputField( + default=None, + description="The list of point lists to prompt the SAM model with. Each list of points represents a single object.", ) apply_polygon_refinement: bool = InputField( description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).", @@ -105,15 +77,32 @@ class SegmentAnythingObjectIdentifierInvocation(BaseInvocation): default="all", ) + @model_validator(mode="after") + def check_point_lists_or_bounding_box(self): + if self.point_lists is None and self.bounding_boxes is None: + raise ValueError("Either point_lists or bounding_box must be provided.") + elif self.point_lists is not None and self.bounding_boxes is not None: + raise ValueError("Only one of point_lists or bounding_box can be provided.") + return self + @torch.no_grad() def invoke(self, context: InvocationContext) -> MaskOutput: # The models expect a 3-channel RGB image. image_pil = context.images.get_pil(self.image.image_name, mode="RGB") - mask = self._segment(context=context, image=image_pil)[0] + if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and ( + not self.point_lists or len(self.point_lists) == 0 + ): + combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool) + else: + masks = self._segment(context=context, image=image_pil) + masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes) - mask_tensor_name = context.tensors.save(mask) - height, width = mask.shape + # masks contains bool values, so we merge them via max-reduce. + combined_mask, _ = torch.stack(masks).max(dim=0) + + mask_tensor_name = context.tensors.save(combined_mask) + height, width = combined_mask.shape return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height) @staticmethod @@ -133,27 +122,23 @@ class SegmentAnythingObjectIdentifierInvocation(BaseInvocation): def _segment(self, context: InvocationContext, image: Image.Image) -> list[torch.Tensor]: """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.""" - - input_kwargs: SAMObjectInputKwargs = {"points": [], "labels": [], "bounding_boxes": []} - - for obj_id in self.object_identifiers: - input_data = obj_id.get_input_data() - if input_data["points"]: - input_kwargs["points"].append(input_data["points"]) - input_kwargs["labels"].append(input_data["labels"]) - if input_data["bounding_box"]: - input_kwargs["bounding_boxes"].append(input_data["bounding_box"]) + # Convert the bounding boxes to the SAM input format. + sam_bounding_boxes = ( + [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] if self.bounding_boxes else None + ) + sam_points = [p.to_list() for p in self.point_lists] if self.point_lists else None with ( context.models.load_remote_model( - source=SEGMENT_ANYTHING_MODEL_IDS[self.model], - loader=SegmentAnythingObjectIdentifierInvocation._load_sam_model, + source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model ) as sam_pipeline, ): assert isinstance(sam_pipeline, SegmentAnythingPipeline) - masks = sam_pipeline.segment(image=image, **input_kwargs, multimask_output=False) + masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes, point_lists=sam_points) masks = self._process_masks(masks) + if self.apply_polygon_refinement: + masks = self._apply_polygon_refinement(masks) return masks @@ -167,140 +152,51 @@ class SegmentAnythingObjectIdentifierInvocation(BaseInvocation): # Split the first dimension into a list of masks. return list(masks.cpu().unbind(dim=0)) + def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]: + """Apply polygon refinement to the masks. -# @invocation( -# "segment_anything", -# title="Segment Anything", -# tags=["prompt", "segmentation"], -# category="segmentation", -# version="1.0.0", -# ) -# class SegmentAnythingInvocation(BaseInvocation): -# """Runs a Segment Anything Model.""" + Convert each mask to a polygon, then back to a mask. This has the following effect: + - Smooth the edges of the mask slightly. + - Ensure that each mask consists of a single closed polygon + - Removes small mask pieces. + - Removes holes from the mask. + """ + # Convert tensor masks to np masks. + np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks] -# # Reference: -# # - https://arxiv.org/pdf/2304.02643 -# # - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam -# # - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb + # Apply polygon refinement. + for idx, mask in enumerate(np_masks): + shape = mask.shape + assert len(shape) == 2 # Assert length to satisfy type checker. + polygon = mask_to_polygon(mask) + mask = polygon_to_mask(polygon, shape) + np_masks[idx] = mask -# model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.") -# image: ImageField = InputField(description="The image to segment.") -# bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.") -# apply_polygon_refinement: bool = InputField( -# description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).", -# default=True, -# ) -# mask_filter: Literal["all", "largest", "highest_box_score"] = InputField( -# description="The filtering to apply to the detected masks before merging them into a final output.", -# default="all", -# ) + # Convert np masks back to tensor masks. + masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks] -# @torch.no_grad() -# def invoke(self, context: InvocationContext) -> MaskOutput: -# # The models expect a 3-channel RGB image. -# image_pil = context.images.get_pil(self.image.image_name, mode="RGB") + return masks -# if len(self.bounding_boxes) == 0: -# combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool) -# else: -# masks = self._segment(context=context, image=image_pil) -# masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes) + def _filter_masks( + self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField] | None + ) -> list[torch.Tensor]: + """Filter the detected masks based on the specified mask filter.""" -# # masks contains bool values, so we merge them via max-reduce. -# combined_mask, _ = torch.stack(masks).max(dim=0) - -# mask_tensor_name = context.tensors.save(combined_mask) -# height, width = combined_mask.shape -# return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height) - -# @staticmethod -# def _load_sam_model(model_path: Path): -# sam_model = AutoModelForMaskGeneration.from_pretrained( -# model_path, -# local_files_only=True, -# # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the -# # model, and figure out how to make it work in the pipeline. -# # torch_dtype=TorchDevice.choose_torch_dtype(), -# ) -# assert isinstance(sam_model, SamModel) - -# sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True) -# assert isinstance(sam_processor, SamProcessor) -# return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor) - -# def _segment( -# self, -# context: InvocationContext, -# image: Image.Image, -# ) -> list[torch.Tensor]: -# """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.""" -# # Convert the bounding boxes to the SAM input format. -# sam_bounding_boxes = [(bb.x_min, bb.y_min, bb.x_max, bb.y_max) for bb in self.bounding_boxes] - -# with ( -# context.models.load_remote_model( -# source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model -# ) as sam_pipeline, -# ): -# assert isinstance(sam_pipeline, SegmentAnythingPipeline) -# masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes) - -# masks = self._process_masks(masks) -# if self.apply_polygon_refinement: -# masks = self._apply_polygon_refinement(masks) - -# return masks - -# def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]: -# """Convert the tensor output from the Segment Anything model from a tensor of shape -# [num_masks, channels, height, width] to a list of tensors of shape [height, width]. -# """ -# assert masks.dtype == torch.bool -# # [num_masks, channels, height, width] -> [num_masks, height, width] -# masks, _ = masks.max(dim=1) -# # Split the first dimension into a list of masks. -# return list(masks.cpu().unbind(dim=0)) - -# def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]: -# """Apply polygon refinement to the masks. - -# Convert each mask to a polygon, then back to a mask. This has the following effect: -# - Smooth the edges of the mask slightly. -# - Ensure that each mask consists of a single closed polygon -# - Removes small mask pieces. -# - Removes holes from the mask. -# """ -# # Convert tensor masks to np masks. -# np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks] - -# # Apply polygon refinement. -# for idx, mask in enumerate(np_masks): -# shape = mask.shape -# assert len(shape) == 2 # Assert length to satisfy type checker. -# polygon = mask_to_polygon(mask) -# mask = polygon_to_mask(polygon, shape) -# np_masks[idx] = mask - -# # Convert np masks back to tensor masks. -# masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks] - -# return masks - -# def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]: -# """Filter the detected masks based on the specified mask filter.""" -# assert len(masks) == len(bounding_boxes) - -# if self.mask_filter == "all": -# return masks -# elif self.mask_filter == "largest": -# # Find the largest mask. -# return [max(masks, key=lambda x: float(x.sum()))] -# elif self.mask_filter == "highest_box_score": -# # Find the index of the bounding box with the highest score. -# # Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most -# # cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a -# # reasonable fallback since the expected score range is [0.0, 1.0]. -# max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0) -# return [masks[max_score_idx]] -# else: -# raise ValueError(f"Invalid mask filter: {self.mask_filter}") + if self.mask_filter == "all": + return masks + elif self.mask_filter == "largest": + # Find the largest mask. + return [max(masks, key=lambda x: float(x.sum()))] + elif self.mask_filter == "highest_box_score": + assert ( + bounding_boxes is not None + ), "Bounding boxes must be provided to use the 'highest_box_score' mask filter." + assert len(masks) == len(bounding_boxes) + # Find the index of the bounding box with the highest score. + # Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most + # cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a + # reasonable fallback since the expected score range is [0.0, 1.0]. + max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0) + return [masks[max_score_idx]] + else: + raise ValueError(f"Invalid mask filter: {self.mask_filter}") diff --git a/invokeai/backend/image_util/segment_anything/segment_anything_pipeline.py b/invokeai/backend/image_util/segment_anything/segment_anything_pipeline.py index 93d9c285ec..e10df5d180 100644 --- a/invokeai/backend/image_util/segment_anything/segment_anything_pipeline.py +++ b/invokeai/backend/image_util/segment_anything/segment_anything_pipeline.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, TypeAlias import torch from PIL import Image @@ -7,6 +7,14 @@ from transformers.models.sam.processing_sam import SamProcessor from invokeai.backend.raw_model import RawModel +# Type aliases for the inputs to the SAM model. +ListOfBoundingBoxes: TypeAlias = list[list[int]] +"""A list of bounding boxes. Each bounding box is in the format [xmin, ymin, xmax, ymax].""" +ListOfPoints: TypeAlias = list[list[int]] +"""A list of points. Each point is in the format [x, y].""" +ListOfPointLabels: TypeAlias = list[int] +"""A list of SAM point labels. Each label is an integer where -1 is background, 0 is neutral, and 1 is foreground.""" + class SegmentAnythingPipeline(RawModel): """A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager.""" @@ -30,36 +38,49 @@ class SegmentAnythingPipeline(RawModel): def segment( self, image: Image.Image, - bounding_boxes: list[tuple[int, int, int, int]] | None = None, - points: list[list[tuple[int, int]]] | None = None, - labels: list[list[int]] | None = None, - multimask_output: bool = True, + bounding_boxes: list[list[int]] | None = None, + point_lists: list[list[list[int]]] | None = None, ) -> torch.Tensor: """Run the SAM model. + Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and + point_lists will be ignored. + Args: image (Image.Image): The image to segment. bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format [xmin, ymin, xmax, ymax]. + point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label]. + `label` is an integer where -1 is background, 0 is neutral, and 1 is foreground. Returns: torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width]. """ - input_kwargs = {} + # Prep the inputs: + # - Create a list of bounding boxes or points and labels. + # - Add a batch dimension of 1 to the inputs. if bounding_boxes: - # Add batch dimension of 1 to the inputs. - input_kwargs["input_boxes"] = [bounding_boxes] - if points and labels and len(points) == len(labels): - # Add batch dimension of 1 to the inputs. - input_kwargs["input_points"] = [points] - input_kwargs["input_labels"] = [labels] + input_boxes: list[ListOfBoundingBoxes] | None = [bounding_boxes] + input_points: list[ListOfPoints] | None = None + input_labels: list[ListOfPointLabels] | None = None + elif point_lists: + input_boxes: list[ListOfBoundingBoxes] | None = None + input_points: list[ListOfPoints] | None = [] + input_labels: list[ListOfPointLabels] | None = [] + for point_list in point_lists: + input_points.append([[p[0], p[1]] for p in point_list]) + input_labels.append([p[2] for p in point_list]) + + else: + raise ValueError("Either bounding_boxes or points and labels must be provided.") inputs = self._sam_processor( images=image, - **input_kwargs, + input_boxes=input_boxes, + input_points=input_points, + input_labels=input_labels, return_tensors="pt", - multimask_output=multimask_output, ).to(self._sam_model.device) outputs = self._sam_model(**inputs) masks = self._sam_processor.post_process_masks(