mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-04-29 18:22:30 +02:00
feat(nodes): update SAM backend and nodes to work with SAM points
This commit is contained in:
parent
790846297a
commit
ff72315db2
@ -1,7 +1,8 @@
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
@ -13,6 +14,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
|
||||
from invokeai.app.invocations.primitives import MaskOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
|
||||
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
|
||||
@ -35,55 +37,21 @@ class SAMPoint(BaseModel):
|
||||
label: SAMPointLabel = Field(..., description="The label of the point")
|
||||
|
||||
|
||||
class SAMObjectIdentifierInputData(TypedDict):
|
||||
points: list[tuple[int, int]]
|
||||
labels: list[int]
|
||||
bounding_box: tuple[int, int, int, int] | None
|
||||
class SAMPointsField(BaseModel):
|
||||
points: list[SAMPoint] = Field(..., description="The points of the object")
|
||||
|
||||
|
||||
class SAMObjectInputKwargs(TypedDict):
|
||||
points: list[list[tuple[int, int]]]
|
||||
labels: list[list[int]]
|
||||
bounding_boxes: list[tuple[int, int, int, int]]
|
||||
|
||||
|
||||
class SAMObjectIdentifierField(BaseModel):
|
||||
points: list[SAMPoint] | None = Field(None, description="The points of the feature")
|
||||
bounding_box: BoundingBoxField | None = Field(None, description="The bounding box of the feature")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_points_or_bounding_box(self):
|
||||
if self.points is None and self.bounding_box is None:
|
||||
raise ValueError("Either points or bounding_box must be provided.")
|
||||
return self
|
||||
|
||||
def get_input_data(self) -> SAMObjectIdentifierInputData:
|
||||
input_data: SAMObjectIdentifierInputData = {"points": [], "labels": [], "bounding_box": None}
|
||||
|
||||
if self.points is not None:
|
||||
for point in self.points:
|
||||
input_data["points"].append((point.x, point.y))
|
||||
input_data["labels"].append(point.label.value)
|
||||
|
||||
if self.bounding_box is not None:
|
||||
input_data["bounding_box"] = (
|
||||
self.bounding_box.x_min,
|
||||
self.bounding_box.y_min,
|
||||
self.bounding_box.x_max,
|
||||
self.bounding_box.y_max,
|
||||
)
|
||||
|
||||
return input_data
|
||||
def to_list(self) -> list[list[int]]:
|
||||
return [[point.x, point.y, point.label.value] for point in self.points]
|
||||
|
||||
|
||||
@invocation(
|
||||
"segment_anything_object_identifier",
|
||||
"segment_anything",
|
||||
title="Segment Anything",
|
||||
tags=["prompt", "segmentation"],
|
||||
category="segmentation",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
)
|
||||
class SegmentAnythingObjectIdentifierInvocation(BaseInvocation):
|
||||
class SegmentAnythingInvocation(BaseInvocation):
|
||||
"""Runs a Segment Anything Model."""
|
||||
|
||||
# Reference:
|
||||
@ -93,8 +61,12 @@ class SegmentAnythingObjectIdentifierInvocation(BaseInvocation):
|
||||
|
||||
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
|
||||
image: ImageField = InputField(description="The image to segment.")
|
||||
object_identifiers: list[SAMObjectIdentifierField] = InputField(
|
||||
description="The bounding boxes to prompt the SAM model with."
|
||||
bounding_boxes: list[BoundingBoxField] | None = InputField(
|
||||
default=None, description="The bounding boxes to prompt the SAM model with."
|
||||
)
|
||||
point_lists: list[SAMPointsField] | None = InputField(
|
||||
default=None,
|
||||
description="The list of point lists to prompt the SAM model with. Each list of points represents a single object.",
|
||||
)
|
||||
apply_polygon_refinement: bool = InputField(
|
||||
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
|
||||
@ -105,15 +77,32 @@ class SegmentAnythingObjectIdentifierInvocation(BaseInvocation):
|
||||
default="all",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_point_lists_or_bounding_box(self):
|
||||
if self.point_lists is None and self.bounding_boxes is None:
|
||||
raise ValueError("Either point_lists or bounding_box must be provided.")
|
||||
elif self.point_lists is not None and self.bounding_boxes is not None:
|
||||
raise ValueError("Only one of point_lists or bounding_box can be provided.")
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
# The models expect a 3-channel RGB image.
|
||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
mask = self._segment(context=context, image=image_pil)[0]
|
||||
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
|
||||
not self.point_lists or len(self.point_lists) == 0
|
||||
):
|
||||
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
|
||||
else:
|
||||
masks = self._segment(context=context, image=image_pil)
|
||||
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
|
||||
|
||||
mask_tensor_name = context.tensors.save(mask)
|
||||
height, width = mask.shape
|
||||
# masks contains bool values, so we merge them via max-reduce.
|
||||
combined_mask, _ = torch.stack(masks).max(dim=0)
|
||||
|
||||
mask_tensor_name = context.tensors.save(combined_mask)
|
||||
height, width = combined_mask.shape
|
||||
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
|
||||
|
||||
@staticmethod
|
||||
@ -133,27 +122,23 @@ class SegmentAnythingObjectIdentifierInvocation(BaseInvocation):
|
||||
|
||||
def _segment(self, context: InvocationContext, image: Image.Image) -> list[torch.Tensor]:
|
||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||
|
||||
input_kwargs: SAMObjectInputKwargs = {"points": [], "labels": [], "bounding_boxes": []}
|
||||
|
||||
for obj_id in self.object_identifiers:
|
||||
input_data = obj_id.get_input_data()
|
||||
if input_data["points"]:
|
||||
input_kwargs["points"].append(input_data["points"])
|
||||
input_kwargs["labels"].append(input_data["labels"])
|
||||
if input_data["bounding_box"]:
|
||||
input_kwargs["bounding_boxes"].append(input_data["bounding_box"])
|
||||
# Convert the bounding boxes to the SAM input format.
|
||||
sam_bounding_boxes = (
|
||||
[[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] if self.bounding_boxes else None
|
||||
)
|
||||
sam_points = [p.to_list() for p in self.point_lists] if self.point_lists else None
|
||||
|
||||
with (
|
||||
context.models.load_remote_model(
|
||||
source=SEGMENT_ANYTHING_MODEL_IDS[self.model],
|
||||
loader=SegmentAnythingObjectIdentifierInvocation._load_sam_model,
|
||||
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
|
||||
) as sam_pipeline,
|
||||
):
|
||||
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
|
||||
masks = sam_pipeline.segment(image=image, **input_kwargs, multimask_output=False)
|
||||
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes, point_lists=sam_points)
|
||||
|
||||
masks = self._process_masks(masks)
|
||||
if self.apply_polygon_refinement:
|
||||
masks = self._apply_polygon_refinement(masks)
|
||||
|
||||
return masks
|
||||
|
||||
@ -167,140 +152,51 @@ class SegmentAnythingObjectIdentifierInvocation(BaseInvocation):
|
||||
# Split the first dimension into a list of masks.
|
||||
return list(masks.cpu().unbind(dim=0))
|
||||
|
||||
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
"""Apply polygon refinement to the masks.
|
||||
|
||||
# @invocation(
|
||||
# "segment_anything",
|
||||
# title="Segment Anything",
|
||||
# tags=["prompt", "segmentation"],
|
||||
# category="segmentation",
|
||||
# version="1.0.0",
|
||||
# )
|
||||
# class SegmentAnythingInvocation(BaseInvocation):
|
||||
# """Runs a Segment Anything Model."""
|
||||
Convert each mask to a polygon, then back to a mask. This has the following effect:
|
||||
- Smooth the edges of the mask slightly.
|
||||
- Ensure that each mask consists of a single closed polygon
|
||||
- Removes small mask pieces.
|
||||
- Removes holes from the mask.
|
||||
"""
|
||||
# Convert tensor masks to np masks.
|
||||
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
|
||||
|
||||
# # Reference:
|
||||
# # - https://arxiv.org/pdf/2304.02643
|
||||
# # - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
||||
# # - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
||||
# Apply polygon refinement.
|
||||
for idx, mask in enumerate(np_masks):
|
||||
shape = mask.shape
|
||||
assert len(shape) == 2 # Assert length to satisfy type checker.
|
||||
polygon = mask_to_polygon(mask)
|
||||
mask = polygon_to_mask(polygon, shape)
|
||||
np_masks[idx] = mask
|
||||
|
||||
# model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
|
||||
# image: ImageField = InputField(description="The image to segment.")
|
||||
# bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
|
||||
# apply_polygon_refinement: bool = InputField(
|
||||
# description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
|
||||
# default=True,
|
||||
# )
|
||||
# mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
|
||||
# description="The filtering to apply to the detected masks before merging them into a final output.",
|
||||
# default="all",
|
||||
# )
|
||||
# Convert np masks back to tensor masks.
|
||||
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
|
||||
|
||||
# @torch.no_grad()
|
||||
# def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
# # The models expect a 3-channel RGB image.
|
||||
# image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
return masks
|
||||
|
||||
# if len(self.bounding_boxes) == 0:
|
||||
# combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
|
||||
# else:
|
||||
# masks = self._segment(context=context, image=image_pil)
|
||||
# masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
|
||||
def _filter_masks(
|
||||
self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField] | None
|
||||
) -> list[torch.Tensor]:
|
||||
"""Filter the detected masks based on the specified mask filter."""
|
||||
|
||||
# # masks contains bool values, so we merge them via max-reduce.
|
||||
# combined_mask, _ = torch.stack(masks).max(dim=0)
|
||||
|
||||
# mask_tensor_name = context.tensors.save(combined_mask)
|
||||
# height, width = combined_mask.shape
|
||||
# return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
|
||||
|
||||
# @staticmethod
|
||||
# def _load_sam_model(model_path: Path):
|
||||
# sam_model = AutoModelForMaskGeneration.from_pretrained(
|
||||
# model_path,
|
||||
# local_files_only=True,
|
||||
# # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
||||
# # model, and figure out how to make it work in the pipeline.
|
||||
# # torch_dtype=TorchDevice.choose_torch_dtype(),
|
||||
# )
|
||||
# assert isinstance(sam_model, SamModel)
|
||||
|
||||
# sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
||||
# assert isinstance(sam_processor, SamProcessor)
|
||||
# return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
|
||||
|
||||
# def _segment(
|
||||
# self,
|
||||
# context: InvocationContext,
|
||||
# image: Image.Image,
|
||||
# ) -> list[torch.Tensor]:
|
||||
# """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
||||
# # Convert the bounding boxes to the SAM input format.
|
||||
# sam_bounding_boxes = [(bb.x_min, bb.y_min, bb.x_max, bb.y_max) for bb in self.bounding_boxes]
|
||||
|
||||
# with (
|
||||
# context.models.load_remote_model(
|
||||
# source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
|
||||
# ) as sam_pipeline,
|
||||
# ):
|
||||
# assert isinstance(sam_pipeline, SegmentAnythingPipeline)
|
||||
# masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
|
||||
|
||||
# masks = self._process_masks(masks)
|
||||
# if self.apply_polygon_refinement:
|
||||
# masks = self._apply_polygon_refinement(masks)
|
||||
|
||||
# return masks
|
||||
|
||||
# def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
|
||||
# """Convert the tensor output from the Segment Anything model from a tensor of shape
|
||||
# [num_masks, channels, height, width] to a list of tensors of shape [height, width].
|
||||
# """
|
||||
# assert masks.dtype == torch.bool
|
||||
# # [num_masks, channels, height, width] -> [num_masks, height, width]
|
||||
# masks, _ = masks.max(dim=1)
|
||||
# # Split the first dimension into a list of masks.
|
||||
# return list(masks.cpu().unbind(dim=0))
|
||||
|
||||
# def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
# """Apply polygon refinement to the masks.
|
||||
|
||||
# Convert each mask to a polygon, then back to a mask. This has the following effect:
|
||||
# - Smooth the edges of the mask slightly.
|
||||
# - Ensure that each mask consists of a single closed polygon
|
||||
# - Removes small mask pieces.
|
||||
# - Removes holes from the mask.
|
||||
# """
|
||||
# # Convert tensor masks to np masks.
|
||||
# np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
|
||||
|
||||
# # Apply polygon refinement.
|
||||
# for idx, mask in enumerate(np_masks):
|
||||
# shape = mask.shape
|
||||
# assert len(shape) == 2 # Assert length to satisfy type checker.
|
||||
# polygon = mask_to_polygon(mask)
|
||||
# mask = polygon_to_mask(polygon, shape)
|
||||
# np_masks[idx] = mask
|
||||
|
||||
# # Convert np masks back to tensor masks.
|
||||
# masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
|
||||
|
||||
# return masks
|
||||
|
||||
# def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
|
||||
# """Filter the detected masks based on the specified mask filter."""
|
||||
# assert len(masks) == len(bounding_boxes)
|
||||
|
||||
# if self.mask_filter == "all":
|
||||
# return masks
|
||||
# elif self.mask_filter == "largest":
|
||||
# # Find the largest mask.
|
||||
# return [max(masks, key=lambda x: float(x.sum()))]
|
||||
# elif self.mask_filter == "highest_box_score":
|
||||
# # Find the index of the bounding box with the highest score.
|
||||
# # Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
||||
# # cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
|
||||
# # reasonable fallback since the expected score range is [0.0, 1.0].
|
||||
# max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
|
||||
# return [masks[max_score_idx]]
|
||||
# else:
|
||||
# raise ValueError(f"Invalid mask filter: {self.mask_filter}")
|
||||
if self.mask_filter == "all":
|
||||
return masks
|
||||
elif self.mask_filter == "largest":
|
||||
# Find the largest mask.
|
||||
return [max(masks, key=lambda x: float(x.sum()))]
|
||||
elif self.mask_filter == "highest_box_score":
|
||||
assert (
|
||||
bounding_boxes is not None
|
||||
), "Bounding boxes must be provided to use the 'highest_box_score' mask filter."
|
||||
assert len(masks) == len(bounding_boxes)
|
||||
# Find the index of the bounding box with the highest score.
|
||||
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
||||
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
|
||||
# reasonable fallback since the expected score range is [0.0, 1.0].
|
||||
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
|
||||
return [masks[max_score_idx]]
|
||||
else:
|
||||
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, TypeAlias
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@ -7,6 +7,14 @@ from transformers.models.sam.processing_sam import SamProcessor
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
# Type aliases for the inputs to the SAM model.
|
||||
ListOfBoundingBoxes: TypeAlias = list[list[int]]
|
||||
"""A list of bounding boxes. Each bounding box is in the format [xmin, ymin, xmax, ymax]."""
|
||||
ListOfPoints: TypeAlias = list[list[int]]
|
||||
"""A list of points. Each point is in the format [x, y]."""
|
||||
ListOfPointLabels: TypeAlias = list[int]
|
||||
"""A list of SAM point labels. Each label is an integer where -1 is background, 0 is neutral, and 1 is foreground."""
|
||||
|
||||
|
||||
class SegmentAnythingPipeline(RawModel):
|
||||
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
||||
@ -30,36 +38,49 @@ class SegmentAnythingPipeline(RawModel):
|
||||
def segment(
|
||||
self,
|
||||
image: Image.Image,
|
||||
bounding_boxes: list[tuple[int, int, int, int]] | None = None,
|
||||
points: list[list[tuple[int, int]]] | None = None,
|
||||
labels: list[list[int]] | None = None,
|
||||
multimask_output: bool = True,
|
||||
bounding_boxes: list[list[int]] | None = None,
|
||||
point_lists: list[list[list[int]]] | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Run the SAM model.
|
||||
|
||||
Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and
|
||||
point_lists will be ignored.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The image to segment.
|
||||
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
|
||||
[xmin, ymin, xmax, ymax].
|
||||
point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label].
|
||||
`label` is an integer where -1 is background, 0 is neutral, and 1 is foreground.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
|
||||
"""
|
||||
input_kwargs = {}
|
||||
|
||||
# Prep the inputs:
|
||||
# - Create a list of bounding boxes or points and labels.
|
||||
# - Add a batch dimension of 1 to the inputs.
|
||||
if bounding_boxes:
|
||||
# Add batch dimension of 1 to the inputs.
|
||||
input_kwargs["input_boxes"] = [bounding_boxes]
|
||||
if points and labels and len(points) == len(labels):
|
||||
# Add batch dimension of 1 to the inputs.
|
||||
input_kwargs["input_points"] = [points]
|
||||
input_kwargs["input_labels"] = [labels]
|
||||
input_boxes: list[ListOfBoundingBoxes] | None = [bounding_boxes]
|
||||
input_points: list[ListOfPoints] | None = None
|
||||
input_labels: list[ListOfPointLabels] | None = None
|
||||
elif point_lists:
|
||||
input_boxes: list[ListOfBoundingBoxes] | None = None
|
||||
input_points: list[ListOfPoints] | None = []
|
||||
input_labels: list[ListOfPointLabels] | None = []
|
||||
for point_list in point_lists:
|
||||
input_points.append([[p[0], p[1]] for p in point_list])
|
||||
input_labels.append([p[2] for p in point_list])
|
||||
|
||||
else:
|
||||
raise ValueError("Either bounding_boxes or points and labels must be provided.")
|
||||
|
||||
inputs = self._sam_processor(
|
||||
images=image,
|
||||
**input_kwargs,
|
||||
input_boxes=input_boxes,
|
||||
input_points=input_points,
|
||||
input_labels=input_labels,
|
||||
return_tensors="pt",
|
||||
multimask_output=multimask_output,
|
||||
).to(self._sam_model.device)
|
||||
outputs = self._sam_model(**inputs)
|
||||
masks = self._sam_processor.post_process_masks(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user