Fix static type errors with SCHEDULER_NAME_VALUES. And, avoid bi-directional cross-directory imports, which contribute to circular import issues.

This commit is contained in:
Ryan Dick 2024-07-03 11:13:16 -04:00 committed by Kent Keirsey
parent 3a24d70279
commit 35f8781ea2
8 changed files with 49 additions and 11 deletions

View File

@ -1,6 +1,5 @@
from typing import Literal
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
LATENT_SCALE_FACTOR = 8
@ -11,9 +10,6 @@ factor is hard-coded to a literal '8' rather than using this constant.
The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
"""
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
"""A literal type representing the valid scheduler names."""
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke"""

View File

@ -17,7 +17,7 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import (
ConditioningField,
@ -54,6 +54,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
from invokeai.backend.util.mask import to_standard_float_mask

View File

@ -1,5 +1,4 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
FieldDescriptions,
InputField,
@ -7,6 +6,7 @@ from invokeai.app.invocations.fields import (
UIType,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@invocation_output("scheduler_output")

View File

@ -8,7 +8,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
@ -29,6 +29,7 @@ from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
MultiDiffusionRegionConditioning,
)
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.tiles.tiles import (
calc_tiles_min_overlap,
)

View File

@ -30,10 +30,10 @@ from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.raw_model import RawModel
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime

View File

@ -1,3 +1,5 @@
from typing import Any, Literal, Type
from diffusers import (
DDIMScheduler,
DDPMScheduler,
@ -16,8 +18,36 @@ from diffusers import (
TCDScheduler,
UniPCMultistepScheduler,
)
from diffusers.schedulers.scheduling_utils import SchedulerMixin
SCHEDULER_MAP = {
SCHEDULER_NAME_VALUES = Literal[
"ddim",
"ddpm",
"deis",
"lms",
"lms_k",
"pndm",
"heun",
"heun_k",
"euler",
"euler_k",
"euler_a",
"kdpm_2",
"kdpm_2_a",
"dpmpp_2s",
"dpmpp_2s_k",
"dpmpp_2m",
"dpmpp_2m_k",
"dpmpp_2m_sde",
"dpmpp_2m_sde_k",
"dpmpp_sde",
"dpmpp_sde_k",
"unipc",
"lcm",
"tcd",
]
SCHEDULER_MAP: dict[SCHEDULER_NAME_VALUES, tuple[Type[SchedulerMixin], dict[str, Any]]] = {
"ddim": (DDIMScheduler, {}),
"ddpm": (DDPMScheduler, {}),
"deis": (DEISMultistepScheduler, {}),

View File

@ -11,7 +11,6 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
BoardField,
ColorField,
@ -78,6 +77,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningFieldData,
SDXLConditioningInfo,
)
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device
from invokeai.version import __version__
@ -163,7 +163,7 @@ __all__ = [
"BaseModelType",
"ModelType",
"SubModelType",
# invokeai.app.invocations.constants
# invokeai.backend.stable_diffusion.schedulers.schedulers
"SCHEDULER_NAME_VALUES",
# invokeai.version
"__version__",

View File

@ -0,0 +1,10 @@
from typing import get_args
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_MAP, SCHEDULER_NAME_VALUES
def test_scheduler_map_has_all_keys():
# Assert that SCHEDULER_MAP has all keys from SCHEDULER_NAME_VALUES.
# TODO(ryand): This feels like it should be a type check, but I couldn't find a clean way to do this and didn't want
# to spend more time on it.
assert set(SCHEDULER_MAP.keys()) == set(get_args(SCHEDULER_NAME_VALUES))