mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-04-21 14:15:53 +02:00
Add unit test for FLUX XLabs IP-Adapter V2 model format.
This commit is contained in:
parent
8bd4207a27
commit
9a77e951d2
@ -10,36 +10,51 @@ from invokeai.backend.flux.ip_adapter.state_dict_utils import (
|
||||
)
|
||||
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import (
|
||||
XlabsIpAdapterFlux,
|
||||
XlabsIpAdapterParams,
|
||||
)
|
||||
from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_state_dict import xlabs_sd_shapes
|
||||
from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_state_dict import xlabs_flux_ip_adapter_sd_shapes
|
||||
from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_v2_state_dict import xlabs_flux_ip_adapter_v2_sd_shapes
|
||||
|
||||
|
||||
def test_is_state_dict_xlabs_ip_adapter():
|
||||
@pytest.mark.parametrize("sd_shapes", [xlabs_flux_ip_adapter_sd_shapes, xlabs_flux_ip_adapter_v2_sd_shapes])
|
||||
def test_is_state_dict_xlabs_ip_adapter(sd_shapes: dict[str, list[int]]):
|
||||
# Construct a dummy state_dict.
|
||||
sd = {k: None for k in xlabs_sd_shapes}
|
||||
sd = {k: None for k in sd_shapes}
|
||||
|
||||
assert is_state_dict_xlabs_ip_adapter(sd)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
|
||||
def test_infer_xlabs_ip_adapter_params_from_state_dict():
|
||||
@pytest.mark.parametrize(
|
||||
["sd_shapes", "expected_params"],
|
||||
[
|
||||
(
|
||||
xlabs_flux_ip_adapter_sd_shapes,
|
||||
XlabsIpAdapterParams(num_double_blocks=19, context_dim=4096, hidden_dim=3072, clip_embeddings_dim=768),
|
||||
),
|
||||
(
|
||||
xlabs_flux_ip_adapter_v2_sd_shapes,
|
||||
XlabsIpAdapterParams(num_double_blocks=19, context_dim=4096, hidden_dim=3072, clip_embeddings_dim=768),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_infer_xlabs_ip_adapter_params_from_state_dict(
|
||||
sd_shapes: dict[str, list[int]], expected_params: XlabsIpAdapterParams
|
||||
):
|
||||
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
|
||||
with torch.device("meta"):
|
||||
sd = {k: torch.zeros(v) for k, v in xlabs_sd_shapes.items()}
|
||||
sd = {k: torch.zeros(v) for k, v in sd_shapes.items()}
|
||||
|
||||
params = infer_xlabs_ip_adapter_params_from_state_dict(sd)
|
||||
|
||||
assert params.num_double_blocks == 19
|
||||
assert params.context_dim == 4096
|
||||
assert params.hidden_dim == 3072
|
||||
assert params.clip_embeddings_dim == 768
|
||||
assert params == expected_params
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
|
||||
def test_initialize_xlabs_ip_adapter_flux_from_state_dict():
|
||||
@pytest.mark.parametrize("sd_shapes", [xlabs_flux_ip_adapter_sd_shapes, xlabs_flux_ip_adapter_v2_sd_shapes])
|
||||
def test_initialize_xlabs_ip_adapter_flux_from_state_dict(sd_shapes: dict[str, list[int]]):
|
||||
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
|
||||
with torch.device("meta"):
|
||||
sd = {k: torch.zeros(v) for k, v in xlabs_sd_shapes.items()}
|
||||
sd = {k: torch.zeros(v) for k, v in sd_shapes.items()}
|
||||
|
||||
# Initialize the XLabs IP-Adapter from the state_dict.
|
||||
params = infer_xlabs_ip_adapter_params_from_state_dict(sd)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# State dict keys and shapes for an XLabs FLUX IP-Adapter model. Intended to be used for unit tests.
|
||||
# These keys were extracted from:
|
||||
# https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/ip_adapter.safetensors
|
||||
xlabs_sd_shapes = {
|
||||
xlabs_flux_ip_adapter_sd_shapes = {
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
|
||||
@ -0,0 +1,85 @@
|
||||
# State dict keys and shapes for an XLabs FLUX IP-Adapter V2 model. Intended to be used for unit tests.
|
||||
# These keys were extracted from:
|
||||
# https://huggingface.co/XLabs-AI/flux-ip-adapter-v2/blob/main/ip_adapter.safetensors
|
||||
xlabs_flux_ip_adapter_v2_sd_shapes = {
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.1.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.1.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.1.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.1.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.10.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.10.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.10.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.10.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.11.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.11.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.11.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.11.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.12.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.12.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.12.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.12.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.13.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.13.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.13.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.13.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.14.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.14.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.14.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.14.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.15.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.15.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.15.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.15.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.16.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.16.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.16.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.16.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.17.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.17.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.17.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.17.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.18.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.18.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.18.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.18.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.2.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.2.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.2.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.2.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.3.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.3.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.3.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.3.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.4.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.4.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.4.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.4.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.5.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.5.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.5.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.5.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.6.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.6.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.6.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.6.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.7.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.7.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.7.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.7.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.8.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.8.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.8.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.8.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"double_blocks.9.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.9.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
"double_blocks.9.processor.ip_adapter_double_stream_v_proj.bias": [3072],
|
||||
"double_blocks.9.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
|
||||
"ip_adapter_proj_model.norm.bias": [4096],
|
||||
"ip_adapter_proj_model.norm.weight": [4096],
|
||||
"ip_adapter_proj_model.proj.bias": [65536],
|
||||
"ip_adapter_proj_model.proj.weight": [65536, 768],
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user