From 2f16a2c35d3f7cc213bbcb954581960de077ed53 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 29 Sep 2023 00:09:07 -0400 Subject: [PATCH] fix migrate script and type mismatches in probe, config and loader --- invokeai/app/services/config/base.py | 5 +- .../app/services/config/invokeai_config.py | 5 +- invokeai/backend/model_manager/cache.py | 16 ++---- invokeai/backend/model_manager/config.py | 21 ++++---- invokeai/backend/model_manager/install.py | 6 +-- invokeai/backend/model_manager/loader.py | 33 ++++++------- invokeai/backend/model_manager/models/base.py | 9 ++++ invokeai/backend/model_manager/probe.py | 49 +++++++++++-------- .../backend/model_manager/storage/base.py | 6 +-- .../backend/model_manager/storage/migrate.py | 15 ++++-- invokeai/backend/model_manager/storage/sql.py | 4 +- .../backend/model_manager/storage/yaml.py | 4 +- invokeai/backend/util/__init__.py | 2 + invokeai/backend/util/logging.py | 4 +- tests/AA_nodes/test_session_queue.py | 1 + 15 files changed, 96 insertions(+), 84 deletions(-) diff --git a/invokeai/app/services/config/base.py b/invokeai/app/services/config/base.py index f24879af05..7268a93b82 100644 --- a/invokeai/app/services/config/base.py +++ b/invokeai/app/services/config/base.py @@ -25,6 +25,7 @@ from pydantic import BaseSettings class PagingArgumentParser(argparse.ArgumentParser): """ A custom ArgumentParser that uses pydoc to page its output. + It also supports reading defaults from an init file. """ @@ -226,9 +227,7 @@ class InvokeAISettings(BaseSettings): def int_or_float_or_str(value: str) -> Union[int, float, str]: - """ - Workaround for argparse type checking. - """ + """Workaround for argparse type checking.""" try: return int(value) except Exception as e: # noqa F841 diff --git a/invokeai/app/services/config/invokeai_config.py b/invokeai/app/services/config/invokeai_config.py index c2fc5a9a49..e7bd5b1177 100644 --- a/invokeai/app/services/config/invokeai_config.py +++ b/invokeai/app/services/config/invokeai_config.py @@ -257,7 +257,6 @@ class InvokeAIAppConfig(InvokeAISettings): attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", ) attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", ) force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",) - force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",) # QUEUE max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", ) @@ -454,9 +453,7 @@ class InvokeAIAppConfig(InvokeAISettings): def get_invokeai_config(**kwargs) -> InvokeAIAppConfig: - """ - Legacy function which returns InvokeAIAppConfig.get_config() - """ + """Legacy function which returns InvokeAIAppConfig.get_config().""" return InvokeAIAppConfig.get_config(**kwargs) diff --git a/invokeai/backend/model_manager/cache.py b/invokeai/backend/model_manager/cache.py index c11b51d2cd..4c357634a9 100644 --- a/invokeai/backend/model_manager/cache.py +++ b/invokeai/backend/model_manager/cache.py @@ -55,20 +55,10 @@ class CacheStats(object): loaded_model_sizes: Dict[str, int] = field(default_factory=dict) -class ModelLocker(object): - "Forward declaration" - pass - - -class ModelCache(object): - "Forward declaration" - pass - - class _CacheRecord: size: int model: Any - cache: ModelCache + cache: "ModelCache" _locks: int def __init__(self, cache, model: Any, size: int): @@ -130,7 +120,7 @@ class ModelCache(object): self.logger = logger # used for stats collection - self.stats = None + self.stats: Optional[CacheStats] = None self._cached_models = dict() self._cache_stack = list() @@ -160,7 +150,7 @@ class ModelCache(object): if model_info_key not in self.model_infos: self.model_infos[model_info_key] = model_class( - model_path, + model_path.as_posix(), base_model, model_type, ) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index b64ab875a4..201b270d60 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -118,7 +118,7 @@ class ModelConfigBase(BaseModel): base_model: BaseModelType model_type: ModelType model_format: ModelFormat - key: Optional[str] = Field(None) # this will get added by the store + key: str = Field(description="hash key for model", default="") # this will get added by the store description: Optional[str] = Field(None) author: Optional[str] = Field(description="Model author") license: Optional[str] = Field(description="License string") @@ -233,6 +233,16 @@ class IPAdapterConfig(ModelConfigBase): model_format: Literal[ModelFormat.InvokeAI] +AnyModelConfig = Union[ + MainCheckpointConfig, + MainDiffusersConfig, + LoRAConfig, + TextualInversionConfig, + ONNXSD1Config, + ONNXSD2Config, +] + + class ModelConfigFactory(object): """Class for parsing config dicts into StableDiffusion Config obects.""" @@ -279,14 +289,7 @@ class ModelConfigFactory(object): model_data: Union[dict, ModelConfigBase], key: Optional[str] = None, dest_class: Optional[Type] = None, - ) -> Union[ - MainCheckpointConfig, - MainDiffusersConfig, - LoRAConfig, - TextualInversionConfig, - ONNXSD1Config, - ONNXSD2Config, - ]: + ) -> AnyModelConfig: """ Return the appropriate config object from raw dict values. diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index 411e7937a0..9e0e849fa0 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -60,7 +60,7 @@ from pydantic import Field from pydantic.networks import AnyHttpUrl from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.util import Chdir, InvokeAILogger +from invokeai.backend.util import Chdir, InvokeAILogger, Logger from .config import ( BaseModelType, @@ -321,7 +321,7 @@ class ModelInstall(ModelInstallBase): """Model installer class handles installation from a local path.""" _app_config: InvokeAIAppConfig - _logger: InvokeAILogger + _logger: Logger _store: ModelConfigStore _download_queue: DownloadQueueBase _async_installs: Dict[str, str] @@ -355,7 +355,7 @@ class ModelInstall(ModelInstallBase): self, store: Optional[ModelConfigStore] = None, config: Optional[InvokeAIAppConfig] = None, - logger: Optional[InvokeAILogger] = None, + logger: Optional[Logger] = None, download: Optional[DownloadQueueBase] = None, event_handlers: Optional[List[DownloadEventHandler]] = None, ): # noqa D107 - use base class docstrings diff --git a/invokeai/backend/model_manager/loader.py b/invokeai/backend/model_manager/loader.py index fc4175f8fb..a7560e9567 100644 --- a/invokeai/backend/model_manager/loader.py +++ b/invokeai/backend/model_manager/loader.py @@ -5,16 +5,16 @@ import hashlib from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import torch from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.util import InvokeAILogger, choose_precision, choose_torch_device +from invokeai.backend.util import InvokeAILogger, Logger, choose_precision, choose_torch_device -from .cache import CacheStats, ModelCache, ModelLocker +from .cache import CacheStats, ModelCache from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType -from .download import DownloadEventHandler +from .download import DownloadEventHandler, DownloadQueueBase from .install import ModelInstall, ModelInstallBase from .models import MODEL_CLASSES, InvalidModelException, ModelBase from .storage import ConfigFileVersionMismatchException, ModelConfigStore, get_config_store, migrate_models_store @@ -24,10 +24,10 @@ from .storage import ConfigFileVersionMismatchException, ModelConfigStore, get_c class ModelInfo: """This is a context manager object that is used to intermediate access to a model.""" - context: ModelLocker + context: ModelCache.ModelLocker name: str base_model: BaseModelType - type: ModelType + type: Union[ModelType, SubModelType] key: str location: Union[Path, str] precision: torch.dtype @@ -73,7 +73,7 @@ class ModelLoadBase(ABC): @property @abstractmethod - def logger(self) -> InvokeAILogger: + def logger(self) -> Logger: """Return the current logger.""" pass @@ -84,7 +84,7 @@ class ModelLoadBase(ABC): @property @abstractmethod - def queue(self) -> str: + def queue(self) -> DownloadQueueBase: """Return the download queue object used by this object.""" pass @@ -106,10 +106,7 @@ class ModelLoadBase(ABC): @abstractmethod def sync_to_config(self): - """ - Reinitialize the store to sync in-memory and in-disk - versions. - """ + """Reinitialize the store to sync in-memory and in-disk versions.""" pass @@ -120,7 +117,7 @@ class ModelLoad(ModelLoadBase): _store: ModelConfigStore _installer: ModelInstallBase _cache: ModelCache - _logger: InvokeAILogger + _logger: Logger _cache_keys: dict _models_file: Path @@ -195,7 +192,7 @@ class ModelLoad(ModelLoadBase): return self._installer @property - def logger(self) -> InvokeAILogger: + def logger(self) -> Logger: """Return the current logger.""" return self._logger @@ -205,7 +202,7 @@ class ModelLoad(ModelLoadBase): return self._app_config @property - def queue(self) -> str: + def queue(self) -> DownloadQueueBase: """Return the download queue object used by this object.""" return self._installer.queue @@ -267,6 +264,7 @@ class ModelLoad(ModelLoadBase): ) def collect_cache_stats(self, cache_stats: CacheStats): + """Save CacheStats object for stats collecting.""" self._cache.stats = cache_stats def resolve_model_path(self, path: Union[Path, str]) -> Path: @@ -283,12 +281,12 @@ class ModelLoad(ModelLoadBase): def _get_model_path( self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None - ) -> (Path, bool): + ) -> Tuple[Path, bool]: """Extract a model's filesystem path from its config. :return: The fully qualified Path of the module (or submodule). """ - model_path = model_config.path + model_path = Path(model_config.path) is_submodel_override = False # Does the config explicitly override the submodel? @@ -302,5 +300,6 @@ class ModelLoad(ModelLoadBase): return model_path, is_submodel_override def sync_to_config(self): + """Synchronize models on disk to those in memory.""" self._store = get_config_store(self._models_file) self.installer.scan_models_directory() diff --git a/invokeai/backend/model_manager/models/base.py b/invokeai/backend/model_manager/models/base.py index 674f025844..e9c2e54fae 100644 --- a/invokeai/backend/model_manager/models/base.py +++ b/invokeai/backend/model_manager/models/base.py @@ -177,6 +177,15 @@ class ModelBase(metaclass=ABCMeta): ) -> Any: raise NotImplementedError() + @classmethod + @abstractmethod + def convert_if_required( + cls, + model_config: ModelConfigBase, + output_path: str, + ) -> str: + raise NotImplementedError() + class DiffusersModel(ModelBase): # child_types: Dict[str, Type] diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index e801ab55ad..56d0d8f954 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -10,7 +10,7 @@ import json import re from abc import ABC, abstractmethod from pathlib import Path -from typing import Callable, Optional +from typing import Callable, Dict, Optional import safetensors.torch import torch @@ -33,8 +33,8 @@ class ModelProbeInfo(BaseModel): base_type: BaseModelType format: ModelFormat hash: str - variant_type: Optional[ModelVariantType] = "normal" - prediction_type: Optional[SchedulerPredictionType] = "v_prediction" + variant_type: Optional[ModelVariantType] = ModelVariantType("normal") + prediction_type: Optional[SchedulerPredictionType] = SchedulerPredictionType("v_prediction") upcast_attention: Optional[bool] = False image_size: Optional[int] = None @@ -63,7 +63,7 @@ class ProbeBase(ABC): """Base model for probing checkpoint and diffusers-style models.""" @abstractmethod - def get_base_type(self) -> BaseModelType: + def get_base_type(self) -> Optional[BaseModelType]: """Return the BaseModelType for the model.""" pass @@ -71,7 +71,7 @@ class ProbeBase(ABC): """Return the ModelVariantType for the model.""" pass - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: + def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: """Return the SchedulerPredictionType for the model.""" pass @@ -83,7 +83,7 @@ class ProbeBase(ABC): class ModelProbe(ModelProbeBase): """Class to probe a checkpoint, safetensors or diffusers folder.""" - PROBES = { + PROBES: Dict[str, dict] = { "diffusers": {}, "checkpoint": {}, "onnx": {}, @@ -252,7 +252,7 @@ class ModelProbe(ModelProbeBase): # scan model scan_result = scan_file_path(model) if scan_result.infected_files != 0: - raise "The model {model_name} is potentially infected by malware. Aborting import." + raise InvalidModelException("The model {model_name} is potentially infected by malware. Aborting import.") # ##################################################3 @@ -263,15 +263,13 @@ class ModelProbe(ModelProbeBase): class CheckpointProbeBase(ProbeBase): """Base class for probing checkpoint-style models.""" - def __init__( - self, model: Path, helper: Optional[Callable[[Path], SchedulerPredictionType]] = None - ) -> BaseModelType: + def __init__(self, checkpoint_path: Path, helper: Optional[Callable[[Path], SchedulerPredictionType]] = None): """Initialize the CheckpointProbeBase object.""" - self.checkpoint = ModelProbe._scan_and_load_checkpoint(model) - self.model = model + self.checkpoint_path = checkpoint_path + self.checkpoint = ModelProbe._scan_and_load_checkpoint(checkpoint_path) self.helper = helper - def get_base_type(self) -> BaseModelType: + def get_base_type(self) -> Optional[BaseModelType]: """Return the BaseModelType of a checkpoint-style model.""" pass @@ -281,7 +279,7 @@ class CheckpointProbeBase(ProbeBase): def get_variant_type(self) -> ModelVariantType: """Return the ModelVariantType of a checkpoint-style model.""" - model_type = ModelProbe.get_model_type_from_checkpoint(self.model) + model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path) if model_type != ModelType.Main: return ModelVariantType.Normal state_dict = self.checkpoint.get("state_dict") or self.checkpoint @@ -378,13 +376,13 @@ class LoRACheckpointProbe(CheckpointProbeBase): elif token_vector_length == 2048: return BaseModelType.StableDiffusionXL else: - raise InvalidModelException(f"Unsupported LoRA type: {self.model}") + raise InvalidModelException(f"Unsupported LoRA type: {self.checkpoint_path}") class TextualInversionCheckpointProbe(CheckpointProbeBase): """TextualInversion checkpoint prober.""" - def get_format(self) -> Optional[str]: + def get_format(self) -> str: """Return the format of a TextualInversion emedding.""" return ModelFormat.EmbeddingFile @@ -401,8 +399,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase): return BaseModelType.StableDiffusion1 elif token_dim == 1024: return BaseModelType.StableDiffusion2 - else: - return None + raise InvalidModelException("Unknown base model for {self.checkpoint_path}") class ControlNetCheckpointProbe(CheckpointProbeBase): @@ -421,18 +418,22 @@ class ControlNetCheckpointProbe(CheckpointProbeBase): return BaseModelType.StableDiffusion1 elif checkpoint[key_name].shape[-1] == 1024: return BaseModelType.StableDiffusion2 - elif self.checkpoint_path and self.helper: - return self.helper(self.checkpoint_path) raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}") class IPAdapterCheckpointProbe(CheckpointProbeBase): + """Probe IP adapter models.""" + def get_base_type(self) -> BaseModelType: + """Probe base type.""" raise NotImplementedError() class CLIPVisionCheckpointProbe(CheckpointProbeBase): + """Probe ClipVision adapter models.""" + def get_base_type(self) -> BaseModelType: + """Probe base type.""" raise NotImplementedError() @@ -511,6 +512,7 @@ class VaeFolderProbe(FolderProbeBase): """Class for probing folder-style models.""" def get_base_type(self) -> BaseModelType: + """Get base type of model.""" if self._config_looks_like_sdxl(): return BaseModelType.StableDiffusionXL elif self._name_looks_like_sdxl(): @@ -542,7 +544,7 @@ class VaeFolderProbe(FolderProbeBase): class TextualInversionFolderProbe(FolderProbeBase): """Probe a HuggingFace-style TextualInversion folder.""" - def get_format(self) -> Optional[str]: + def get_format(self) -> str: """Return the format of the TextualInversion.""" return ModelFormat.EmbeddingFolder @@ -616,9 +618,11 @@ class IPAdapterFolderProbe(FolderProbeBase): """Class for probing IP-Adapter models.""" def get_format(self) -> str: + """Get format of ip adapter.""" return ModelFormat.InvokeAI.value def get_base_type(self) -> BaseModelType: + """Get base type of ip adapter.""" model_file = self.folder_path / "ip_adapter.bin" if not model_file.exists(): raise InvalidModelException("Unknown IP-Adapter model format.") @@ -636,7 +640,10 @@ class IPAdapterFolderProbe(FolderProbeBase): class CLIPVisionFolderProbe(FolderProbeBase): + """Probe for folder-based CLIPVision models.""" + def get_base_type(self) -> BaseModelType: + """Get base type.""" return BaseModelType.Any diff --git a/invokeai/backend/model_manager/storage/base.py b/invokeai/backend/model_manager/storage/base.py index 9597b26862..5a85dc2530 100644 --- a/invokeai/backend/model_manager/storage/base.py +++ b/invokeai/backend/model_manager/storage/base.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import List, Optional, Set, Union -from ..config import BaseModelType, ModelConfigBase, ModelType +from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelType # should match the InvokeAI version when this is first released. CONFIG_FILE_VERSION = "3.2" @@ -76,9 +76,9 @@ class ModelConfigStore(ABC): pass @abstractmethod - def get_model(self, key: str) -> ModelConfigBase: + def get_model(self, key: str) -> AnyModelConfig: """ - Retrieve the ModelConfigBase instance for the indicated model. + Retrieve the configuration for the indicated model. :param key: Key of model config to be fetched. diff --git a/invokeai/backend/model_manager/storage/migrate.py b/invokeai/backend/model_manager/storage/migrate.py index c62e91980d..8f9c6b2f1e 100644 --- a/invokeai/backend/model_manager/storage/migrate.py +++ b/invokeai/backend/model_manager/storage/migrate.py @@ -8,12 +8,14 @@ from omegaconf import OmegaConf from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util.logging import InvokeAILogger +from ..config import BaseModelType, MainCheckpointConfig, MainConfig, ModelType from .base import CONFIG_FILE_VERSION def migrate_models_store(config: InvokeAIAppConfig): + """Migrate models from v1 models.yaml to v3.2 models.yaml.""" # avoid circular import - from invokeai.backend.model_manager import DuplicateModelException, InvalidModelException, ModelInstall + from invokeai.backend.model_manager import DuplicateModelException, ModelInstall from invokeai.backend.model_manager.storage import get_config_store app_config = InvokeAIAppConfig.get_config() @@ -33,14 +35,17 @@ def migrate_models_store(config: InvokeAIAppConfig): ), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version" continue - base_type, model_type, model_name = model_key.split("/") + base_type, model_type, model_name = str(model_key).split("/") + new_key = "" try: path = app_config.models_path / stanza["path"] new_key = installer.register_path(path) except DuplicateModelException: # if model already installed, then we just update its info - models = store.search_by_name(model_name=model_name, base_model=base_type, model_type=model_type) + models = store.search_by_name( + model_name=model_name, base_model=BaseModelType(base_type), model_type=ModelType(model_type) + ) if len(models) != 1: continue new_key = models[0].key @@ -48,9 +53,9 @@ def migrate_models_store(config: InvokeAIAppConfig): print(str(excp)) model_info = store.get_model(new_key) - if vae := stanza.get("vae"): + if vae := stanza.get("vae") and isinstance(model_info, MainConfig): model_info.vae = (app_config.models_path / vae).as_posix() - if model_config := stanza.get("config"): + if model_config := stanza.get("config") and isinstance(model_info, MainCheckpointConfig): model_info.config = (app_config.root_path / model_config).as_posix() model_info.description = stanza.get("description") store.update_model(new_key, model_info) diff --git a/invokeai/backend/model_manager/storage/sql.py b/invokeai/backend/model_manager/storage/sql.py index 9487b755b5..f692c3214e 100644 --- a/invokeai/backend/model_manager/storage/sql.py +++ b/invokeai/backend/model_manager/storage/sql.py @@ -46,7 +46,7 @@ import threading from pathlib import Path from typing import List, Optional, Set, Union -from ..config import BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType +from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType from .base import CONFIG_FILE_VERSION, DuplicateModelException, ModelConfigStore, UnknownModelException @@ -350,7 +350,7 @@ class ModelConfigStoreSQL(ModelConfigStore): self._lock.release() return self.get_model(key) - def get_model(self, key: str) -> ModelConfigBase: + def get_model(self, key: str) -> AnyModelConfig: """ Retrieve the ModelConfigBase instance for the indicated model. diff --git a/invokeai/backend/model_manager/storage/yaml.py b/invokeai/backend/model_manager/storage/yaml.py index 68ce9f14bb..e9ae374df8 100644 --- a/invokeai/backend/model_manager/storage/yaml.py +++ b/invokeai/backend/model_manager/storage/yaml.py @@ -49,7 +49,7 @@ import yaml from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig -from ..config import BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType +from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType from .base import ( CONFIG_FILE_VERSION, ConfigFileVersionMismatchException, @@ -170,7 +170,7 @@ class ModelConfigStoreYAML(ModelConfigStore): self._lock.release() return self.get_model(key) - def get_model(self, key: str) -> ModelConfigBase: + def get_model(self, key: str) -> AnyModelConfig: """ Retrieve the ModelConfigBase instance for the indicated model. diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index 10d839fafd..186d842723 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -1,6 +1,8 @@ """ Initialization file for invokeai.backend.util """ +from logging import Logger # noqa: F401 + from .attention import auto_detect_slice_size # noqa: F401 from .devices import ( # noqa: F401 CPU_DEVICE, diff --git a/invokeai/backend/util/logging.py b/invokeai/backend/util/logging.py index d599c73afc..5bc0d5eb80 100644 --- a/invokeai/backend/util/logging.py +++ b/invokeai/backend/util/logging.py @@ -180,7 +180,7 @@ import socket import urllib.parse from abc import abstractmethod from pathlib import Path -from typing import Optional +from typing import Dict from invokeai.app.services.config import InvokeAIAppConfig @@ -345,7 +345,7 @@ LOG_FORMATTERS = { class InvokeAILogger(object): - loggers = dict() + loggers: Dict[str, logging.Logger] = dict() @classmethod def get_logger( diff --git a/tests/AA_nodes/test_session_queue.py b/tests/AA_nodes/test_session_queue.py index 01be0adb69..353615d7d3 100644 --- a/tests/AA_nodes/test_session_queue.py +++ b/tests/AA_nodes/test_session_queue.py @@ -12,6 +12,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( populate_graph, prepare_values_to_insert, ) + from .test_nodes import PromptTestInvocation