feat(nodes): update SAM backend and nodes to work with SAM points

This commit is contained in:
psychedelicious 2024-10-23 13:47:44 +10:00
parent 790846297a
commit ff72315db2
2 changed files with 122 additions and 205 deletions

View File

@ -1,7 +1,8 @@
from enum import Enum
from pathlib import Path
from typing import Literal, TypedDict
from typing import Literal
import numpy as np
import torch
from PIL import Image
from pydantic import BaseModel, Field, model_validator
@ -13,6 +14,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
from invokeai.app.invocations.primitives import MaskOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
@ -35,55 +37,21 @@ class SAMPoint(BaseModel):
label: SAMPointLabel = Field(..., description="The label of the point")
class SAMObjectIdentifierInputData(TypedDict):
points: list[tuple[int, int]]
labels: list[int]
bounding_box: tuple[int, int, int, int] | None
class SAMPointsField(BaseModel):
points: list[SAMPoint] = Field(..., description="The points of the object")
class SAMObjectInputKwargs(TypedDict):
points: list[list[tuple[int, int]]]
labels: list[list[int]]
bounding_boxes: list[tuple[int, int, int, int]]
class SAMObjectIdentifierField(BaseModel):
points: list[SAMPoint] | None = Field(None, description="The points of the feature")
bounding_box: BoundingBoxField | None = Field(None, description="The bounding box of the feature")
@model_validator(mode="after")
def check_points_or_bounding_box(self):
if self.points is None and self.bounding_box is None:
raise ValueError("Either points or bounding_box must be provided.")
return self
def get_input_data(self) -> SAMObjectIdentifierInputData:
input_data: SAMObjectIdentifierInputData = {"points": [], "labels": [], "bounding_box": None}
if self.points is not None:
for point in self.points:
input_data["points"].append((point.x, point.y))
input_data["labels"].append(point.label.value)
if self.bounding_box is not None:
input_data["bounding_box"] = (
self.bounding_box.x_min,
self.bounding_box.y_min,
self.bounding_box.x_max,
self.bounding_box.y_max,
)
return input_data
def to_list(self) -> list[list[int]]:
return [[point.x, point.y, point.label.value] for point in self.points]
@invocation(
"segment_anything_object_identifier",
"segment_anything",
title="Segment Anything",
tags=["prompt", "segmentation"],
category="segmentation",
version="1.0.0",
version="1.1.0",
)
class SegmentAnythingObjectIdentifierInvocation(BaseInvocation):
class SegmentAnythingInvocation(BaseInvocation):
"""Runs a Segment Anything Model."""
# Reference:
@ -93,8 +61,12 @@ class SegmentAnythingObjectIdentifierInvocation(BaseInvocation):
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
image: ImageField = InputField(description="The image to segment.")
object_identifiers: list[SAMObjectIdentifierField] = InputField(
description="The bounding boxes to prompt the SAM model with."
bounding_boxes: list[BoundingBoxField] | None = InputField(
default=None, description="The bounding boxes to prompt the SAM model with."
)
point_lists: list[SAMPointsField] | None = InputField(
default=None,
description="The list of point lists to prompt the SAM model with. Each list of points represents a single object.",
)
apply_polygon_refinement: bool = InputField(
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
@ -105,15 +77,32 @@ class SegmentAnythingObjectIdentifierInvocation(BaseInvocation):
default="all",
)
@model_validator(mode="after")
def check_point_lists_or_bounding_box(self):
if self.point_lists is None and self.bounding_boxes is None:
raise ValueError("Either point_lists or bounding_box must be provided.")
elif self.point_lists is not None and self.bounding_boxes is not None:
raise ValueError("Only one of point_lists or bounding_box can be provided.")
return self
@torch.no_grad()
def invoke(self, context: InvocationContext) -> MaskOutput:
# The models expect a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
mask = self._segment(context=context, image=image_pil)[0]
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
not self.point_lists or len(self.point_lists) == 0
):
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
else:
masks = self._segment(context=context, image=image_pil)
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
mask_tensor_name = context.tensors.save(mask)
height, width = mask.shape
# masks contains bool values, so we merge them via max-reduce.
combined_mask, _ = torch.stack(masks).max(dim=0)
mask_tensor_name = context.tensors.save(combined_mask)
height, width = combined_mask.shape
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
@staticmethod
@ -133,27 +122,23 @@ class SegmentAnythingObjectIdentifierInvocation(BaseInvocation):
def _segment(self, context: InvocationContext, image: Image.Image) -> list[torch.Tensor]:
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
input_kwargs: SAMObjectInputKwargs = {"points": [], "labels": [], "bounding_boxes": []}
for obj_id in self.object_identifiers:
input_data = obj_id.get_input_data()
if input_data["points"]:
input_kwargs["points"].append(input_data["points"])
input_kwargs["labels"].append(input_data["labels"])
if input_data["bounding_box"]:
input_kwargs["bounding_boxes"].append(input_data["bounding_box"])
# Convert the bounding boxes to the SAM input format.
sam_bounding_boxes = (
[[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] if self.bounding_boxes else None
)
sam_points = [p.to_list() for p in self.point_lists] if self.point_lists else None
with (
context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_IDS[self.model],
loader=SegmentAnythingObjectIdentifierInvocation._load_sam_model,
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
) as sam_pipeline,
):
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
masks = sam_pipeline.segment(image=image, **input_kwargs, multimask_output=False)
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes, point_lists=sam_points)
masks = self._process_masks(masks)
if self.apply_polygon_refinement:
masks = self._apply_polygon_refinement(masks)
return masks
@ -167,140 +152,51 @@ class SegmentAnythingObjectIdentifierInvocation(BaseInvocation):
# Split the first dimension into a list of masks.
return list(masks.cpu().unbind(dim=0))
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
"""Apply polygon refinement to the masks.
# @invocation(
# "segment_anything",
# title="Segment Anything",
# tags=["prompt", "segmentation"],
# category="segmentation",
# version="1.0.0",
# )
# class SegmentAnythingInvocation(BaseInvocation):
# """Runs a Segment Anything Model."""
Convert each mask to a polygon, then back to a mask. This has the following effect:
- Smooth the edges of the mask slightly.
- Ensure that each mask consists of a single closed polygon
- Removes small mask pieces.
- Removes holes from the mask.
"""
# Convert tensor masks to np masks.
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
# # Reference:
# # - https://arxiv.org/pdf/2304.02643
# # - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
# # - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
# Apply polygon refinement.
for idx, mask in enumerate(np_masks):
shape = mask.shape
assert len(shape) == 2 # Assert length to satisfy type checker.
polygon = mask_to_polygon(mask)
mask = polygon_to_mask(polygon, shape)
np_masks[idx] = mask
# model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
# image: ImageField = InputField(description="The image to segment.")
# bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
# apply_polygon_refinement: bool = InputField(
# description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
# default=True,
# )
# mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
# description="The filtering to apply to the detected masks before merging them into a final output.",
# default="all",
# )
# Convert np masks back to tensor masks.
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
# @torch.no_grad()
# def invoke(self, context: InvocationContext) -> MaskOutput:
# # The models expect a 3-channel RGB image.
# image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
return masks
# if len(self.bounding_boxes) == 0:
# combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
# else:
# masks = self._segment(context=context, image=image_pil)
# masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
def _filter_masks(
self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField] | None
) -> list[torch.Tensor]:
"""Filter the detected masks based on the specified mask filter."""
# # masks contains bool values, so we merge them via max-reduce.
# combined_mask, _ = torch.stack(masks).max(dim=0)
# mask_tensor_name = context.tensors.save(combined_mask)
# height, width = combined_mask.shape
# return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
# @staticmethod
# def _load_sam_model(model_path: Path):
# sam_model = AutoModelForMaskGeneration.from_pretrained(
# model_path,
# local_files_only=True,
# # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# # model, and figure out how to make it work in the pipeline.
# # torch_dtype=TorchDevice.choose_torch_dtype(),
# )
# assert isinstance(sam_model, SamModel)
# sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
# assert isinstance(sam_processor, SamProcessor)
# return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
# def _segment(
# self,
# context: InvocationContext,
# image: Image.Image,
# ) -> list[torch.Tensor]:
# """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
# # Convert the bounding boxes to the SAM input format.
# sam_bounding_boxes = [(bb.x_min, bb.y_min, bb.x_max, bb.y_max) for bb in self.bounding_boxes]
# with (
# context.models.load_remote_model(
# source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
# ) as sam_pipeline,
# ):
# assert isinstance(sam_pipeline, SegmentAnythingPipeline)
# masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
# masks = self._process_masks(masks)
# if self.apply_polygon_refinement:
# masks = self._apply_polygon_refinement(masks)
# return masks
# def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
# """Convert the tensor output from the Segment Anything model from a tensor of shape
# [num_masks, channels, height, width] to a list of tensors of shape [height, width].
# """
# assert masks.dtype == torch.bool
# # [num_masks, channels, height, width] -> [num_masks, height, width]
# masks, _ = masks.max(dim=1)
# # Split the first dimension into a list of masks.
# return list(masks.cpu().unbind(dim=0))
# def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
# """Apply polygon refinement to the masks.
# Convert each mask to a polygon, then back to a mask. This has the following effect:
# - Smooth the edges of the mask slightly.
# - Ensure that each mask consists of a single closed polygon
# - Removes small mask pieces.
# - Removes holes from the mask.
# """
# # Convert tensor masks to np masks.
# np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
# # Apply polygon refinement.
# for idx, mask in enumerate(np_masks):
# shape = mask.shape
# assert len(shape) == 2 # Assert length to satisfy type checker.
# polygon = mask_to_polygon(mask)
# mask = polygon_to_mask(polygon, shape)
# np_masks[idx] = mask
# # Convert np masks back to tensor masks.
# masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
# return masks
# def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
# """Filter the detected masks based on the specified mask filter."""
# assert len(masks) == len(bounding_boxes)
# if self.mask_filter == "all":
# return masks
# elif self.mask_filter == "largest":
# # Find the largest mask.
# return [max(masks, key=lambda x: float(x.sum()))]
# elif self.mask_filter == "highest_box_score":
# # Find the index of the bounding box with the highest score.
# # Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
# # cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
# # reasonable fallback since the expected score range is [0.0, 1.0].
# max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
# return [masks[max_score_idx]]
# else:
# raise ValueError(f"Invalid mask filter: {self.mask_filter}")
if self.mask_filter == "all":
return masks
elif self.mask_filter == "largest":
# Find the largest mask.
return [max(masks, key=lambda x: float(x.sum()))]
elif self.mask_filter == "highest_box_score":
assert (
bounding_boxes is not None
), "Bounding boxes must be provided to use the 'highest_box_score' mask filter."
assert len(masks) == len(bounding_boxes)
# Find the index of the bounding box with the highest score.
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
# reasonable fallback since the expected score range is [0.0, 1.0].
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
return [masks[max_score_idx]]
else:
raise ValueError(f"Invalid mask filter: {self.mask_filter}")

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, TypeAlias
import torch
from PIL import Image
@ -7,6 +7,14 @@ from transformers.models.sam.processing_sam import SamProcessor
from invokeai.backend.raw_model import RawModel
# Type aliases for the inputs to the SAM model.
ListOfBoundingBoxes: TypeAlias = list[list[int]]
"""A list of bounding boxes. Each bounding box is in the format [xmin, ymin, xmax, ymax]."""
ListOfPoints: TypeAlias = list[list[int]]
"""A list of points. Each point is in the format [x, y]."""
ListOfPointLabels: TypeAlias = list[int]
"""A list of SAM point labels. Each label is an integer where -1 is background, 0 is neutral, and 1 is foreground."""
class SegmentAnythingPipeline(RawModel):
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
@ -30,36 +38,49 @@ class SegmentAnythingPipeline(RawModel):
def segment(
self,
image: Image.Image,
bounding_boxes: list[tuple[int, int, int, int]] | None = None,
points: list[list[tuple[int, int]]] | None = None,
labels: list[list[int]] | None = None,
multimask_output: bool = True,
bounding_boxes: list[list[int]] | None = None,
point_lists: list[list[list[int]]] | None = None,
) -> torch.Tensor:
"""Run the SAM model.
Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and
point_lists will be ignored.
Args:
image (Image.Image): The image to segment.
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
[xmin, ymin, xmax, ymax].
point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label].
`label` is an integer where -1 is background, 0 is neutral, and 1 is foreground.
Returns:
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
"""
input_kwargs = {}
# Prep the inputs:
# - Create a list of bounding boxes or points and labels.
# - Add a batch dimension of 1 to the inputs.
if bounding_boxes:
# Add batch dimension of 1 to the inputs.
input_kwargs["input_boxes"] = [bounding_boxes]
if points and labels and len(points) == len(labels):
# Add batch dimension of 1 to the inputs.
input_kwargs["input_points"] = [points]
input_kwargs["input_labels"] = [labels]
input_boxes: list[ListOfBoundingBoxes] | None = [bounding_boxes]
input_points: list[ListOfPoints] | None = None
input_labels: list[ListOfPointLabels] | None = None
elif point_lists:
input_boxes: list[ListOfBoundingBoxes] | None = None
input_points: list[ListOfPoints] | None = []
input_labels: list[ListOfPointLabels] | None = []
for point_list in point_lists:
input_points.append([[p[0], p[1]] for p in point_list])
input_labels.append([p[2] for p in point_list])
else:
raise ValueError("Either bounding_boxes or points and labels must be provided.")
inputs = self._sam_processor(
images=image,
**input_kwargs,
input_boxes=input_boxes,
input_points=input_points,
input_labels=input_labels,
return_tensors="pt",
multimask_output=multimask_output,
).to(self._sam_model.device)
outputs = self._sam_model(**inputs)
masks = self._sam_processor.post_process_masks(