mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-04-21 22:21:28 +02:00
Add utils for detecting XLabs ControlNet vs. InstantX ControlNet from
state dict.
This commit is contained in:
parent
1751c380db
commit
4be3a33744
41
invokeai/backend/flux/controlnet/state_dict_utils.py
Normal file
41
invokeai/backend/flux/controlnet/state_dict_utils.py
Normal file
@ -0,0 +1,41 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool:
|
||||
"""Is the state dict for an XLabs ControlNet model?
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
|
||||
"""
|
||||
# If all of the expected keys are present, then this is very likely an XLabs ControlNet model.
|
||||
expected_keys = {
|
||||
"controlnet_blocks.0.bias",
|
||||
"controlnet_blocks.0.weight",
|
||||
"input_hint_block.0.bias",
|
||||
"input_hint_block.0.weight",
|
||||
"pos_embed_input.bias",
|
||||
"pos_embed_input.weight",
|
||||
}
|
||||
|
||||
if expected_keys.issubset(sd.keys()):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool:
|
||||
"""Is the state dict for an InstantX ControlNet model?
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
|
||||
"""
|
||||
# If all of the expected keys are present, then this is very likely an InstantX ControlNet model.
|
||||
expected_keys = {
|
||||
"controlnet_blocks.0.bias",
|
||||
"controlnet_blocks.0.weight",
|
||||
"controlnet_single_blocks.0.bias",
|
||||
"controlnet_single_blocks.0.weight",
|
||||
"controlnet_x_embedder.bias",
|
||||
"controlnet_x_embedder.weight",
|
||||
}
|
||||
|
||||
if expected_keys.issubset(sd.keys()):
|
||||
return True
|
||||
return False
|
||||
@ -1,7 +1,7 @@
|
||||
# State dict keys for an InstantX FLUX ControlNet Union model. Intended to be used for unit tests.
|
||||
# These keys were extracted from:
|
||||
# https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/4f32d6f2b220f8873d49bb8acc073e1df180c994/diffusion_pytorch_model.safetensors
|
||||
state_dict_keys = [
|
||||
instantx_state_dict_keys = [
|
||||
"context_embedder.bias",
|
||||
"context_embedder.weight",
|
||||
"controlnet_blocks.0.bias",
|
||||
|
||||
34
tests/backend/flux/controlnet/test_state_dict_utils.py
Normal file
34
tests/backend/flux/controlnet/test_state_dict_utils.py
Normal file
@ -0,0 +1,34 @@
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
is_state_dict_instantx_controlnet,
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from tests.backend.flux.controlnet.instantx_flux_controlnet_state_dict import instantx_state_dict_keys
|
||||
from tests.backend.flux.controlnet.xlabs_flux_controlnet_state_dict import xlabs_state_dict_keys
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["sd_keys", "expected"],
|
||||
[
|
||||
(xlabs_state_dict_keys, True),
|
||||
(instantx_state_dict_keys, False),
|
||||
(["foo"], False),
|
||||
],
|
||||
)
|
||||
def test_is_state_dict_xlabs_controlnet(sd_keys: list[str], expected: bool):
|
||||
sd = {k: None for k in sd_keys}
|
||||
assert is_state_dict_xlabs_controlnet(sd) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["sd_keys", "expected"],
|
||||
[
|
||||
(instantx_state_dict_keys, True),
|
||||
(xlabs_state_dict_keys, False),
|
||||
(["foo"], False),
|
||||
],
|
||||
)
|
||||
def test_is_state_dict_instantx_controlnet(sd_keys: list[str], expected: bool):
|
||||
sd = {k: None for k in sd_keys}
|
||||
assert is_state_dict_instantx_controlnet(sd) == expected
|
||||
@ -1,7 +1,7 @@
|
||||
# State dict keys for an XLabs FLUX ControlNet model. Intended to be used for unit tests.
|
||||
# These keys were extracted from:
|
||||
# https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
|
||||
state_dict_keys = [
|
||||
xlabs_state_dict_keys = [
|
||||
"controlnet_blocks.0.bias",
|
||||
"controlnet_blocks.0.weight",
|
||||
"controlnet_blocks.1.bias",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user