diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index 826112156d..9176bf1f49 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -54,6 +54,44 @@ class ImageProjModel(torch.nn.Module): return clip_extra_context_tokens +class MLPProjModel(torch.nn.Module): + """SD model with image prompt""" + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), + torch.nn.GELU(), + torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), + torch.nn.LayerNorm(cross_attention_dim), + ) + + @classmethod + def from_state_dict(cls, state_dict: dict[torch.Tensor]): + """Initialize an MLPProjModel from a state_dict. + + The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict. + + Args: + state_dict (dict[torch.Tensor]): The state_dict of model weights. + + Returns: + MLPProjModel + """ + cross_attention_dim = state_dict["proj.3.weight"].shape[0] + clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] + + model = cls(cross_attention_dim, clip_embeddings_dim) + + model.load_state_dict(state_dict) + return model + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + class IPAdapter: """IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf""" @@ -130,6 +168,13 @@ class IPAdapterPlus(IPAdapter): return image_prompt_embeds, uncond_image_prompt_embeds +class IPAdapterFull(IPAdapterPlus): + """IP-Adapter Plus with full features.""" + + def _init_image_proj_model(self, state_dict: dict[torch.Tensor]): + return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype) + + class IPAdapterPlusXL(IPAdapterPlus): """IP-Adapter Plus for SDXL.""" @@ -149,11 +194,9 @@ def build_ip_adapter( ) -> Union[IPAdapter, IPAdapterPlus]: state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu") - # Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it - # contains. - is_plus = "proj.weight" not in state_dict["image_proj"] - - if is_plus: + if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel). + return IPAdapter(state_dict, device=device, dtype=dtype) + elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler). cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] if cross_attention_dim == 768: # SD1 IP-Adapter Plus @@ -163,5 +206,7 @@ def build_ip_adapter( return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) else: raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.") + elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel). + return IPAdapterFull(state_dict, device=device, dtype=dtype) else: - return IPAdapter(state_dict, device=device, dtype=dtype) + raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.") diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/backend/ip_adapter/test_ip_adapter.py index 6712196778..6a3ec510a2 100644 --- a/tests/backend/ip_adapter/test_ip_adapter.py +++ b/tests/backend/ip_adapter/test_ip_adapter.py @@ -37,6 +37,14 @@ def build_dummy_sd15_unet_input(torch_device): "unet_model_id": "runwayml/stable-diffusion-v1-5", "unet_model_name": "stable-diffusion-v1-5", }, + # SD1.5, IPAdapterFull + { + "ip_adapter_model_id": "InvokeAI/ip-adapter-full-face_sd15", + "ip_adapter_model_name": "ip-adapter-full-face_sd15", + "base_model": BaseModelType.StableDiffusion1, + "unet_model_id": "runwayml/stable-diffusion-v1-5", + "unet_model_name": "stable-diffusion-v1-5", + }, ], ) @pytest.mark.slow