Add initial logic for inferring FLUX IP-Adapter params from a state_dict.

This commit is contained in:
Ryan Dick 2024-10-11 17:15:10 +00:00
parent ac7441e606
commit 95c30f6a8b
2 changed files with 77 additions and 0 deletions

View File

@ -0,0 +1,60 @@
from dataclasses import dataclass
import torch
from invokeai.backend.ip_adapter.ip_adapter import ImageProjModel
class IPDoubleStreamBlock(torch.nn.Module):
def __init__(self, context_dim: int, hidden_dim: int):
super().__init__()
self.context_dim = context_dim
self.hidden_dim = hidden_dim
self.ip_adapter_double_stream_k_proj = torch.nn.Linear(context_dim, hidden_dim, bias=True)
self.ip_adapter_double_stream_v_proj = torch.nn.Linear(context_dim, hidden_dim, bias=True)
class XlabsIpAdapterFlux:
def __init__(self, image_proj: ImageProjModel, double_blocks: list[IPDoubleStreamBlock]):
self.image_proj = image_proj
self.double_blocks = double_blocks
@classmethod
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]) -> "XlabsIpAdapterFlux":
# TODO
return cls()
@dataclass
class XlabsIpAdapterParams:
num_double_blocks: int
context_dim: int
hidden_dim: int
clip_embeddings_dim: int
def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Tensor]) -> XlabsIpAdapterParams:
num_double_blocks = 0
context_dim = 0
hidden_dim = 0
# Count the number of double blocks.
double_block_index = 0
while f"double_blocks.{double_block_index}.processor.ip_adapter_double_stream_k_proj.weight" in state_dict:
double_block_index += 1
num_double_blocks = double_block_index
hidden_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[0]
context_dim = state_dict["double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight"].shape[1]
clip_embeddings_dim = state_dict["ip_adapter_proj_model.proj.weight"].shape[1]
return XlabsIpAdapterParams(
num_double_blocks=num_double_blocks,
context_dim=context_dim,
hidden_dim=hidden_dim,
clip_embeddings_dim=clip_embeddings_dim,
)

View File

@ -0,0 +1,17 @@
import torch
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import infer_xlabs_ip_adapter_params_from_state_dict
from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_state_dict import xlabs_sd_shapes
def test_infer_xlabs_ip_adapter_params_from_state_dict():
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(v) for k, v in xlabs_sd_shapes.items()}
params = infer_xlabs_ip_adapter_params_from_state_dict(sd)
assert params.num_double_blocks == 19
assert params.context_dim == 4096
assert params.hidden_dim == 3072
assert params.clip_embeddings_dim == 768