Add utils for detecting XLabs ControlNet vs. InstantX ControlNet from

state dict.
This commit is contained in:
Ryan Dick 2024-10-04 19:19:56 +00:00
parent 1751c380db
commit 4be3a33744
4 changed files with 77 additions and 2 deletions

View File

@ -0,0 +1,41 @@
from typing import Any, Dict
def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool:
"""Is the state dict for an XLabs ControlNet model?
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
"""
# If all of the expected keys are present, then this is very likely an XLabs ControlNet model.
expected_keys = {
"controlnet_blocks.0.bias",
"controlnet_blocks.0.weight",
"input_hint_block.0.bias",
"input_hint_block.0.weight",
"pos_embed_input.bias",
"pos_embed_input.weight",
}
if expected_keys.issubset(sd.keys()):
return True
return False
def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool:
"""Is the state dict for an InstantX ControlNet model?
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
"""
# If all of the expected keys are present, then this is very likely an InstantX ControlNet model.
expected_keys = {
"controlnet_blocks.0.bias",
"controlnet_blocks.0.weight",
"controlnet_single_blocks.0.bias",
"controlnet_single_blocks.0.weight",
"controlnet_x_embedder.bias",
"controlnet_x_embedder.weight",
}
if expected_keys.issubset(sd.keys()):
return True
return False

View File

@ -1,7 +1,7 @@
# State dict keys for an InstantX FLUX ControlNet Union model. Intended to be used for unit tests.
# These keys were extracted from:
# https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/4f32d6f2b220f8873d49bb8acc073e1df180c994/diffusion_pytorch_model.safetensors
state_dict_keys = [
instantx_state_dict_keys = [
"context_embedder.bias",
"context_embedder.weight",
"controlnet_blocks.0.bias",

View File

@ -0,0 +1,34 @@
import pytest
from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from tests.backend.flux.controlnet.instantx_flux_controlnet_state_dict import instantx_state_dict_keys
from tests.backend.flux.controlnet.xlabs_flux_controlnet_state_dict import xlabs_state_dict_keys
@pytest.mark.parametrize(
["sd_keys", "expected"],
[
(xlabs_state_dict_keys, True),
(instantx_state_dict_keys, False),
(["foo"], False),
],
)
def test_is_state_dict_xlabs_controlnet(sd_keys: list[str], expected: bool):
sd = {k: None for k in sd_keys}
assert is_state_dict_xlabs_controlnet(sd) == expected
@pytest.mark.parametrize(
["sd_keys", "expected"],
[
(instantx_state_dict_keys, True),
(xlabs_state_dict_keys, False),
(["foo"], False),
],
)
def test_is_state_dict_instantx_controlnet(sd_keys: list[str], expected: bool):
sd = {k: None for k in sd_keys}
assert is_state_dict_instantx_controlnet(sd) == expected

View File

@ -1,7 +1,7 @@
# State dict keys for an XLabs FLUX ControlNet model. Intended to be used for unit tests.
# These keys were extracted from:
# https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
state_dict_keys = [
xlabs_state_dict_keys = [
"controlnet_blocks.0.bias",
"controlnet_blocks.0.weight",
"controlnet_blocks.1.bias",