Add unit test for FLUX XLabs IP-Adapter V2 model format.

This commit is contained in:
Ryan Dick 2024-11-15 19:59:01 +00:00 committed by psychedelicious
parent 8bd4207a27
commit 9a77e951d2
3 changed files with 113 additions and 13 deletions

View File

@ -10,36 +10,51 @@ from invokeai.backend.flux.ip_adapter.state_dict_utils import (
)
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import (
XlabsIpAdapterFlux,
XlabsIpAdapterParams,
)
from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_state_dict import xlabs_sd_shapes
from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_state_dict import xlabs_flux_ip_adapter_sd_shapes
from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_v2_state_dict import xlabs_flux_ip_adapter_v2_sd_shapes
def test_is_state_dict_xlabs_ip_adapter():
@pytest.mark.parametrize("sd_shapes", [xlabs_flux_ip_adapter_sd_shapes, xlabs_flux_ip_adapter_v2_sd_shapes])
def test_is_state_dict_xlabs_ip_adapter(sd_shapes: dict[str, list[int]]):
# Construct a dummy state_dict.
sd = {k: None for k in xlabs_sd_shapes}
sd = {k: None for k in sd_shapes}
assert is_state_dict_xlabs_ip_adapter(sd)
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
def test_infer_xlabs_ip_adapter_params_from_state_dict():
@pytest.mark.parametrize(
["sd_shapes", "expected_params"],
[
(
xlabs_flux_ip_adapter_sd_shapes,
XlabsIpAdapterParams(num_double_blocks=19, context_dim=4096, hidden_dim=3072, clip_embeddings_dim=768),
),
(
xlabs_flux_ip_adapter_v2_sd_shapes,
XlabsIpAdapterParams(num_double_blocks=19, context_dim=4096, hidden_dim=3072, clip_embeddings_dim=768),
),
],
)
def test_infer_xlabs_ip_adapter_params_from_state_dict(
sd_shapes: dict[str, list[int]], expected_params: XlabsIpAdapterParams
):
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(v) for k, v in xlabs_sd_shapes.items()}
sd = {k: torch.zeros(v) for k, v in sd_shapes.items()}
params = infer_xlabs_ip_adapter_params_from_state_dict(sd)
assert params.num_double_blocks == 19
assert params.context_dim == 4096
assert params.hidden_dim == 3072
assert params.clip_embeddings_dim == 768
assert params == expected_params
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
def test_initialize_xlabs_ip_adapter_flux_from_state_dict():
@pytest.mark.parametrize("sd_shapes", [xlabs_flux_ip_adapter_sd_shapes, xlabs_flux_ip_adapter_v2_sd_shapes])
def test_initialize_xlabs_ip_adapter_flux_from_state_dict(sd_shapes: dict[str, list[int]]):
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(v) for k, v in xlabs_sd_shapes.items()}
sd = {k: torch.zeros(v) for k, v in sd_shapes.items()}
# Initialize the XLabs IP-Adapter from the state_dict.
params = infer_xlabs_ip_adapter_params_from_state_dict(sd)

View File

@ -1,7 +1,7 @@
# State dict keys and shapes for an XLabs FLUX IP-Adapter model. Intended to be used for unit tests.
# These keys were extracted from:
# https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/ip_adapter.safetensors
xlabs_sd_shapes = {
xlabs_flux_ip_adapter_sd_shapes = {
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.0.processor.ip_adapter_double_stream_v_proj.bias": [3072],

View File

@ -0,0 +1,85 @@
# State dict keys and shapes for an XLabs FLUX IP-Adapter V2 model. Intended to be used for unit tests.
# These keys were extracted from:
# https://huggingface.co/XLabs-AI/flux-ip-adapter-v2/blob/main/ip_adapter.safetensors
xlabs_flux_ip_adapter_v2_sd_shapes = {
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.0.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.0.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.1.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.1.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.1.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.1.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.10.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.10.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.10.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.10.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.11.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.11.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.11.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.11.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.12.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.12.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.12.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.12.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.13.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.13.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.13.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.13.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.14.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.14.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.14.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.14.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.15.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.15.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.15.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.15.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.16.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.16.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.16.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.16.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.17.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.17.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.17.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.17.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.18.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.18.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.18.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.18.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.2.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.2.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.2.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.2.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.3.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.3.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.3.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.3.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.4.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.4.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.4.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.4.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.5.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.5.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.5.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.5.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.6.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.6.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.6.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.6.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.7.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.7.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.7.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.7.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.8.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.8.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.8.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.8.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"double_blocks.9.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.9.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
"double_blocks.9.processor.ip_adapter_double_stream_v_proj.bias": [3072],
"double_blocks.9.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096],
"ip_adapter_proj_model.norm.bias": [4096],
"ip_adapter_proj_model.norm.weight": [4096],
"ip_adapter_proj_model.proj.bias": [65536],
"ip_adapter_proj_model.proj.weight": [65536, 768],
}