feat(ui): add control loras to control adapter model options, add default settings for preprocessor in probe

This commit is contained in:
Mary Hipp 2024-12-12 21:07:16 -05:00 committed by Kent Keirsey
parent 246b59f148
commit da213e4638
11 changed files with 40 additions and 27 deletions

View File

@ -273,8 +273,12 @@ class LoRALyCORISConfig(LoRAConfigBase):
def get_tag() -> Tag:
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
class ControlAdapterConfigBase(BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
)
class ControlLoRALyCORISConfig(ModelConfigBase):
class ControlLoRALyCORISConfig(ModelConfigBase, ControlAdapterConfigBase):
"""Model config for Control LoRA models."""
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
@ -317,10 +321,7 @@ class VAEDiffusersConfig(ModelConfigBase):
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
class ControlAdapterConfigBase(BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
)
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):

View File

@ -200,8 +200,8 @@ class ModelProbe(object):
fields["default_settings"] = fields.get("default_settings")
if not fields["default_settings"]:
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter}:
fields["default_settings"] = get_default_settings_controlnet_t2i_adapter(fields["name"])
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter, ModelType.ControlLoRa}:
fields["default_settings"] = get_default_settings_control_adapters(fields["name"])
elif fields["type"] is ModelType.Main:
fields["default_settings"] = get_default_settings_main(fields["base"])
@ -510,7 +510,7 @@ MODEL_NAME_TO_PREPROCESSOR = {
}
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
def get_default_settings_control_adapters(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
model_name_lower = model_name.lower()
if k in model_name_lower:

View File

@ -26,7 +26,7 @@ import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actio
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiShootingStarFill, PiUploadBold } from 'react-icons/pi';
import type { ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlLoRAModelConfig, ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
const buildSelectControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) =>
createMemoizedAppSelector(selectCanvasSlice, (canvas) => {
@ -66,7 +66,7 @@ export const ControlLayerControlAdapter = memo(() => {
);
const onChangeModel = useCallback(
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig) => {
dispatch(controlLayerModelChanged({ entityIdentifier, modelConfig }));
// When we change the model, we need may need to start filtering w/ the simplified filter mode, and/or change the
// filter config.

View File

@ -4,22 +4,22 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useControlNetAndT2IAdapterModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { useControlLayerModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
type Props = {
modelKey: string | null;
onChange: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => void;
onChange: (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig) => void;
};
export const ControlLayerControlAdapterModel = memo(({ modelKey, onChange: onChangeModel }: Props) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useControlNetAndT2IAdapterModels();
const [modelConfigs, { isLoading }] = useControlLayerModels();
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
const _onChange = useCallback(
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null) => {
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null) => {
if (!modelConfig) {
return;
}

View File

@ -34,7 +34,7 @@ import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/par
import type { IRect } from 'konva/lib/types';
import { merge } from 'lodash-es';
import type { UndoableOptions } from 'redux-undo';
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlLoRAModelConfig, ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import type {
@ -436,7 +436,7 @@ export const canvasSlice = createSlice({
action: PayloadAction<
EntityIdentifierPayload<
{
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null;
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null;
},
'control_layer'
>

View File

@ -2,7 +2,7 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { ImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlLoRAModelConfig, ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { z } from 'zod';
@ -454,7 +454,7 @@ const PROCESSOR_TO_FILTER_MAP: Record<string, FilterType> = {
* Gets the default filter for a control model. If the model has a default, it will be used, otherwise the default
* filter for the model type will be used.
*/
export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null) => {
export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig | null) => {
if (!modelConfig) {
// No model
return null;

View File

@ -296,6 +296,12 @@ const zT2IAdapterConfig = z.object({
});
export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
const zControlLoRAConfig = z.object({
type: z.literal('control_lora'),
model: zServerValidatedModelIdentifierField.nullable(),
});
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;
export const zCanvasRasterLayerState = zCanvasEntityBase.extend({
type: z.literal('raster_layer'),
position: zCoordinate,
@ -307,7 +313,7 @@ export type CanvasRasterLayerState = z.infer<typeof zCanvasRasterLayerState>;
const zCanvasControlLayerState = zCanvasRasterLayerState.extend({
type: z.literal('control_layer'),
withTransparencyEffect: z.boolean(),
controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig]),
controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig, zControlLoRAConfig]),
});
export type CanvasControlLayerState = z.infer<typeof zCanvasControlLayerState>;

View File

@ -13,7 +13,7 @@ import {
isCLIPVisionModelConfig,
isControlLoRAModelConfig,
isControlNetModelConfig,
isControlNetOrT2IAdapterModelConfig,
isControlLayerModelConfig,
isFluxMainModelModelConfig,
isFluxVAEModelConfig,
isIPAdapterModelConfig,
@ -60,7 +60,7 @@ export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlLoRAModel = buildModelsHook(isControlLoRAModelConfig);
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
export const useControlLayerModels = buildModelsHook(isControlLayerModelConfig);
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
export const useT5EncoderModels = (args?: ModelHookArgs) =>

View File

@ -4359,6 +4359,8 @@ export type components = {
* @description Model config for Control LoRA models.
*/
ControlLoRALyCORISConfig: {
/** @description Default settings for this model */
default_settings?: components["schemas"]["ControlAdapterDefaultSettings"] | null;
/**
* Key
* @description A unique key for this model.

View File

@ -145,6 +145,10 @@ export const isControlNetModelConfig = (config: AnyModelConfig): config is Contr
return config.type === 'controlnet';
};
export const isControlLayerModelConfig = (config: AnyModelConfig): config is ControlNetModelConfig | T2IAdapterModelConfig | ControlLoRAModelConfig => {
return config.type === 'controlnet' || config.type === "t2i_adapter" || config.type === "control_lora";
};
export const isIPAdapterModelConfig = (config: AnyModelConfig): config is IPAdapterModelConfig => {
return config.type === 'ip_adapter';
};

View File

@ -10,7 +10,7 @@ from invokeai.backend.model_manager.probe import (
CkptType,
ModelProbe,
VaeFolderProbe,
get_default_settings_controlnet_t2i_adapter,
get_default_settings_control_adapters,
get_default_settings_main,
)
@ -40,12 +40,12 @@ def test_repo_variant(datadir: Path):
def test_controlnet_t2i_default_settings():
assert get_default_settings_controlnet_t2i_adapter("some_canny_model").preprocessor == "canny_image_processor"
assert get_default_settings_control_adapters("some_canny_model").preprocessor == "canny_image_processor"
assert (
get_default_settings_controlnet_t2i_adapter("some_depth_model").preprocessor == "depth_anything_image_processor"
get_default_settings_control_adapters("some_depth_model").preprocessor == "depth_anything_image_processor"
)
assert get_default_settings_controlnet_t2i_adapter("some_pose_model").preprocessor == "dw_openpose_image_processor"
assert get_default_settings_controlnet_t2i_adapter("i like turtles") is None
assert get_default_settings_control_adapters("some_pose_model").preprocessor == "dw_openpose_image_processor"
assert get_default_settings_control_adapters("i like turtles") is None
def test_default_settings_main():