Tidy handling of SCHEDULER_NAME_VALUES to help with circular import errors.

This commit is contained in:
Ryan Dick 2024-07-02 15:12:59 -04:00
parent 44f62944ee
commit 798e73969c
8 changed files with 16 additions and 14 deletions

View File

@ -1,6 +1,5 @@
from typing import Literal from typing import Literal
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
LATENT_SCALE_FACTOR = 8 LATENT_SCALE_FACTOR = 8
@ -11,9 +10,6 @@ factor is hard-coded to a literal '8' rather than using this constant.
The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1. The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
""" """
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
"""A literal type representing the valid scheduler names."""
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke""" """A literal type for PIL image modes supported by Invoke"""

View File

@ -17,7 +17,7 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
ConditioningField, ConditioningField,
@ -53,7 +53,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningData, TextConditioningData,
TextConditioningRegions, TextConditioningRegions,
) )
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_MAP, SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel from invokeai.backend.util.hotfixes import ControlNetModel
from invokeai.backend.util.mask import to_standard_float_mask from invokeai.backend.util.mask import to_standard_float_mask

View File

@ -1,5 +1,4 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
FieldDescriptions, FieldDescriptions,
InputField, InputField,
@ -7,6 +6,7 @@ from invokeai.app.invocations.fields import (
UIType, UIType,
) )
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@invocation_output("scheduler_output") @invocation_output("scheduler_output")

View File

@ -8,7 +8,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin
from pydantic import field_validator from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
@ -29,6 +29,7 @@ from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline, MultiDiffusionPipeline,
MultiDiffusionRegionConditioning, MultiDiffusionRegionConditioning,
) )
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.tiles.tiles import ( from invokeai.backend.tiles.tiles import (
calc_tiles_min_overlap, calc_tiles_min_overlap,
) )

View File

@ -30,9 +30,9 @@ from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from ..raw_model import RawModel from ..raw_model import RawModel

View File

@ -1,3 +0,0 @@
from .schedulers import SCHEDULER_MAP # noqa: F401
__all__ = ["SCHEDULER_MAP"]

View File

@ -1,3 +1,5 @@
from typing import Literal
from diffusers import ( from diffusers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
@ -43,3 +45,9 @@ SCHEDULER_MAP = {
"lcm": (LCMScheduler, {}), "lcm": (LCMScheduler, {}),
"tcd": (TCDScheduler, {}), "tcd": (TCDScheduler, {}),
} }
# HACK(ryand): Passing a tuple of keys to Literal works at runtime, but not at type-check time. See the docs here for
# more info: https://typing.readthedocs.io/en/latest/spec/literal.html#parameters-at-runtime. For now, we are ignoring
# this error. In the future, we should fix this type handling.
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] # type: ignore

View File

@ -11,7 +11,6 @@ from invokeai.app.invocations.baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
BoardField, BoardField,
ColorField, ColorField,
@ -78,6 +77,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningFieldData, ConditioningFieldData,
SDXLConditioningInfo, SDXLConditioningInfo,
) )
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device
from invokeai.version import __version__ from invokeai.version import __version__
@ -163,7 +163,7 @@ __all__ = [
"BaseModelType", "BaseModelType",
"ModelType", "ModelType",
"SubModelType", "SubModelType",
# invokeai.app.invocations.constants # invokeai.backend.stable_diffusion.schedulers.schedulers
"SCHEDULER_NAME_VALUES", "SCHEDULER_NAME_VALUES",
# invokeai.version # invokeai.version
"__version__", "__version__",