From 9a77e951d2f622e08ed5bd5853de96d99bb4b778 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 15 Nov 2024 19:59:01 +0000 Subject: [PATCH] Add unit test for FLUX XLabs IP-Adapter V2 model format. --- .../ip_adapter/test_xlabs_ip_adapter_flux.py | 39 ++++++--- .../xlabs_flux_ip_adapter_state_dict.py | 2 +- .../xlabs_flux_ip_adapter_v2_state_dict.py | 85 +++++++++++++++++++ 3 files changed, 113 insertions(+), 13 deletions(-) create mode 100644 tests/backend/flux/ip_adapter/xlabs_flux_ip_adapter_v2_state_dict.py diff --git a/tests/backend/flux/ip_adapter/test_xlabs_ip_adapter_flux.py b/tests/backend/flux/ip_adapter/test_xlabs_ip_adapter_flux.py index 93012684b7..6b75fbe4ea 100644 --- a/tests/backend/flux/ip_adapter/test_xlabs_ip_adapter_flux.py +++ b/tests/backend/flux/ip_adapter/test_xlabs_ip_adapter_flux.py @@ -10,36 +10,51 @@ from invokeai.backend.flux.ip_adapter.state_dict_utils import ( ) from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import ( XlabsIpAdapterFlux, + XlabsIpAdapterParams, ) -from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_state_dict import xlabs_sd_shapes +from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_state_dict import xlabs_flux_ip_adapter_sd_shapes +from tests.backend.flux.ip_adapter.xlabs_flux_ip_adapter_v2_state_dict import xlabs_flux_ip_adapter_v2_sd_shapes -def test_is_state_dict_xlabs_ip_adapter(): +@pytest.mark.parametrize("sd_shapes", [xlabs_flux_ip_adapter_sd_shapes, xlabs_flux_ip_adapter_v2_sd_shapes]) +def test_is_state_dict_xlabs_ip_adapter(sd_shapes: dict[str, list[int]]): # Construct a dummy state_dict. - sd = {k: None for k in xlabs_sd_shapes} + sd = {k: None for k in sd_shapes} assert is_state_dict_xlabs_ip_adapter(sd) @pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS") -def test_infer_xlabs_ip_adapter_params_from_state_dict(): +@pytest.mark.parametrize( + ["sd_shapes", "expected_params"], + [ + ( + xlabs_flux_ip_adapter_sd_shapes, + XlabsIpAdapterParams(num_double_blocks=19, context_dim=4096, hidden_dim=3072, clip_embeddings_dim=768), + ), + ( + xlabs_flux_ip_adapter_v2_sd_shapes, + XlabsIpAdapterParams(num_double_blocks=19, context_dim=4096, hidden_dim=3072, clip_embeddings_dim=768), + ), + ], +) +def test_infer_xlabs_ip_adapter_params_from_state_dict( + sd_shapes: dict[str, list[int]], expected_params: XlabsIpAdapterParams +): # Construct a dummy state_dict with tensors of the correct shape on the meta device. with torch.device("meta"): - sd = {k: torch.zeros(v) for k, v in xlabs_sd_shapes.items()} + sd = {k: torch.zeros(v) for k, v in sd_shapes.items()} params = infer_xlabs_ip_adapter_params_from_state_dict(sd) - - assert params.num_double_blocks == 19 - assert params.context_dim == 4096 - assert params.hidden_dim == 3072 - assert params.clip_embeddings_dim == 768 + assert params == expected_params @pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS") -def test_initialize_xlabs_ip_adapter_flux_from_state_dict(): +@pytest.mark.parametrize("sd_shapes", [xlabs_flux_ip_adapter_sd_shapes, xlabs_flux_ip_adapter_v2_sd_shapes]) +def test_initialize_xlabs_ip_adapter_flux_from_state_dict(sd_shapes: dict[str, list[int]]): # Construct a dummy state_dict with tensors of the correct shape on the meta device. with torch.device("meta"): - sd = {k: torch.zeros(v) for k, v in xlabs_sd_shapes.items()} + sd = {k: torch.zeros(v) for k, v in sd_shapes.items()} # Initialize the XLabs IP-Adapter from the state_dict. params = infer_xlabs_ip_adapter_params_from_state_dict(sd) diff --git a/tests/backend/flux/ip_adapter/xlabs_flux_ip_adapter_state_dict.py b/tests/backend/flux/ip_adapter/xlabs_flux_ip_adapter_state_dict.py index d0d2d550af..6406b9c364 100644 --- a/tests/backend/flux/ip_adapter/xlabs_flux_ip_adapter_state_dict.py +++ b/tests/backend/flux/ip_adapter/xlabs_flux_ip_adapter_state_dict.py @@ -1,7 +1,7 @@ # State dict keys and shapes for an XLabs FLUX IP-Adapter model. Intended to be used for unit tests. # These keys were extracted from: # https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/ip_adapter.safetensors -xlabs_sd_shapes = { +xlabs_flux_ip_adapter_sd_shapes = { "double_blocks.0.processor.ip_adapter_double_stream_k_proj.bias": [3072], "double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], "double_blocks.0.processor.ip_adapter_double_stream_v_proj.bias": [3072], diff --git a/tests/backend/flux/ip_adapter/xlabs_flux_ip_adapter_v2_state_dict.py b/tests/backend/flux/ip_adapter/xlabs_flux_ip_adapter_v2_state_dict.py new file mode 100644 index 0000000000..59a33c8cb1 --- /dev/null +++ b/tests/backend/flux/ip_adapter/xlabs_flux_ip_adapter_v2_state_dict.py @@ -0,0 +1,85 @@ +# State dict keys and shapes for an XLabs FLUX IP-Adapter V2 model. Intended to be used for unit tests. +# These keys were extracted from: +# https://huggingface.co/XLabs-AI/flux-ip-adapter-v2/blob/main/ip_adapter.safetensors +xlabs_flux_ip_adapter_v2_sd_shapes = { + "double_blocks.0.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.0.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.0.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.1.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.1.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.1.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.1.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.10.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.10.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.10.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.10.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.11.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.11.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.11.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.11.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.12.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.12.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.12.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.12.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.13.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.13.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.13.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.13.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.14.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.14.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.14.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.14.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.15.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.15.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.15.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.15.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.16.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.16.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.16.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.16.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.17.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.17.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.17.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.17.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.18.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.18.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.18.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.18.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.2.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.2.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.2.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.2.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.3.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.3.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.3.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.3.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.4.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.4.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.4.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.4.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.5.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.5.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.5.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.5.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.6.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.6.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.6.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.6.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.7.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.7.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.7.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.7.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.8.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.8.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.8.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.8.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "double_blocks.9.processor.ip_adapter_double_stream_k_proj.bias": [3072], + "double_blocks.9.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096], + "double_blocks.9.processor.ip_adapter_double_stream_v_proj.bias": [3072], + "double_blocks.9.processor.ip_adapter_double_stream_v_proj.weight": [3072, 4096], + "ip_adapter_proj_model.norm.bias": [4096], + "ip_adapter_proj_model.norm.weight": [4096], + "ip_adapter_proj_model.proj.bias": [65536], + "ip_adapter_proj_model.proj.weight": [65536, 768], +}