diff --git a/invokeai/backend/flux/ip_adapter/xlabs_ip_adapter_flux.py b/invokeai/backend/flux/ip_adapter/xlabs_ip_adapter_flux.py new file mode 100644 index 0000000000..63fd121221 --- /dev/null +++ b/invokeai/backend/flux/ip_adapter/xlabs_ip_adapter_flux.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass + +import torch + +from invokeai.backend.ip_adapter.ip_adapter import ImageProjModel + + +class IPDoubleStreamBlock(torch.nn.Module): + def __init__(self, context_dim: int, hidden_dim: int): + super().__init__() + + self.context_dim = context_dim + self.hidden_dim = hidden_dim + + self.ip_adapter_double_stream_k_proj = torch.nn.Linear(context_dim, hidden_dim, bias=True) + self.ip_adapter_double_stream_v_proj = torch.nn.Linear(context_dim, hidden_dim, bias=True) + + +class XlabsIpAdapterFlux: + def __init__(self, image_proj: ImageProjModel, double_blocks: list[IPDoubleStreamBlock]): + self.image_proj = image_proj + self.double_blocks = double_blocks + + @classmethod + def from_state_dict(cls, state_dict: dict[str, torch.Tensor]) -> "XlabsIpAdapterFlux": + # TODO + + return cls() + + +@dataclass +class XlabsIpAdapterParams: + num_double_blocks: int + context_dim: int + hidden_dim: int + + clip_embeddings_dim: int + + +def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Tensor]) -> XlabsIpAdapterParams: + num_double_blocks = 0 + context_dim = 0 + hidden_dim = 0 + + # Count the number of double blocks. + double_block_index = 0 + while f"double_blocks.{double_block_index}.processor.ip_adapter_double_stream_k_proj.weight" in state_dict: + double_block_index += 1 + num_double_blocks = double_block_index + + hidden_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[0] + context_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[1] + clip_embeddings_dim = state_dict["ip_adapter_proj_model.proj.weight"].shape[1] + + return XlabsIpAdapterParams( + num_double_blocks=num_double_blocks, + context_dim=context_dim, + hidden_dim=hidden_dim, + clip_embeddings_dim=clip_embeddings_dim, + ) diff --git a/tests/backend/flux/ip_adapter/test_xlabs_ip_adapter_flux.py b/tests/backend/flux/ip_adapter/test_xlabs_ip_adapter_flux.py new file mode 100644 index 0000000000..a4ca8180d0 --- /dev/null +++ b/tests/backend/flux/ip_adapter/test_xlabs_ip_adapter_flux.py @@ -0,0 +1,17 @@ +import torch + +from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import infer_xlabs_ip_adapter_params_from_state_dict +from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_state_dict import xlabs_sd_shapes + + +def test_infer_xlabs_ip_adapter_params_from_state_dict(): + # Construct a dummy state_dict with tensors of the correct shape on the meta device. + with torch.device("meta"): + sd = {k: torch.zeros(v) for k, v in xlabs_sd_shapes.items()} + + params = infer_xlabs_ip_adapter_params_from_state_dict(sd) + + assert params.num_double_blocks == 19 + assert params.context_dim == 4096 + assert params.hidden_dim == 3072 + assert params.clip_embeddings_dim == 768