mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-04-27 09:12:34 +02:00
feat(ui): add control loras to control adapter model options, add default settings for preprocessor in probe
This commit is contained in:
parent
246b59f148
commit
da213e4638
@ -273,8 +273,12 @@ class LoRALyCORISConfig(LoRAConfigBase):
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||
|
||||
class ControlAdapterConfigBase(BaseModel):
|
||||
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
class ControlLoRALyCORISConfig(ModelConfigBase):
|
||||
class ControlLoRALyCORISConfig(ModelConfigBase, ControlAdapterConfigBase):
|
||||
"""Model config for Control LoRA models."""
|
||||
|
||||
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
|
||||
@ -317,10 +321,7 @@ class VAEDiffusersConfig(ModelConfigBase):
|
||||
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class ControlAdapterConfigBase(BaseModel):
|
||||
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
||||
|
||||
@ -200,8 +200,8 @@ class ModelProbe(object):
|
||||
fields["default_settings"] = fields.get("default_settings")
|
||||
|
||||
if not fields["default_settings"]:
|
||||
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter}:
|
||||
fields["default_settings"] = get_default_settings_controlnet_t2i_adapter(fields["name"])
|
||||
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter, ModelType.ControlLoRa}:
|
||||
fields["default_settings"] = get_default_settings_control_adapters(fields["name"])
|
||||
elif fields["type"] is ModelType.Main:
|
||||
fields["default_settings"] = get_default_settings_main(fields["base"])
|
||||
|
||||
@ -510,7 +510,7 @@ MODEL_NAME_TO_PREPROCESSOR = {
|
||||
}
|
||||
|
||||
|
||||
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
||||
def get_default_settings_control_adapters(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
||||
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
|
||||
model_name_lower = model_name.lower()
|
||||
if k in model_name_lower:
|
||||
|
||||
@ -26,7 +26,7 @@ import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actio
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold, PiShootingStarFill, PiUploadBold } from 'react-icons/pi';
|
||||
import type { ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import type { ControlLoRAModelConfig, ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
const buildSelectControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) =>
|
||||
createMemoizedAppSelector(selectCanvasSlice, (canvas) => {
|
||||
@ -66,7 +66,7 @@ export const ControlLayerControlAdapter = memo(() => {
|
||||
);
|
||||
|
||||
const onChangeModel = useCallback(
|
||||
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
|
||||
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig) => {
|
||||
dispatch(controlLayerModelChanged({ entityIdentifier, modelConfig }));
|
||||
// When we change the model, we need may need to start filtering w/ the simplified filter mode, and/or change the
|
||||
// filter config.
|
||||
|
||||
@ -4,22 +4,22 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useControlNetAndT2IAdapterModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { useControlLayerModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
modelKey: string | null;
|
||||
onChange: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => void;
|
||||
onChange: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig) => void;
|
||||
};
|
||||
|
||||
export const ControlLayerControlAdapterModel = memo(({ modelKey, onChange: onChangeModel }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const currentBaseModel = useAppSelector(selectBase);
|
||||
const [modelConfigs, { isLoading }] = useControlNetAndT2IAdapterModels();
|
||||
const [modelConfigs, { isLoading }] = useControlLayerModels();
|
||||
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
|
||||
|
||||
const _onChange = useCallback(
|
||||
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null) => {
|
||||
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null) => {
|
||||
if (!modelConfig) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -34,7 +34,7 @@ import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/par
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import { merge } from 'lodash-es';
|
||||
import type { UndoableOptions } from 'redux-undo';
|
||||
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import type { ControlLoRAModelConfig, ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import type {
|
||||
@ -436,7 +436,7 @@ export const canvasSlice = createSlice({
|
||||
action: PayloadAction<
|
||||
EntityIdentifierPayload<
|
||||
{
|
||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null;
|
||||
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null;
|
||||
},
|
||||
'control_layer'
|
||||
>
|
||||
|
||||
@ -2,7 +2,7 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type { ImageWithDims } from 'features/controlLayers/store/types';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import type { ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
import { z } from 'zod';
|
||||
|
||||
@ -454,7 +454,7 @@ const PROCESSOR_TO_FILTER_MAP: Record<string, FilterType> = {
|
||||
* Gets the default filter for a control model. If the model has a default, it will be used, otherwise the default
|
||||
* filter for the model type will be used.
|
||||
*/
|
||||
export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null) => {
|
||||
export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null) => {
|
||||
if (!modelConfig) {
|
||||
// No model
|
||||
return null;
|
||||
|
||||
@ -296,6 +296,12 @@ const zT2IAdapterConfig = z.object({
|
||||
});
|
||||
export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
|
||||
|
||||
const zControlLoRAConfig = z.object({
|
||||
type: z.literal('control_lora'),
|
||||
model: zServerValidatedModelIdentifierField.nullable(),
|
||||
});
|
||||
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;
|
||||
|
||||
export const zCanvasRasterLayerState = zCanvasEntityBase.extend({
|
||||
type: z.literal('raster_layer'),
|
||||
position: zCoordinate,
|
||||
@ -307,7 +313,7 @@ export type CanvasRasterLayerState = z.infer<typeof zCanvasRasterLayerState>;
|
||||
const zCanvasControlLayerState = zCanvasRasterLayerState.extend({
|
||||
type: z.literal('control_layer'),
|
||||
withTransparencyEffect: z.boolean(),
|
||||
controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig]),
|
||||
controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig, zControlLoRAConfig]),
|
||||
});
|
||||
export type CanvasControlLayerState = z.infer<typeof zCanvasControlLayerState>;
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ import {
|
||||
isCLIPVisionModelConfig,
|
||||
isControlLoRAModelConfig,
|
||||
isControlNetModelConfig,
|
||||
isControlNetOrT2IAdapterModelConfig,
|
||||
isControlLayerModelConfig,
|
||||
isFluxMainModelModelConfig,
|
||||
isFluxVAEModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
@ -60,7 +60,7 @@ export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
|
||||
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||
export const useControlLoRAModel = buildModelsHook(isControlLoRAModelConfig);
|
||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||
export const useControlLayerModels = buildModelsHook(isControlLayerModelConfig);
|
||||
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
||||
export const useT5EncoderModels = (args?: ModelHookArgs) =>
|
||||
|
||||
@ -4359,6 +4359,8 @@ export type components = {
|
||||
* @description Model config for Control LoRA models.
|
||||
*/
|
||||
ControlLoRALyCORISConfig: {
|
||||
/** @description Default settings for this model */
|
||||
default_settings?: components["schemas"]["ControlAdapterDefaultSettings"] | null;
|
||||
/**
|
||||
* Key
|
||||
* @description A unique key for this model.
|
||||
|
||||
@ -145,6 +145,10 @@ export const isControlNetModelConfig = (config: AnyModelConfig): config is Contr
|
||||
return config.type === 'controlnet';
|
||||
};
|
||||
|
||||
export const isControlLayerModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig => {
|
||||
return config.type === 'controlnet' || config.type === "t2i_adapter" || config.type === "control_lora";
|
||||
};
|
||||
|
||||
export const isIPAdapterModelConfig = (config: AnyModelConfig): config is IPAdapterModelConfig => {
|
||||
return config.type === 'ip_adapter';
|
||||
};
|
||||
|
||||
@ -10,7 +10,7 @@ from invokeai.backend.model_manager.probe import (
|
||||
CkptType,
|
||||
ModelProbe,
|
||||
VaeFolderProbe,
|
||||
get_default_settings_controlnet_t2i_adapter,
|
||||
get_default_settings_control_adapters,
|
||||
get_default_settings_main,
|
||||
)
|
||||
|
||||
@ -40,12 +40,12 @@ def test_repo_variant(datadir: Path):
|
||||
|
||||
|
||||
def test_controlnet_t2i_default_settings():
|
||||
assert get_default_settings_controlnet_t2i_adapter("some_canny_model").preprocessor == "canny_image_processor"
|
||||
assert get_default_settings_control_adapters("some_canny_model").preprocessor == "canny_image_processor"
|
||||
assert (
|
||||
get_default_settings_controlnet_t2i_adapter("some_depth_model").preprocessor == "depth_anything_image_processor"
|
||||
get_default_settings_control_adapters("some_depth_model").preprocessor == "depth_anything_image_processor"
|
||||
)
|
||||
assert get_default_settings_controlnet_t2i_adapter("some_pose_model").preprocessor == "dw_openpose_image_processor"
|
||||
assert get_default_settings_controlnet_t2i_adapter("i like turtles") is None
|
||||
assert get_default_settings_control_adapters("some_pose_model").preprocessor == "dw_openpose_image_processor"
|
||||
assert get_default_settings_control_adapters("i like turtles") is None
|
||||
|
||||
|
||||
def test_default_settings_main():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user