mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-03-12 18:00:19 +01:00
Feature: add prompt template node (#8680)
* feat(nodes): add Prompt Template node
Add a new node that applies Style Preset templates to prompts in workflows.
The node takes a style preset ID and positive/negative prompts as inputs,
then replaces {prompt} placeholders in the template with the provided prompts.
This makes Style Preset templates accessible in Workflow mode, enabling
users to apply consistent styling across their workflow-based generations.
* feat(nodes): add StylePresetField for database-driven preset selection
Adds a new StylePresetField type that enables dropdown selection of
style presets from the database in the workflow editor.
Changes:
- Add StylePresetField to backend (fields.py)
- Update Prompt Template node to use StylePresetField instead of string ID
- Add frontend field type definitions (zod schemas, type guards)
- Create StylePresetFieldInputComponent with Combobox
- Register field in InputFieldRenderer and nodesSlice
- Add translations for preset selection
* fix schema.ts on windows.
* chore(api): regenerate schema.ts after merge
---------
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
21138e5d52
commit
4cb9b8d97d
@ -243,6 +243,12 @@ class BoardField(BaseModel):
|
||||
board_id: str = Field(description="The id of the board")
|
||||
|
||||
|
||||
class StylePresetField(BaseModel):
|
||||
"""A style preset primitive field"""
|
||||
|
||||
style_preset_id: str = Field(description="The id of the style preset")
|
||||
|
||||
|
||||
class DenoiseMaskField(BaseModel):
|
||||
"""An inpaint mask field"""
|
||||
|
||||
|
||||
57
invokeai/app/invocations/prompt_template.py
Normal file
57
invokeai/app/invocations/prompt_template.py
Normal file
@ -0,0 +1,57 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import InputField, OutputField, StylePresetField, UIComponent
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("prompt_template_output")
|
||||
class PromptTemplateOutput(BaseInvocationOutput):
|
||||
"""Output for the Prompt Template node"""
|
||||
|
||||
positive_prompt: str = OutputField(description="The positive prompt with the template applied")
|
||||
negative_prompt: str = OutputField(description="The negative prompt with the template applied")
|
||||
|
||||
|
||||
@invocation(
|
||||
"prompt_template",
|
||||
title="Prompt Template",
|
||||
tags=["prompt", "template", "style", "preset"],
|
||||
category="prompt",
|
||||
version="1.0.0",
|
||||
)
|
||||
class PromptTemplateInvocation(BaseInvocation):
|
||||
"""Applies a Style Preset template to positive and negative prompts.
|
||||
|
||||
Select a Style Preset and provide positive/negative prompts. The node replaces
|
||||
{prompt} placeholders in the template with your input prompts.
|
||||
"""
|
||||
|
||||
style_preset: StylePresetField = InputField(
|
||||
description="The Style Preset to use as a template",
|
||||
)
|
||||
positive_prompt: str = InputField(
|
||||
default="",
|
||||
description="The positive prompt to insert into the template's {prompt} placeholder",
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
negative_prompt: str = InputField(
|
||||
default="",
|
||||
description="The negative prompt to insert into the template's {prompt} placeholder",
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptTemplateOutput:
|
||||
# Fetch the style preset from the database
|
||||
style_preset = context._services.style_preset_records.get(self.style_preset.style_preset_id)
|
||||
|
||||
# Get the template prompts
|
||||
positive_template = style_preset.preset_data.positive_prompt
|
||||
negative_template = style_preset.preset_data.negative_prompt
|
||||
|
||||
# Replace {prompt} placeholder with the input prompts
|
||||
rendered_positive = positive_template.replace("{prompt}", self.positive_prompt)
|
||||
rendered_negative = negative_template.replace("{prompt}", self.negative_prompt)
|
||||
|
||||
return PromptTemplateOutput(
|
||||
positive_prompt=rendered_positive,
|
||||
negative_prompt=rendered_negative,
|
||||
)
|
||||
@ -2676,7 +2676,9 @@
|
||||
"useForTemplate": "Use For Prompt Template",
|
||||
"viewList": "View Template List",
|
||||
"viewModeTooltip": "This is how your prompt will look with your currently selected template. To edit your prompt, click anywhere in the text box.",
|
||||
"togglePromptPreviews": "Toggle Prompt Previews"
|
||||
"togglePromptPreviews": "Toggle Prompt Previews",
|
||||
"selectPreset": "Select Style Preset",
|
||||
"noMatchingPresets": "No matching presets"
|
||||
},
|
||||
|
||||
"ui": {
|
||||
|
||||
@ -55,6 +55,8 @@ import {
|
||||
isStringFieldInputTemplate,
|
||||
isStringGeneratorFieldInputInstance,
|
||||
isStringGeneratorFieldInputTemplate,
|
||||
isStylePresetFieldInputInstance,
|
||||
isStylePresetFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { memo } from 'react';
|
||||
@ -67,6 +69,7 @@ import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
|
||||
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
|
||||
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
|
||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||
import StylePresetFieldInputComponent from './inputs/StylePresetFieldInputComponent';
|
||||
|
||||
type Props = {
|
||||
nodeId: string;
|
||||
@ -206,6 +209,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
|
||||
return <BoardFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isStylePresetFieldInputTemplate(template)) {
|
||||
if (!isStylePresetFieldInputInstance(field)) {
|
||||
return null;
|
||||
}
|
||||
return <StylePresetFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
|
||||
}
|
||||
|
||||
if (isModelIdentifierFieldInputTemplate(template)) {
|
||||
if (!isModelIdentifierFieldInputInstance(field)) {
|
||||
return null;
|
||||
|
||||
@ -0,0 +1,73 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { fieldStylePresetValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { StylePresetFieldInputInstance, StylePresetFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useListStylePresetsQuery } from 'services/api/endpoints/stylePresets';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const StylePresetFieldInputComponent = (
|
||||
props: FieldComponentProps<StylePresetFieldInputInstance, StylePresetFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { data: stylePresets, isLoading } = useListStylePresetsQuery();
|
||||
|
||||
const options = useMemo<ComboboxOption[]>(() => {
|
||||
const _options: ComboboxOption[] = [];
|
||||
if (stylePresets) {
|
||||
for (const preset of stylePresets) {
|
||||
_options.push({
|
||||
label: preset.name,
|
||||
value: preset.id,
|
||||
});
|
||||
}
|
||||
}
|
||||
return _options;
|
||||
}, [stylePresets]);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldStylePresetValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: { style_preset_id: v.value },
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const value = useMemo(() => {
|
||||
const _value = field.value;
|
||||
if (!_value) {
|
||||
return null;
|
||||
}
|
||||
return options.find((o) => o.value === _value.style_preset_id) ?? null;
|
||||
}, [field.value, options]);
|
||||
|
||||
const noOptionsMessage = useCallback(() => t('stylePresets.noMatchingPresets'), [t]);
|
||||
|
||||
return (
|
||||
<Combobox
|
||||
className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`}
|
||||
value={value}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
placeholder={isLoading ? t('common.loading') : t('stylePresets.selectPreset')}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(StylePresetFieldInputComponent);
|
||||
@ -41,6 +41,7 @@ import type {
|
||||
StringFieldCollectionValue,
|
||||
StringFieldValue,
|
||||
StringGeneratorFieldValue,
|
||||
StylePresetFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
zBoardFieldValue,
|
||||
@ -62,6 +63,7 @@ import {
|
||||
zStringFieldCollectionValue,
|
||||
zStringFieldValue,
|
||||
zStringGeneratorFieldValue,
|
||||
zStylePresetFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
@ -438,6 +440,9 @@ const slice = createSlice({
|
||||
fieldBoardValueChanged: (state, action: FieldValueAction<BoardFieldValue>) => {
|
||||
fieldValueReducer(state, action, zBoardFieldValue);
|
||||
},
|
||||
fieldStylePresetValueChanged: (state, action: FieldValueAction<StylePresetFieldValue>) => {
|
||||
fieldValueReducer(state, action, zStylePresetFieldValue);
|
||||
},
|
||||
fieldImageValueChanged: (state, action: FieldValueAction<ImageFieldValue>) => {
|
||||
fieldValueReducer(state, action, zImageFieldValue);
|
||||
},
|
||||
@ -588,6 +593,7 @@ export const {
|
||||
fieldBoardValueChanged,
|
||||
fieldBooleanValueChanged,
|
||||
fieldColorValueChanged,
|
||||
fieldStylePresetValueChanged,
|
||||
fieldEnumModelValueChanged,
|
||||
fieldImageValueChanged,
|
||||
fieldImageCollectionValueChanged,
|
||||
|
||||
@ -16,6 +16,10 @@ export const zBoardField = z.object({
|
||||
});
|
||||
export type BoardField = z.infer<typeof zBoardField>;
|
||||
|
||||
export const zStylePresetField = z.object({
|
||||
style_preset_id: z.string().trim().min(1),
|
||||
});
|
||||
|
||||
export const zColorField = z.object({
|
||||
r: z.number().int().min(0).max(255),
|
||||
g: z.number().int().min(0).max(255),
|
||||
|
||||
@ -35,6 +35,7 @@ export const NO_PAN_CLASS = 'nopan';
|
||||
export const FIELD_COLORS: { [key: string]: string } = {
|
||||
BoardField: 'purple.500',
|
||||
BooleanField: 'green.500',
|
||||
StylePresetField: 'purple.400',
|
||||
CLIPField: 'green.500',
|
||||
ColorField: 'pink.300',
|
||||
ConditioningField: 'cyan.500',
|
||||
|
||||
@ -19,6 +19,7 @@ import {
|
||||
zModelIdentifierField,
|
||||
zModelType,
|
||||
zSchedulerField,
|
||||
zStylePresetField,
|
||||
} from './common';
|
||||
|
||||
/**
|
||||
@ -169,6 +170,11 @@ const zBoardFieldType = zFieldTypeBase.extend({
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
|
||||
const zStylePresetFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('StylePresetField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
|
||||
const zColorFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ColorField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@ -205,6 +211,7 @@ const zStatefulFieldType = z.union([
|
||||
zEnumFieldType,
|
||||
zImageFieldType,
|
||||
zBoardFieldType,
|
||||
zStylePresetFieldType,
|
||||
zModelIdentifierFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
@ -607,6 +614,27 @@ export const isBoardFieldInputInstance = buildInstanceTypeGuard(zBoardFieldInput
|
||||
export const isBoardFieldInputTemplate = buildTemplateTypeGuard<BoardFieldInputTemplate>('BoardField');
|
||||
// #endregion
|
||||
|
||||
// #region StylePresetField
|
||||
export const zStylePresetFieldValue = zStylePresetField.optional();
|
||||
const zStylePresetFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zStylePresetFieldValue,
|
||||
});
|
||||
const zStylePresetFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zStylePresetFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zStylePresetFieldValue,
|
||||
});
|
||||
const zStylePresetFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStylePresetFieldType,
|
||||
});
|
||||
export type StylePresetFieldValue = z.infer<typeof zStylePresetFieldValue>;
|
||||
export type StylePresetFieldInputInstance = z.infer<typeof zStylePresetFieldInputInstance>;
|
||||
export type StylePresetFieldInputTemplate = z.infer<typeof zStylePresetFieldInputTemplate>;
|
||||
export const isStylePresetFieldInputInstance = buildInstanceTypeGuard(zStylePresetFieldInputInstance);
|
||||
export const isStylePresetFieldInputTemplate =
|
||||
buildTemplateTypeGuard<StylePresetFieldInputTemplate>('StylePresetField');
|
||||
// #endregion
|
||||
|
||||
// #region ColorField
|
||||
export const zColorFieldValue = zColorField.optional();
|
||||
const zColorFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
@ -1257,6 +1285,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zImageFieldValue,
|
||||
zImageFieldCollectionValue,
|
||||
zBoardFieldValue,
|
||||
zStylePresetFieldValue,
|
||||
zModelIdentifierFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
@ -1284,6 +1313,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zImageFieldInputInstance,
|
||||
zImageFieldCollectionInputInstance,
|
||||
zBoardFieldInputInstance,
|
||||
zStylePresetFieldInputInstance,
|
||||
zModelIdentifierFieldInputInstance,
|
||||
zColorFieldInputInstance,
|
||||
zSchedulerFieldInputInstance,
|
||||
@ -1310,6 +1340,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zImageFieldInputTemplate,
|
||||
zImageFieldCollectionInputTemplate,
|
||||
zBoardFieldInputTemplate,
|
||||
zStylePresetFieldInputTemplate,
|
||||
zModelIdentifierFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
@ -1337,6 +1368,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zImageFieldOutputTemplate,
|
||||
zImageFieldCollectionOutputTemplate,
|
||||
zBoardFieldOutputTemplate,
|
||||
zStylePresetFieldOutputTemplate,
|
||||
zModelIdentifierFieldOutputTemplate,
|
||||
zColorFieldOutputTemplate,
|
||||
zSchedulerFieldOutputTemplate,
|
||||
|
||||
@ -12,6 +12,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
ModelIdentifierField: undefined,
|
||||
SchedulerField: 'dpmpp_3m_k',
|
||||
StringField: '',
|
||||
StylePresetField: undefined,
|
||||
FloatGeneratorField: undefined,
|
||||
IntegerGeneratorField: undefined,
|
||||
StringGeneratorField: undefined,
|
||||
|
||||
@ -23,6 +23,7 @@ import type {
|
||||
StringFieldCollectionInputTemplate,
|
||||
StringFieldInputTemplate,
|
||||
StringGeneratorFieldInputTemplate,
|
||||
StylePresetFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
getFloatGeneratorArithmeticSequenceDefaults,
|
||||
@ -289,6 +290,20 @@ const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTem
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStylePresetFieldInputTemplate: FieldInputTemplateBuilder<StylePresetFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: StylePresetFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImageFieldInputTemplate: FieldInputTemplateBuilder<ImageFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -460,6 +475,7 @@ const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplate
|
||||
ModelIdentifierField: buildModelIdentifierFieldInputTemplate,
|
||||
SchedulerField: buildSchedulerFieldInputTemplate,
|
||||
StringField: buildStringFieldInputTemplate,
|
||||
StylePresetField: buildStylePresetFieldInputTemplate,
|
||||
FloatGeneratorField: buildFloatGeneratorFieldInputTemplate,
|
||||
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,
|
||||
StringGeneratorField: buildStringGeneratorFieldInputTemplate,
|
||||
|
||||
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user