From 24d44ca5591880584352b5decf132b7becdf0ba7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 30 Aug 2023 22:35:47 +1000 Subject: [PATCH] feat(nodes): add scheduler invocation --- invokeai/app/invocations/latent.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 31dde74a09..8585fbe531 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -49,12 +49,15 @@ from ...backend.util.devices import choose_precision, choose_torch_device from ..models.image import ImageCategory, ResourceOrigin from .baseinvocation import ( BaseInvocation, + BaseInvocationOutput, FieldDescriptions, Input, InputField, InvocationContext, + OutputField, UIType, invocation, + invocation_output, ) from .compel import ConditioningField from .controlnet_image_processors import ControlField @@ -66,6 +69,23 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device()) SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] +@invocation_output("scheduler_output") +class SchedulerOutput(BaseInvocationOutput): + scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) + + +@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents") +class SchedulerInvocation(BaseInvocation): + """Selects a scheduler.""" + + scheduler: SAMPLER_NAME_VALUES = InputField( + default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler + ) + + def invoke(self, context: InvocationContext) -> SchedulerOutput: + return SchedulerOutput(scheduler=self.scheduler) + + @invocation("create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents") class CreateDenoiseMaskInvocation(BaseInvocation): """Creates mask for denoising model run."""