diff --git a/invokeai/app/invocations/constants.py b/invokeai/app/invocations/constants.py index 95b16f0d05..795e7a3b60 100644 --- a/invokeai/app/invocations/constants.py +++ b/invokeai/app/invocations/constants.py @@ -1,3 +1,7 @@ +from typing import Literal + +from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP + LATENT_SCALE_FACTOR = 8 """ HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to @@ -5,3 +9,6 @@ be addressed if future models use a different latent scale factor. Also, note th factor is hard-coded to a literal '8' rather than using this constant. The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. """ + +SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] +"""A literal type representing the valid scheduler names.""" diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index fedfc38402..69e3f055ca 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -23,7 +23,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize -from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES from invokeai.app.invocations.fields import ( ConditioningField, DenoiseMaskField, @@ -78,12 +78,10 @@ if choose_torch_device() == torch.device("mps"): DEFAULT_PRECISION = choose_precision(choose_torch_device()) -SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] - @invocation_output("scheduler_output") class SchedulerOutput(BaseInvocationOutput): - scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) + scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) @invocation( @@ -96,7 +94,7 @@ class SchedulerOutput(BaseInvocationOutput): class SchedulerInvocation(BaseInvocation): """Selects a scheduler.""" - scheduler: SAMPLER_NAME_VALUES = InputField( + scheduler: SCHEDULER_NAME_VALUES = InputField( default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler, @@ -234,7 +232,7 @@ class DenoiseLatentsInvocation(BaseInvocation): description=FieldDescriptions.denoising_start, ) denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) - scheduler: SAMPLER_NAME_VALUES = InputField( + scheduler: SCHEDULER_NAME_VALUES = InputField( default="euler", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler,