diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index deeb484aea..3851caa647 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -24,7 +24,6 @@ from invokeai.app.invocations.fields import ( Input, InputField, LatentsField, - OutputField, UIType, ) from invokeai.app.invocations.ip_adapter import IPAdapterField @@ -56,38 +55,13 @@ from ...backend.stable_diffusion.diffusers_pipeline import ( ) from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.util.devices import TorchDevice -from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output +from .baseinvocation import BaseInvocation, invocation from .controlnet_image_processors import ControlField from .model import ModelIdentifierField, UNetField DEFAULT_PRECISION = TorchDevice.choose_torch_dtype() -@invocation_output("scheduler_output") -class SchedulerOutput(BaseInvocationOutput): - scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) - - -@invocation( - "scheduler", - title="Scheduler", - tags=["scheduler"], - category="latents", - version="1.0.0", -) -class SchedulerInvocation(BaseInvocation): - """Selects a scheduler.""" - - scheduler: SCHEDULER_NAME_VALUES = InputField( - default="euler", - description=FieldDescriptions.scheduler, - ui_type=UIType.Scheduler, - ) - - def invoke(self, context: InvocationContext) -> SchedulerOutput: - return SchedulerOutput(scheduler=self.scheduler) - - def get_scheduler( context: InvocationContext, scheduler_info: ModelIdentifierField, diff --git a/invokeai/app/invocations/scheduler.py b/invokeai/app/invocations/scheduler.py new file mode 100644 index 0000000000..52af20378e --- /dev/null +++ b/invokeai/app/invocations/scheduler.py @@ -0,0 +1,34 @@ +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output +from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES +from invokeai.app.invocations.fields import ( + FieldDescriptions, + InputField, + OutputField, + UIType, +) +from invokeai.app.services.shared.invocation_context import InvocationContext + + +@invocation_output("scheduler_output") +class SchedulerOutput(BaseInvocationOutput): + scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) + + +@invocation( + "scheduler", + title="Scheduler", + tags=["scheduler"], + category="latents", + version="1.0.0", +) +class SchedulerInvocation(BaseInvocation): + """Selects a scheduler.""" + + scheduler: SCHEDULER_NAME_VALUES = InputField( + default="euler", + description=FieldDescriptions.scheduler, + ui_type=UIType.Scheduler, + ) + + def invoke(self, context: InvocationContext) -> SchedulerOutput: + return SchedulerOutput(scheduler=self.scheduler)