From 745b6dbd5dc34e91a06fa827a56544eb07066a90 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 7 Oct 2024 17:38:03 +0000 Subject: [PATCH] Add unit test for infer_instantx_num_control_modes_from_state_dict(). --- .../backend/flux/controlnet/test_state_dict_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/backend/flux/controlnet/test_state_dict_utils.py b/tests/backend/flux/controlnet/test_state_dict_utils.py index f5a37e3054..8dfbfa18e5 100644 --- a/tests/backend/flux/controlnet/test_state_dict_utils.py +++ b/tests/backend/flux/controlnet/test_state_dict_utils.py @@ -4,6 +4,7 @@ import torch from invokeai.backend.flux.controlnet.state_dict_utils import ( convert_diffusers_instantx_state_dict_to_bfl_format, infer_flux_params_from_state_dict, + infer_instantx_num_control_modes_from_state_dict, is_state_dict_instantx_controlnet, is_state_dict_xlabs_controlnet, ) @@ -64,3 +65,14 @@ def test_infer_flux_params_from_state_dict(): assert flux_params.theta == 10000 assert flux_params.qkv_bias assert flux_params.guidance_embed == instantx_config["guidance_embeds"] + + +def test_infer_instantx_num_control_modes_from_state_dict(): + # Construct a dummy state_dict with tensor of the correct shape on the meta device. + with torch.device("meta"): + sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()} + + sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd) + num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd) + + assert num_control_modes == instantx_config["num_mode"]