From da213e4638c3ebf657e84e570b02709f46794d57 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Thu, 12 Dec 2024 21:07:16 -0500 Subject: [PATCH] feat(ui): add control loras to control adapter model options, add default settings for preprocessor in probe --- invokeai/backend/model_manager/config.py | 11 ++++++----- invokeai/backend/model_manager/probe.py | 6 +++--- .../ControlLayer/ControlLayerControlAdapter.tsx | 4 ++-- .../ControlLayer/ControlLayerControlAdapterModel.tsx | 10 +++++----- .../src/features/controlLayers/store/canvasSlice.ts | 4 ++-- .../web/src/features/controlLayers/store/filters.ts | 4 ++-- .../web/src/features/controlLayers/store/types.ts | 8 +++++++- .../web/src/services/api/hooks/modelsByType.ts | 4 ++-- invokeai/frontend/web/src/services/api/schema.ts | 2 ++ invokeai/frontend/web/src/services/api/types.ts | 4 ++++ tests/test_model_probe.py | 10 +++++----- 11 files changed, 40 insertions(+), 27 deletions(-) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index c5eadfa037..65255c9f18 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -273,8 +273,12 @@ class LoRALyCORISConfig(LoRAConfigBase): def get_tag() -> Tag: return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}") +class ControlAdapterConfigBase(BaseModel): + default_settings: Optional[ControlAdapterDefaultSettings] = Field( + description="Default settings for this model", default=None + ) -class ControlLoRALyCORISConfig(ModelConfigBase): +class ControlLoRALyCORISConfig(ModelConfigBase, ControlAdapterConfigBase): """Model config for Control LoRA models.""" type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa @@ -317,10 +321,7 @@ class VAEDiffusersConfig(ModelConfigBase): return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}") -class ControlAdapterConfigBase(BaseModel): - default_settings: Optional[ControlAdapterDefaultSettings] = Field( - description="Default settings for this model", default=None - ) + class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase): diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index a8d4b6a270..268c94f410 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -200,8 +200,8 @@ class ModelProbe(object): fields["default_settings"] = fields.get("default_settings") if not fields["default_settings"]: - if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter}: - fields["default_settings"] = get_default_settings_controlnet_t2i_adapter(fields["name"]) + if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter, ModelType.ControlLoRa}: + fields["default_settings"] = get_default_settings_control_adapters(fields["name"]) elif fields["type"] is ModelType.Main: fields["default_settings"] = get_default_settings_main(fields["base"]) @@ -510,7 +510,7 @@ MODEL_NAME_TO_PREPROCESSOR = { } -def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]: +def get_default_settings_control_adapters(model_name: str) -> Optional[ControlAdapterDefaultSettings]: for k, v in MODEL_NAME_TO_PREPROCESSOR.items(): model_name_lower = model_name.lower() if k in model_name_lower: diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapter.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapter.tsx index 3bc79602d5..046567596b 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapter.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapter.tsx @@ -26,7 +26,7 @@ import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actio import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiBoundingBoxBold, PiShootingStarFill, PiUploadBold } from 'react-icons/pi'; -import type { ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types'; +import type { ControlLoRAModelConfig, ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types'; const buildSelectControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => createMemoizedAppSelector(selectCanvasSlice, (canvas) => { @@ -66,7 +66,7 @@ export const ControlLayerControlAdapter = memo(() => { ); const onChangeModel = useCallback( - (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => { + (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig) => { dispatch(controlLayerModelChanged({ entityIdentifier, modelConfig })); // When we change the model, we need may need to start filtering w/ the simplified filter mode, and/or change the // filter config. diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapterModel.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapterModel.tsx index aaed9031ed..2cc8662dbb 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapterModel.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapterModel.tsx @@ -4,22 +4,22 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { selectBase } from 'features/controlLayers/store/paramsSlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { useControlNetAndT2IAdapterModels } from 'services/api/hooks/modelsByType'; -import type { AnyModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types'; +import { useControlLayerModels } from 'services/api/hooks/modelsByType'; +import type { AnyModelConfig, ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types'; type Props = { modelKey: string | null; - onChange: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => void; + onChange: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig) => void; }; export const ControlLayerControlAdapterModel = memo(({ modelKey, onChange: onChangeModel }: Props) => { const { t } = useTranslation(); const currentBaseModel = useAppSelector(selectBase); - const [modelConfigs, { isLoading }] = useControlNetAndT2IAdapterModels(); + const [modelConfigs, { isLoading }] = useControlLayerModels(); const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]); const _onChange = useCallback( - (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null) => { + (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null) => { if (!modelConfig) { return; } diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 500bc92dcf..93f1108fe6 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -34,7 +34,7 @@ import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/par import type { IRect } from 'konva/lib/types'; import { merge } from 'lodash-es'; import type { UndoableOptions } from 'redux-undo'; -import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; +import type { ControlLoRAModelConfig, ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; import type { @@ -436,7 +436,7 @@ export const canvasSlice = createSlice({ action: PayloadAction< EntityIdentifierPayload< { - modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null; + modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null; }, 'control_layer' > diff --git a/invokeai/frontend/web/src/features/controlLayers/store/filters.ts b/invokeai/frontend/web/src/features/controlLayers/store/filters.ts index b3d186197b..d6560d453e 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/filters.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/filters.ts @@ -2,7 +2,7 @@ import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { ImageWithDims } from 'features/controlLayers/store/types'; import { zModelIdentifierField } from 'features/nodes/types/common'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; -import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types'; +import type { ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; import { z } from 'zod'; @@ -454,7 +454,7 @@ const PROCESSOR_TO_FILTER_MAP: Record = { * Gets the default filter for a control model. If the model has a default, it will be used, otherwise the default * filter for the model type will be used. */ -export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null) => { +export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null) => { if (!modelConfig) { // No model return null; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 45665b262b..a04637b3af 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -296,6 +296,12 @@ const zT2IAdapterConfig = z.object({ }); export type T2IAdapterConfig = z.infer; +const zControlLoRAConfig = z.object({ + type: z.literal('control_lora'), + model: zServerValidatedModelIdentifierField.nullable(), +}); +export type ControlLoRAConfig = z.infer; + export const zCanvasRasterLayerState = zCanvasEntityBase.extend({ type: z.literal('raster_layer'), position: zCoordinate, @@ -307,7 +313,7 @@ export type CanvasRasterLayerState = z.infer; const zCanvasControlLayerState = zCanvasRasterLayerState.extend({ type: z.literal('control_layer'), withTransparencyEffect: z.boolean(), - controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig]), + controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig, zControlLoRAConfig]), }); export type CanvasControlLayerState = z.infer; diff --git a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts index 6d0e2ead99..2f5b72593f 100644 --- a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts +++ b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts @@ -13,7 +13,7 @@ import { isCLIPVisionModelConfig, isControlLoRAModelConfig, isControlNetModelConfig, - isControlNetOrT2IAdapterModelConfig, + isControlLayerModelConfig, isFluxMainModelModelConfig, isFluxVAEModelConfig, isIPAdapterModelConfig, @@ -60,7 +60,7 @@ export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig); export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig); export const useLoRAModels = buildModelsHook(isLoRAModelConfig); export const useControlLoRAModel = buildModelsHook(isControlLoRAModelConfig); -export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig); +export const useControlLayerModels = buildModelsHook(isControlLayerModelConfig); export const useControlNetModels = buildModelsHook(isControlNetModelConfig); export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig); export const useT5EncoderModels = (args?: ModelHookArgs) => diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index e927b58adc..37e5db14e6 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -4359,6 +4359,8 @@ export type components = { * @description Model config for Control LoRA models. */ ControlLoRALyCORISConfig: { + /** @description Default settings for this model */ + default_settings?: components["schemas"]["ControlAdapterDefaultSettings"] | null; /** * Key * @description A unique key for this model. diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 8574b4cba2..1adf089a7c 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -145,6 +145,10 @@ export const isControlNetModelConfig = (config: AnyModelConfig): config is Contr return config.type === 'controlnet'; }; +export const isControlLayerModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig => { + return config.type === 'controlnet' || config.type === "t2i_adapter" || config.type === "control_lora"; +}; + export const isIPAdapterModelConfig = (config: AnyModelConfig): config is IPAdapterModelConfig => { return config.type === 'ip_adapter'; }; diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index f41bec165b..24826f3203 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -10,7 +10,7 @@ from invokeai.backend.model_manager.probe import ( CkptType, ModelProbe, VaeFolderProbe, - get_default_settings_controlnet_t2i_adapter, + get_default_settings_control_adapters, get_default_settings_main, ) @@ -40,12 +40,12 @@ def test_repo_variant(datadir: Path): def test_controlnet_t2i_default_settings(): - assert get_default_settings_controlnet_t2i_adapter("some_canny_model").preprocessor == "canny_image_processor" + assert get_default_settings_control_adapters("some_canny_model").preprocessor == "canny_image_processor" assert ( - get_default_settings_controlnet_t2i_adapter("some_depth_model").preprocessor == "depth_anything_image_processor" + get_default_settings_control_adapters("some_depth_model").preprocessor == "depth_anything_image_processor" ) - assert get_default_settings_controlnet_t2i_adapter("some_pose_model").preprocessor == "dw_openpose_image_processor" - assert get_default_settings_controlnet_t2i_adapter("i like turtles") is None + assert get_default_settings_control_adapters("some_pose_model").preprocessor == "dw_openpose_image_processor" + assert get_default_settings_control_adapters("i like turtles") is None def test_default_settings_main():