mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix migrate script and type mismatches in probe, config and loader
This commit is contained in:
@ -25,6 +25,7 @@ from pydantic import BaseSettings
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
"""
|
||||
A custom ArgumentParser that uses pydoc to page its output.
|
||||
|
||||
It also supports reading defaults from an init file.
|
||||
"""
|
||||
|
||||
@ -226,9 +227,7 @@ class InvokeAISettings(BaseSettings):
|
||||
|
||||
|
||||
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
||||
"""
|
||||
Workaround for argparse type checking.
|
||||
"""
|
||||
"""Workaround for argparse type checking."""
|
||||
try:
|
||||
return int(value)
|
||||
except Exception as e: # noqa F841
|
||||
|
@ -257,7 +257,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", category="Generation", )
|
||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
|
||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
|
||||
|
||||
# QUEUE
|
||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )
|
||||
@ -454,9 +453,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
|
||||
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||
"""
|
||||
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
|
||||
return InvokeAIAppConfig.get_config(**kwargs)
|
||||
|
||||
|
||||
|
@ -55,20 +55,10 @@ class CacheStats(object):
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelLocker(object):
|
||||
"Forward declaration"
|
||||
pass
|
||||
|
||||
|
||||
class ModelCache(object):
|
||||
"Forward declaration"
|
||||
pass
|
||||
|
||||
|
||||
class _CacheRecord:
|
||||
size: int
|
||||
model: Any
|
||||
cache: ModelCache
|
||||
cache: "ModelCache"
|
||||
_locks: int
|
||||
|
||||
def __init__(self, cache, model: Any, size: int):
|
||||
@ -130,7 +120,7 @@ class ModelCache(object):
|
||||
self.logger = logger
|
||||
|
||||
# used for stats collection
|
||||
self.stats = None
|
||||
self.stats: Optional[CacheStats] = None
|
||||
|
||||
self._cached_models = dict()
|
||||
self._cache_stack = list()
|
||||
@ -160,7 +150,7 @@ class ModelCache(object):
|
||||
|
||||
if model_info_key not in self.model_infos:
|
||||
self.model_infos[model_info_key] = model_class(
|
||||
model_path,
|
||||
model_path.as_posix(),
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
|
@ -118,7 +118,7 @@ class ModelConfigBase(BaseModel):
|
||||
base_model: BaseModelType
|
||||
model_type: ModelType
|
||||
model_format: ModelFormat
|
||||
key: Optional[str] = Field(None) # this will get added by the store
|
||||
key: str = Field(description="hash key for model", default="<NOKEY>") # this will get added by the store
|
||||
description: Optional[str] = Field(None)
|
||||
author: Optional[str] = Field(description="Model author")
|
||||
license: Optional[str] = Field(description="License string")
|
||||
@ -233,6 +233,16 @@ class IPAdapterConfig(ModelConfigBase):
|
||||
model_format: Literal[ModelFormat.InvokeAI]
|
||||
|
||||
|
||||
AnyModelConfig = Union[
|
||||
MainCheckpointConfig,
|
||||
MainDiffusersConfig,
|
||||
LoRAConfig,
|
||||
TextualInversionConfig,
|
||||
ONNXSD1Config,
|
||||
ONNXSD2Config,
|
||||
]
|
||||
|
||||
|
||||
class ModelConfigFactory(object):
|
||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||
|
||||
@ -279,14 +289,7 @@ class ModelConfigFactory(object):
|
||||
model_data: Union[dict, ModelConfigBase],
|
||||
key: Optional[str] = None,
|
||||
dest_class: Optional[Type] = None,
|
||||
) -> Union[
|
||||
MainCheckpointConfig,
|
||||
MainDiffusersConfig,
|
||||
LoRAConfig,
|
||||
TextualInversionConfig,
|
||||
ONNXSD1Config,
|
||||
ONNXSD2Config,
|
||||
]:
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Return the appropriate config object from raw dict values.
|
||||
|
||||
|
@ -60,7 +60,7 @@ from pydantic import Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
|
||||
|
||||
from .config import (
|
||||
BaseModelType,
|
||||
@ -321,7 +321,7 @@ class ModelInstall(ModelInstallBase):
|
||||
"""Model installer class handles installation from a local path."""
|
||||
|
||||
_app_config: InvokeAIAppConfig
|
||||
_logger: InvokeAILogger
|
||||
_logger: Logger
|
||||
_store: ModelConfigStore
|
||||
_download_queue: DownloadQueueBase
|
||||
_async_installs: Dict[str, str]
|
||||
@ -355,7 +355,7 @@ class ModelInstall(ModelInstallBase):
|
||||
self,
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
logger: Optional[InvokeAILogger] = None,
|
||||
logger: Optional[Logger] = None,
|
||||
download: Optional[DownloadQueueBase] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
): # noqa D107 - use base class docstrings
|
||||
|
@ -5,16 +5,16 @@ import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import InvokeAILogger, choose_precision, choose_torch_device
|
||||
from invokeai.backend.util import InvokeAILogger, Logger, choose_precision, choose_torch_device
|
||||
|
||||
from .cache import CacheStats, ModelCache, ModelLocker
|
||||
from .cache import CacheStats, ModelCache
|
||||
from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType
|
||||
from .download import DownloadEventHandler
|
||||
from .download import DownloadEventHandler, DownloadQueueBase
|
||||
from .install import ModelInstall, ModelInstallBase
|
||||
from .models import MODEL_CLASSES, InvalidModelException, ModelBase
|
||||
from .storage import ConfigFileVersionMismatchException, ModelConfigStore, get_config_store, migrate_models_store
|
||||
@ -24,10 +24,10 @@ from .storage import ConfigFileVersionMismatchException, ModelConfigStore, get_c
|
||||
class ModelInfo:
|
||||
"""This is a context manager object that is used to intermediate access to a model."""
|
||||
|
||||
context: ModelLocker
|
||||
context: ModelCache.ModelLocker
|
||||
name: str
|
||||
base_model: BaseModelType
|
||||
type: ModelType
|
||||
type: Union[ModelType, SubModelType]
|
||||
key: str
|
||||
location: Union[Path, str]
|
||||
precision: torch.dtype
|
||||
@ -73,7 +73,7 @@ class ModelLoadBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self) -> InvokeAILogger:
|
||||
def logger(self) -> Logger:
|
||||
"""Return the current logger."""
|
||||
pass
|
||||
|
||||
@ -84,7 +84,7 @@ class ModelLoadBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def queue(self) -> str:
|
||||
def queue(self) -> DownloadQueueBase:
|
||||
"""Return the download queue object used by this object."""
|
||||
pass
|
||||
|
||||
@ -106,10 +106,7 @@ class ModelLoadBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Reinitialize the store to sync in-memory and in-disk
|
||||
versions.
|
||||
"""
|
||||
"""Reinitialize the store to sync in-memory and in-disk versions."""
|
||||
pass
|
||||
|
||||
|
||||
@ -120,7 +117,7 @@ class ModelLoad(ModelLoadBase):
|
||||
_store: ModelConfigStore
|
||||
_installer: ModelInstallBase
|
||||
_cache: ModelCache
|
||||
_logger: InvokeAILogger
|
||||
_logger: Logger
|
||||
_cache_keys: dict
|
||||
_models_file: Path
|
||||
|
||||
@ -195,7 +192,7 @@ class ModelLoad(ModelLoadBase):
|
||||
return self._installer
|
||||
|
||||
@property
|
||||
def logger(self) -> InvokeAILogger:
|
||||
def logger(self) -> Logger:
|
||||
"""Return the current logger."""
|
||||
return self._logger
|
||||
|
||||
@ -205,7 +202,7 @@ class ModelLoad(ModelLoadBase):
|
||||
return self._app_config
|
||||
|
||||
@property
|
||||
def queue(self) -> str:
|
||||
def queue(self) -> DownloadQueueBase:
|
||||
"""Return the download queue object used by this object."""
|
||||
return self._installer.queue
|
||||
|
||||
@ -267,6 +264,7 @@ class ModelLoad(ModelLoadBase):
|
||||
)
|
||||
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""Save CacheStats object for stats collecting."""
|
||||
self._cache.stats = cache_stats
|
||||
|
||||
def resolve_model_path(self, path: Union[Path, str]) -> Path:
|
||||
@ -283,12 +281,12 @@ class ModelLoad(ModelLoadBase):
|
||||
|
||||
def _get_model_path(
|
||||
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
|
||||
) -> (Path, bool):
|
||||
) -> Tuple[Path, bool]:
|
||||
"""Extract a model's filesystem path from its config.
|
||||
|
||||
:return: The fully qualified Path of the module (or submodule).
|
||||
"""
|
||||
model_path = model_config.path
|
||||
model_path = Path(model_config.path)
|
||||
is_submodel_override = False
|
||||
|
||||
# Does the config explicitly override the submodel?
|
||||
@ -302,5 +300,6 @@ class ModelLoad(ModelLoadBase):
|
||||
return model_path, is_submodel_override
|
||||
|
||||
def sync_to_config(self):
|
||||
"""Synchronize models on disk to those in memory."""
|
||||
self._store = get_config_store(self._models_file)
|
||||
self.installer.scan_models_directory()
|
||||
|
@ -177,6 +177,15 @@ class ModelBase(metaclass=ABCMeta):
|
||||
) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_config: ModelConfigBase,
|
||||
output_path: str,
|
||||
) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiffusersModel(ModelBase):
|
||||
# child_types: Dict[str, Type]
|
||||
|
@ -10,7 +10,7 @@ import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
@ -33,8 +33,8 @@ class ModelProbeInfo(BaseModel):
|
||||
base_type: BaseModelType
|
||||
format: ModelFormat
|
||||
hash: str
|
||||
variant_type: Optional[ModelVariantType] = "normal"
|
||||
prediction_type: Optional[SchedulerPredictionType] = "v_prediction"
|
||||
variant_type: Optional[ModelVariantType] = ModelVariantType("normal")
|
||||
prediction_type: Optional[SchedulerPredictionType] = SchedulerPredictionType("v_prediction")
|
||||
upcast_attention: Optional[bool] = False
|
||||
image_size: Optional[int] = None
|
||||
|
||||
@ -63,7 +63,7 @@ class ProbeBase(ABC):
|
||||
"""Base model for probing checkpoint and diffusers-style models."""
|
||||
|
||||
@abstractmethod
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
def get_base_type(self) -> Optional[BaseModelType]:
|
||||
"""Return the BaseModelType for the model."""
|
||||
pass
|
||||
|
||||
@ -71,7 +71,7 @@ class ProbeBase(ABC):
|
||||
"""Return the ModelVariantType for the model."""
|
||||
pass
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
|
||||
"""Return the SchedulerPredictionType for the model."""
|
||||
pass
|
||||
|
||||
@ -83,7 +83,7 @@ class ProbeBase(ABC):
|
||||
class ModelProbe(ModelProbeBase):
|
||||
"""Class to probe a checkpoint, safetensors or diffusers folder."""
|
||||
|
||||
PROBES = {
|
||||
PROBES: Dict[str, dict] = {
|
||||
"diffusers": {},
|
||||
"checkpoint": {},
|
||||
"onnx": {},
|
||||
@ -252,7 +252,7 @@ class ModelProbe(ModelProbeBase):
|
||||
# scan model
|
||||
scan_result = scan_file_path(model)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
||||
raise InvalidModelException("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
|
||||
|
||||
# ##################################################3
|
||||
@ -263,15 +263,13 @@ class ModelProbe(ModelProbeBase):
|
||||
class CheckpointProbeBase(ProbeBase):
|
||||
"""Base class for probing checkpoint-style models."""
|
||||
|
||||
def __init__(
|
||||
self, model: Path, helper: Optional[Callable[[Path], SchedulerPredictionType]] = None
|
||||
) -> BaseModelType:
|
||||
def __init__(self, checkpoint_path: Path, helper: Optional[Callable[[Path], SchedulerPredictionType]] = None):
|
||||
"""Initialize the CheckpointProbeBase object."""
|
||||
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model)
|
||||
self.model = model
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.checkpoint = ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
||||
self.helper = helper
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
def get_base_type(self) -> Optional[BaseModelType]:
|
||||
"""Return the BaseModelType of a checkpoint-style model."""
|
||||
pass
|
||||
|
||||
@ -281,7 +279,7 @@ class CheckpointProbeBase(ProbeBase):
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
"""Return the ModelVariantType of a checkpoint-style model."""
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.model)
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path)
|
||||
if model_type != ModelType.Main:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
@ -378,13 +376,13 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"Unsupported LoRA type: {self.model}")
|
||||
raise InvalidModelException(f"Unsupported LoRA type: {self.checkpoint_path}")
|
||||
|
||||
|
||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
"""TextualInversion checkpoint prober."""
|
||||
|
||||
def get_format(self) -> Optional[str]:
|
||||
def get_format(self) -> str:
|
||||
"""Return the format of a TextualInversion emedding."""
|
||||
return ModelFormat.EmbeddingFile
|
||||
|
||||
@ -401,8 +399,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
else:
|
||||
return None
|
||||
raise InvalidModelException("Unknown base model for {self.checkpoint_path}")
|
||||
|
||||
|
||||
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
@ -421,18 +418,22 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif checkpoint[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif self.checkpoint_path and self.helper:
|
||||
return self.helper(self.checkpoint_path)
|
||||
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
||||
|
||||
|
||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
"""Probe IP adapter models."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Probe base type."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||
"""Probe ClipVision adapter models."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Probe base type."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@ -511,6 +512,7 @@ class VaeFolderProbe(FolderProbeBase):
|
||||
"""Class for probing folder-style models."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Get base type of model."""
|
||||
if self._config_looks_like_sdxl():
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif self._name_looks_like_sdxl():
|
||||
@ -542,7 +544,7 @@ class VaeFolderProbe(FolderProbeBase):
|
||||
class TextualInversionFolderProbe(FolderProbeBase):
|
||||
"""Probe a HuggingFace-style TextualInversion folder."""
|
||||
|
||||
def get_format(self) -> Optional[str]:
|
||||
def get_format(self) -> str:
|
||||
"""Return the format of the TextualInversion."""
|
||||
return ModelFormat.EmbeddingFolder
|
||||
|
||||
@ -616,9 +618,11 @@ class IPAdapterFolderProbe(FolderProbeBase):
|
||||
"""Class for probing IP-Adapter models."""
|
||||
|
||||
def get_format(self) -> str:
|
||||
"""Get format of ip adapter."""
|
||||
return ModelFormat.InvokeAI.value
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Get base type of ip adapter."""
|
||||
model_file = self.folder_path / "ip_adapter.bin"
|
||||
if not model_file.exists():
|
||||
raise InvalidModelException("Unknown IP-Adapter model format.")
|
||||
@ -636,7 +640,10 @@ class IPAdapterFolderProbe(FolderProbeBase):
|
||||
|
||||
|
||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
"""Probe for folder-based CLIPVision models."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
"""Get base type."""
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
from ..config import BaseModelType, ModelConfigBase, ModelType
|
||||
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelType
|
||||
|
||||
# should match the InvokeAI version when this is first released.
|
||||
CONFIG_FILE_VERSION = "3.2"
|
||||
@ -76,9 +76,9 @@ class ModelConfigStore(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self, key: str) -> ModelConfigBase:
|
||||
def get_model(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the ModelConfigBase instance for the indicated model.
|
||||
Retrieve the configuration for the indicated model.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
|
||||
|
@ -8,12 +8,14 @@ from omegaconf import OmegaConf
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import BaseModelType, MainCheckpointConfig, MainConfig, ModelType
|
||||
from .base import CONFIG_FILE_VERSION
|
||||
|
||||
|
||||
def migrate_models_store(config: InvokeAIAppConfig):
|
||||
"""Migrate models from v1 models.yaml to v3.2 models.yaml."""
|
||||
# avoid circular import
|
||||
from invokeai.backend.model_manager import DuplicateModelException, InvalidModelException, ModelInstall
|
||||
from invokeai.backend.model_manager import DuplicateModelException, ModelInstall
|
||||
from invokeai.backend.model_manager.storage import get_config_store
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
@ -33,14 +35,17 @@ def migrate_models_store(config: InvokeAIAppConfig):
|
||||
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
||||
continue
|
||||
|
||||
base_type, model_type, model_name = model_key.split("/")
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
new_key = "<NOKEY>"
|
||||
|
||||
try:
|
||||
path = app_config.models_path / stanza["path"]
|
||||
new_key = installer.register_path(path)
|
||||
except DuplicateModelException:
|
||||
# if model already installed, then we just update its info
|
||||
models = store.search_by_name(model_name=model_name, base_model=base_type, model_type=model_type)
|
||||
models = store.search_by_name(
|
||||
model_name=model_name, base_model=BaseModelType(base_type), model_type=ModelType(model_type)
|
||||
)
|
||||
if len(models) != 1:
|
||||
continue
|
||||
new_key = models[0].key
|
||||
@ -48,9 +53,9 @@ def migrate_models_store(config: InvokeAIAppConfig):
|
||||
print(str(excp))
|
||||
|
||||
model_info = store.get_model(new_key)
|
||||
if vae := stanza.get("vae"):
|
||||
if vae := stanza.get("vae") and isinstance(model_info, MainConfig):
|
||||
model_info.vae = (app_config.models_path / vae).as_posix()
|
||||
if model_config := stanza.get("config"):
|
||||
if model_config := stanza.get("config") and isinstance(model_info, MainCheckpointConfig):
|
||||
model_info.config = (app_config.root_path / model_config).as_posix()
|
||||
model_info.description = stanza.get("description")
|
||||
store.update_model(new_key, model_info)
|
||||
|
@ -46,7 +46,7 @@ import threading
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
from ..config import BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
|
||||
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
|
||||
from .base import CONFIG_FILE_VERSION, DuplicateModelException, ModelConfigStore, UnknownModelException
|
||||
|
||||
|
||||
@ -350,7 +350,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
self._lock.release()
|
||||
return self.get_model(key)
|
||||
|
||||
def get_model(self, key: str) -> ModelConfigBase:
|
||||
def get_model(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the ModelConfigBase instance for the indicated model.
|
||||
|
||||
|
@ -49,7 +49,7 @@ import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
from ..config import BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
|
||||
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
|
||||
from .base import (
|
||||
CONFIG_FILE_VERSION,
|
||||
ConfigFileVersionMismatchException,
|
||||
@ -170,7 +170,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
self._lock.release()
|
||||
return self.get_model(key)
|
||||
|
||||
def get_model(self, key: str) -> ModelConfigBase:
|
||||
def get_model(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the ModelConfigBase instance for the indicated model.
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.util
|
||||
"""
|
||||
from logging import Logger # noqa: F401
|
||||
|
||||
from .attention import auto_detect_slice_size # noqa: F401
|
||||
from .devices import ( # noqa: F401
|
||||
CPU_DEVICE,
|
||||
|
@ -180,7 +180,7 @@ import socket
|
||||
import urllib.parse
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Dict
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
@ -345,7 +345,7 @@ LOG_FORMATTERS = {
|
||||
|
||||
|
||||
class InvokeAILogger(object):
|
||||
loggers = dict()
|
||||
loggers: Dict[str, logging.Logger] = dict()
|
||||
|
||||
@classmethod
|
||||
def get_logger(
|
||||
|
@ -12,6 +12,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
populate_graph,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
|
||||
from .test_nodes import PromptTestInvocation
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user