mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-04-21 06:06:02 +02:00
Add unit test for infer_instantx_num_control_modes_from_state_dict().
This commit is contained in:
parent
c7628945c4
commit
745b6dbd5d
@ -4,6 +4,7 @@ import torch
|
||||
from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
convert_diffusers_instantx_state_dict_to_bfl_format,
|
||||
infer_flux_params_from_state_dict,
|
||||
infer_instantx_num_control_modes_from_state_dict,
|
||||
is_state_dict_instantx_controlnet,
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
@ -64,3 +65,14 @@ def test_infer_flux_params_from_state_dict():
|
||||
assert flux_params.theta == 10000
|
||||
assert flux_params.qkv_bias
|
||||
assert flux_params.guidance_embed == instantx_config["guidance_embeds"]
|
||||
|
||||
|
||||
def test_infer_instantx_num_control_modes_from_state_dict():
|
||||
# Construct a dummy state_dict with tensor of the correct shape on the meta device.
|
||||
with torch.device("meta"):
|
||||
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
|
||||
|
||||
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
|
||||
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
|
||||
|
||||
assert num_control_modes == instantx_config["num_mode"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user