chore(nodes): "SAMPLER_NAME_VALUES" -> "SCHEDULER_NAME_VALUES"

This was named inaccurately.
This commit is contained in:
psychedelicious 2024-02-11 09:51:25 +11:00 committed by Brandon Rising
parent 54d92cb246
commit acc50d9bd2
2 changed files with 11 additions and 6 deletions

View File

@ -1,3 +1,7 @@
from typing import Literal
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
LATENT_SCALE_FACTOR = 8 LATENT_SCALE_FACTOR = 8
""" """
HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
@ -5,3 +9,6 @@ be addressed if future models use a different latent scale factor. Also, note th
factor is hard-coded to a literal '8' rather than using this constant. factor is hard-coded to a literal '8' rather than using this constant.
The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
""" """
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
"""A literal type representing the valid scheduler names."""

View File

@ -23,7 +23,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import field_validator from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
ConditioningField, ConditioningField,
DenoiseMaskField, DenoiseMaskField,
@ -78,12 +78,10 @@ if choose_torch_device() == torch.device("mps"):
DEFAULT_PRECISION = choose_precision(choose_torch_device()) DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
@invocation_output("scheduler_output") @invocation_output("scheduler_output")
class SchedulerOutput(BaseInvocationOutput): class SchedulerOutput(BaseInvocationOutput):
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler) scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
@invocation( @invocation(
@ -96,7 +94,7 @@ class SchedulerOutput(BaseInvocationOutput):
class SchedulerInvocation(BaseInvocation): class SchedulerInvocation(BaseInvocation):
"""Selects a scheduler.""" """Selects a scheduler."""
scheduler: SAMPLER_NAME_VALUES = InputField( scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler", default="euler",
description=FieldDescriptions.scheduler, description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler, ui_type=UIType.Scheduler,
@ -234,7 +232,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
description=FieldDescriptions.denoising_start, description=FieldDescriptions.denoising_start,
) )
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SAMPLER_NAME_VALUES = InputField( scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler", default="euler",
description=FieldDescriptions.scheduler, description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler, ui_type=UIType.Scheduler,